xref: /linux/tools/net/ynl/pyynl/ynl_gen_c.py (revision 2151003e773c7e7dba4d64bed4bfc483681b5f6a)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import filecmp
7import pathlib
8import os
9import re
10import shutil
11import sys
12import tempfile
13import yaml
14
15sys.path.append(pathlib.Path(__file__).resolve().parent.as_posix())
16from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
17
18
19def c_upper(name):
20    return name.upper().replace('-', '_')
21
22
23def c_lower(name):
24    return name.lower().replace('-', '_')
25
26
27def limit_to_number(name):
28    """
29    Turn a string limit like u32-max or s64-min into its numerical value
30    """
31    if name[0] == 'u' and name.endswith('-min'):
32        return 0
33    width = int(name[1:-4])
34    if name[0] == 's':
35        width -= 1
36    value = (1 << width) - 1
37    if name[0] == 's' and name.endswith('-min'):
38        value = -value - 1
39    return value
40
41
42class BaseNlLib:
43    def get_family_id(self):
44        return 'ys->family_id'
45
46
47class Type(SpecAttr):
48    def __init__(self, family, attr_set, attr, value):
49        super().__init__(family, attr_set, attr, value)
50
51        self.attr = attr
52        self.attr_set = attr_set
53        self.type = attr['type']
54        self.checks = attr.get('checks', {})
55
56        self.request = False
57        self.reply = False
58
59        if 'len' in attr:
60            self.len = attr['len']
61
62        if 'nested-attributes' in attr:
63            self.nested_attrs = attr['nested-attributes']
64            if self.nested_attrs == family.name:
65                self.nested_render_name = c_lower(f"{family.ident_name}")
66            else:
67                self.nested_render_name = c_lower(f"{family.ident_name}_{self.nested_attrs}")
68
69            if self.nested_attrs in self.family.consts:
70                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
71            else:
72                self.nested_struct_type = 'struct ' + self.nested_render_name
73
74        self.c_name = c_lower(self.name)
75        if self.c_name in _C_KW:
76            self.c_name += '_'
77
78        # Added by resolve():
79        self.enum_name = None
80        delattr(self, "enum_name")
81
82    def _get_real_attr(self):
83        # if the attr is for a subset return the "real" attr (just one down, does not recurse)
84        return self.family.attr_sets[self.attr_set.subset_of][self.name]
85
86    def set_request(self):
87        self.request = True
88        if self.attr_set.subset_of:
89            self._get_real_attr().set_request()
90
91    def set_reply(self):
92        self.reply = True
93        if self.attr_set.subset_of:
94            self._get_real_attr().set_reply()
95
96    def get_limit(self, limit, default=None):
97        value = self.checks.get(limit, default)
98        if value is None:
99            return value
100        if isinstance(value, int):
101            return value
102        if value in self.family.consts:
103            return self.family.consts[value]["value"]
104        return limit_to_number(value)
105
106    def get_limit_str(self, limit, default=None, suffix=''):
107        value = self.checks.get(limit, default)
108        if value is None:
109            return ''
110        if isinstance(value, int):
111            return str(value) + suffix
112        if value in self.family.consts:
113            const = self.family.consts[value]
114            if const.get('header'):
115                return c_upper(value)
116            return c_upper(f"{self.family['name']}-{value}")
117        return c_upper(value)
118
119    def resolve(self):
120        if 'name-prefix' in self.attr:
121            enum_name = f"{self.attr['name-prefix']}{self.name}"
122        else:
123            enum_name = f"{self.attr_set.name_prefix}{self.name}"
124        self.enum_name = c_upper(enum_name)
125
126        if self.attr_set.subset_of:
127            if self.checks != self._get_real_attr().checks:
128                raise Exception("Overriding checks not supported by codegen, yet")
129
130    def is_multi_val(self):
131        return None
132
133    def is_scalar(self):
134        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
135
136    def is_recursive(self):
137        return False
138
139    def is_recursive_for_op(self, ri):
140        return self.is_recursive() and not ri.op
141
142    def presence_type(self):
143        return 'bit'
144
145    def presence_member(self, space, type_filter):
146        if self.presence_type() != type_filter:
147            return
148
149        if self.presence_type() == 'bit':
150            pfx = '__' if space == 'user' else ''
151            return f"{pfx}u32 {self.c_name}:1;"
152
153        if self.presence_type() == 'len':
154            pfx = '__' if space == 'user' else ''
155            return f"{pfx}u32 {self.c_name}_len;"
156
157    def _complex_member_type(self, ri):
158        return None
159
160    def free_needs_iter(self):
161        return False
162
163    def free(self, ri, var, ref):
164        if self.is_multi_val() or self.presence_type() == 'len':
165            ri.cw.p(f'free({var}->{ref}{self.c_name});')
166
167    def arg_member(self, ri):
168        member = self._complex_member_type(ri)
169        if member:
170            arg = [member + ' *' + self.c_name]
171            if self.presence_type() == 'count':
172                arg += ['unsigned int n_' + self.c_name]
173            return arg
174        raise Exception(f"Struct member not implemented for class type {self.type}")
175
176    def struct_member(self, ri):
177        if self.is_multi_val():
178            ri.cw.p(f"unsigned int n_{self.c_name};")
179        member = self._complex_member_type(ri)
180        if member:
181            ptr = '*' if self.is_multi_val() else ''
182            if self.is_recursive_for_op(ri):
183                ptr = '*'
184            ri.cw.p(f"{member} {ptr}{self.c_name};")
185            return
186        members = self.arg_member(ri)
187        for one in members:
188            ri.cw.p(one + ';')
189
190    def _attr_policy(self, policy):
191        return '{ .type = ' + policy + ', }'
192
193    def attr_policy(self, cw):
194        policy = f'NLA_{c_upper(self.type)}'
195        if self.attr.get('byte-order') == 'big-endian':
196            if self.type in {'u16', 'u32'}:
197                policy = f'NLA_BE{self.type[1:]}'
198
199        spec = self._attr_policy(policy)
200        cw.p(f"\t[{self.enum_name}] = {spec},")
201
202    def _attr_typol(self):
203        raise Exception(f"Type policy not implemented for class type {self.type}")
204
205    def attr_typol(self, cw):
206        typol = self._attr_typol()
207        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
208
209    def _attr_put_line(self, ri, var, line):
210        if self.presence_type() == 'bit':
211            ri.cw.p(f"if ({var}->_present.{self.c_name})")
212        elif self.presence_type() == 'len':
213            ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
214        ri.cw.p(f"{line};")
215
216    def _attr_put_simple(self, ri, var, put_type):
217        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
218        self._attr_put_line(ri, var, line)
219
220    def attr_put(self, ri, var):
221        raise Exception(f"Put not implemented for class type {self.type}")
222
223    def _attr_get(self, ri, var):
224        raise Exception(f"Attr get not implemented for class type {self.type}")
225
226    def attr_get(self, ri, var, first):
227        lines, init_lines, local_vars = self._attr_get(ri, var)
228        if type(lines) is str:
229            lines = [lines]
230        if type(init_lines) is str:
231            init_lines = [init_lines]
232
233        kw = 'if' if first else 'else if'
234        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
235        if local_vars:
236            for local in local_vars:
237                ri.cw.p(local)
238            ri.cw.nl()
239
240        if not self.is_multi_val():
241            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
242            ri.cw.p("return YNL_PARSE_CB_ERROR;")
243            if self.presence_type() == 'bit':
244                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
245
246        if init_lines:
247            ri.cw.nl()
248            for line in init_lines:
249                ri.cw.p(line)
250
251        for line in lines:
252            ri.cw.p(line)
253        ri.cw.block_end()
254        return True
255
256    def _setter_lines(self, ri, member, presence):
257        raise Exception(f"Setter not implemented for class type {self.type}")
258
259    def setter(self, ri, space, direction, deref=False, ref=None):
260        ref = (ref if ref else []) + [self.c_name]
261        var = "req"
262        member = f"{var}->{'.'.join(ref)}"
263
264        code = []
265        presence = ''
266        for i in range(0, len(ref)):
267            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
268            # Every layer below last is a nest, so we know it uses bit presence
269            # last layer is "self" and may be a complex type
270            if i == len(ref) - 1 and self.presence_type() != 'bit':
271                continue
272            code.append(presence + ' = 1;')
273        code += self._setter_lines(ri, member, presence)
274
275        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
276        free = bool([x for x in code if 'free(' in x])
277        alloc = bool([x for x in code if 'alloc(' in x])
278        if free and not alloc:
279            func_name = '__' + func_name
280        ri.cw.write_func('static inline void', func_name, body=code,
281                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
282
283
284class TypeUnused(Type):
285    def presence_type(self):
286        return ''
287
288    def arg_member(self, ri):
289        return []
290
291    def _attr_get(self, ri, var):
292        return ['return YNL_PARSE_CB_ERROR;'], None, None
293
294    def _attr_typol(self):
295        return '.type = YNL_PT_REJECT, '
296
297    def attr_policy(self, cw):
298        pass
299
300    def attr_put(self, ri, var):
301        pass
302
303    def attr_get(self, ri, var, first):
304        pass
305
306    def setter(self, ri, space, direction, deref=False, ref=None):
307        pass
308
309
310class TypePad(Type):
311    def presence_type(self):
312        return ''
313
314    def arg_member(self, ri):
315        return []
316
317    def _attr_typol(self):
318        return '.type = YNL_PT_IGNORE, '
319
320    def attr_put(self, ri, var):
321        pass
322
323    def attr_get(self, ri, var, first):
324        pass
325
326    def attr_policy(self, cw):
327        pass
328
329    def setter(self, ri, space, direction, deref=False, ref=None):
330        pass
331
332
333class TypeScalar(Type):
334    def __init__(self, family, attr_set, attr, value):
335        super().__init__(family, attr_set, attr, value)
336
337        self.byte_order_comment = ''
338        if 'byte-order' in attr:
339            self.byte_order_comment = f" /* {attr['byte-order']} */"
340
341        if 'enum' in self.attr:
342            enum = self.family.consts[self.attr['enum']]
343            low, high = enum.value_range()
344            if 'min' not in self.checks:
345                if low != 0 or self.type[0] == 's':
346                    self.checks['min'] = low
347            if 'max' not in self.checks:
348                self.checks['max'] = high
349
350        if 'min' in self.checks and 'max' in self.checks:
351            if self.get_limit('min') > self.get_limit('max'):
352                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
353            self.checks['range'] = True
354
355        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
356        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
357        if low < 0 and self.type[0] == 'u':
358            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
359        if low < -32768 or high > 32767:
360            self.checks['full-range'] = True
361
362        # Added by resolve():
363        self.is_bitfield = None
364        delattr(self, "is_bitfield")
365        self.type_name = None
366        delattr(self, "type_name")
367
368    def resolve(self):
369        self.resolve_up(super())
370
371        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
372            self.is_bitfield = True
373        elif 'enum' in self.attr:
374            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
375        else:
376            self.is_bitfield = False
377
378        if not self.is_bitfield and 'enum' in self.attr:
379            self.type_name = self.family.consts[self.attr['enum']].user_type
380        elif self.is_auto_scalar:
381            self.type_name = '__' + self.type[0] + '64'
382        else:
383            self.type_name = '__' + self.type
384
385    def _attr_policy(self, policy):
386        if 'flags-mask' in self.checks or self.is_bitfield:
387            if self.is_bitfield:
388                enum = self.family.consts[self.attr['enum']]
389                mask = enum.get_mask(as_flags=True)
390            else:
391                flags = self.family.consts[self.checks['flags-mask']]
392                flag_cnt = len(flags['entries'])
393                mask = (1 << flag_cnt) - 1
394            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
395        elif 'full-range' in self.checks:
396            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
397        elif 'range' in self.checks:
398            return f"NLA_POLICY_RANGE({policy}, {self.get_limit_str('min')}, {self.get_limit_str('max')})"
399        elif 'min' in self.checks:
400            return f"NLA_POLICY_MIN({policy}, {self.get_limit_str('min')})"
401        elif 'max' in self.checks:
402            return f"NLA_POLICY_MAX({policy}, {self.get_limit_str('max')})"
403        return super()._attr_policy(policy)
404
405    def _attr_typol(self):
406        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
407
408    def arg_member(self, ri):
409        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
410
411    def attr_put(self, ri, var):
412        self._attr_put_simple(ri, var, self.type)
413
414    def _attr_get(self, ri, var):
415        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
416
417    def _setter_lines(self, ri, member, presence):
418        return [f"{member} = {self.c_name};"]
419
420
421class TypeFlag(Type):
422    def arg_member(self, ri):
423        return []
424
425    def _attr_typol(self):
426        return '.type = YNL_PT_FLAG, '
427
428    def attr_put(self, ri, var):
429        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
430
431    def _attr_get(self, ri, var):
432        return [], None, None
433
434    def _setter_lines(self, ri, member, presence):
435        return []
436
437
438class TypeString(Type):
439    def arg_member(self, ri):
440        return [f"const char *{self.c_name}"]
441
442    def presence_type(self):
443        return 'len'
444
445    def struct_member(self, ri):
446        ri.cw.p(f"char *{self.c_name};")
447
448    def _attr_typol(self):
449        return f'.type = YNL_PT_NUL_STR, '
450
451    def _attr_policy(self, policy):
452        if 'exact-len' in self.checks:
453            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
454        else:
455            mem = '{ .type = ' + policy
456            if 'max-len' in self.checks:
457                mem += ', .len = ' + self.get_limit_str('max-len')
458            mem += ', }'
459        return mem
460
461    def attr_policy(self, cw):
462        if self.checks.get('unterminated-ok', False):
463            policy = 'NLA_STRING'
464        else:
465            policy = 'NLA_NUL_STRING'
466
467        spec = self._attr_policy(policy)
468        cw.p(f"\t[{self.enum_name}] = {spec},")
469
470    def attr_put(self, ri, var):
471        self._attr_put_simple(ri, var, 'str')
472
473    def _attr_get(self, ri, var):
474        len_mem = var + '->_present.' + self.c_name + '_len'
475        return [f"{len_mem} = len;",
476                f"{var}->{self.c_name} = malloc(len + 1);",
477                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
478                f"{var}->{self.c_name}[len] = 0;"], \
479               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
480               ['unsigned int len;']
481
482    def _setter_lines(self, ri, member, presence):
483        return [f"free({member});",
484                f"{presence}_len = strlen({self.c_name});",
485                f"{member} = malloc({presence}_len + 1);",
486                f'memcpy({member}, {self.c_name}, {presence}_len);',
487                f'{member}[{presence}_len] = 0;']
488
489
490class TypeBinary(Type):
491    def arg_member(self, ri):
492        return [f"const void *{self.c_name}", 'size_t len']
493
494    def presence_type(self):
495        return 'len'
496
497    def struct_member(self, ri):
498        ri.cw.p(f"void *{self.c_name};")
499
500    def _attr_typol(self):
501        return f'.type = YNL_PT_BINARY,'
502
503    def _attr_policy(self, policy):
504        if len(self.checks) == 0:
505            pass
506        elif len(self.checks) == 1:
507            check_name = list(self.checks)[0]
508            if check_name not in {'exact-len', 'min-len', 'max-len'}:
509                raise Exception('Unsupported check for binary type: ' + check_name)
510        else:
511            raise Exception('More than one check for binary type not implemented, yet')
512
513        if len(self.checks) == 0:
514            mem = '{ .type = NLA_BINARY, }'
515        elif 'exact-len' in self.checks:
516            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
517        elif 'min-len' in self.checks:
518            mem = '{ .len = ' + self.get_limit_str('min-len') + ', }'
519        elif 'max-len' in self.checks:
520            mem = 'NLA_POLICY_MAX_LEN(' + self.get_limit_str('max-len') + ')'
521
522        return mem
523
524    def attr_put(self, ri, var):
525        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
526                            f"{var}->{self.c_name}, {var}->_present.{self.c_name}_len)")
527
528    def _attr_get(self, ri, var):
529        len_mem = var + '->_present.' + self.c_name + '_len'
530        return [f"{len_mem} = len;",
531                f"{var}->{self.c_name} = malloc(len);",
532                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
533               ['len = ynl_attr_data_len(attr);'], \
534               ['unsigned int len;']
535
536    def _setter_lines(self, ri, member, presence):
537        return [f"free({member});",
538                f"{presence}_len = len;",
539                f"{member} = malloc({presence}_len);",
540                f'memcpy({member}, {self.c_name}, {presence}_len);']
541
542
543class TypeBitfield32(Type):
544    def _complex_member_type(self, ri):
545        return "struct nla_bitfield32"
546
547    def _attr_typol(self):
548        return f'.type = YNL_PT_BITFIELD32, '
549
550    def _attr_policy(self, policy):
551        if not 'enum' in self.attr:
552            raise Exception('Enum required for bitfield32 attr')
553        enum = self.family.consts[self.attr['enum']]
554        mask = enum.get_mask(as_flags=True)
555        return f"NLA_POLICY_BITFIELD32({mask})"
556
557    def attr_put(self, ri, var):
558        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
559        self._attr_put_line(ri, var, line)
560
561    def _attr_get(self, ri, var):
562        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
563
564    def _setter_lines(self, ri, member, presence):
565        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
566
567
568class TypeNest(Type):
569    def is_recursive(self):
570        return self.family.pure_nested_structs[self.nested_attrs].recursive
571
572    def _complex_member_type(self, ri):
573        return self.nested_struct_type
574
575    def free(self, ri, var, ref):
576        at = '&'
577        if self.is_recursive_for_op(ri):
578            at = ''
579            ri.cw.p(f'if ({var}->{ref}{self.c_name})')
580        ri.cw.p(f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});')
581
582    def _attr_typol(self):
583        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
584
585    def _attr_policy(self, policy):
586        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
587
588    def attr_put(self, ri, var):
589        at = '' if self.is_recursive_for_op(ri) else '&'
590        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
591                            f"{self.enum_name}, {at}{var}->{self.c_name})")
592
593    def _attr_get(self, ri, var):
594        get_lines = [f"if ({self.nested_render_name}_parse(&parg, attr))",
595                     "return YNL_PARSE_CB_ERROR;"]
596        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
597                      f"parg.data = &{var}->{self.c_name};"]
598        return get_lines, init_lines, None
599
600    def setter(self, ri, space, direction, deref=False, ref=None):
601        ref = (ref if ref else []) + [self.c_name]
602
603        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
604            if attr.is_recursive():
605                continue
606            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
607
608
609class TypeMultiAttr(Type):
610    def __init__(self, family, attr_set, attr, value, base_type):
611        super().__init__(family, attr_set, attr, value)
612
613        self.base_type = base_type
614
615    def is_multi_val(self):
616        return True
617
618    def presence_type(self):
619        return 'count'
620
621    def _complex_member_type(self, ri):
622        if 'type' not in self.attr or self.attr['type'] == 'nest':
623            return self.nested_struct_type
624        elif self.attr['type'] in scalars:
625            scalar_pfx = '__' if ri.ku_space == 'user' else ''
626            return scalar_pfx + self.attr['type']
627        else:
628            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
629
630    def free_needs_iter(self):
631        return 'type' not in self.attr or self.attr['type'] == 'nest'
632
633    def free(self, ri, var, ref):
634        if self.attr['type'] in scalars:
635            ri.cw.p(f"free({var}->{ref}{self.c_name});")
636        elif 'type' not in self.attr or self.attr['type'] == 'nest':
637            ri.cw.p(f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)")
638            ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);')
639            ri.cw.p(f"free({var}->{ref}{self.c_name});")
640        else:
641            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
642
643    def _attr_policy(self, policy):
644        return self.base_type._attr_policy(policy)
645
646    def _attr_typol(self):
647        return self.base_type._attr_typol()
648
649    def _attr_get(self, ri, var):
650        return f'n_{self.c_name}++;', None, None
651
652    def attr_put(self, ri, var):
653        if self.attr['type'] in scalars:
654            put_type = self.type
655            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
656            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
657        elif 'type' not in self.attr or self.attr['type'] == 'nest':
658            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
659            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
660                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
661        else:
662            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
663
664    def _setter_lines(self, ri, member, presence):
665        # For multi-attr we have a count, not presence, hack up the presence
666        presence = presence[:-(len('_present.') + len(self.c_name))] + "n_" + self.c_name
667        return [f"free({member});",
668                f"{member} = {self.c_name};",
669                f"{presence} = n_{self.c_name};"]
670
671
672class TypeArrayNest(Type):
673    def is_multi_val(self):
674        return True
675
676    def presence_type(self):
677        return 'count'
678
679    def _complex_member_type(self, ri):
680        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
681            return self.nested_struct_type
682        elif self.attr['sub-type'] in scalars:
683            scalar_pfx = '__' if ri.ku_space == 'user' else ''
684            return scalar_pfx + self.attr['sub-type']
685        else:
686            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
687
688    def _attr_typol(self):
689        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
690
691    def _attr_get(self, ri, var):
692        local_vars = ['const struct nlattr *attr2;']
693        get_lines = [f'attr_{self.c_name} = attr;',
694                     'ynl_attr_for_each_nested(attr2, attr)',
695                     f'\t{var}->n_{self.c_name}++;']
696        return get_lines, None, local_vars
697
698
699class TypeNestTypeValue(Type):
700    def _complex_member_type(self, ri):
701        return self.nested_struct_type
702
703    def _attr_typol(self):
704        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
705
706    def _attr_get(self, ri, var):
707        prev = 'attr'
708        tv_args = ''
709        get_lines = []
710        local_vars = []
711        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
712                      f"parg.data = &{var}->{self.c_name};"]
713        if 'type-value' in self.attr:
714            tv_names = [c_lower(x) for x in self.attr["type-value"]]
715            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
716            local_vars += [f'__u32 {", ".join(tv_names)};']
717            for level in self.attr["type-value"]:
718                level = c_lower(level)
719                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
720                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
721                prev = 'attr_' + level
722
723            tv_args = f", {', '.join(tv_names)}"
724
725        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
726        return get_lines, init_lines, local_vars
727
728
729class Struct:
730    def __init__(self, family, space_name, type_list=None, inherited=None):
731        self.family = family
732        self.space_name = space_name
733        self.attr_set = family.attr_sets[space_name]
734        # Use list to catch comparisons with empty sets
735        self._inherited = inherited if inherited is not None else []
736        self.inherited = []
737
738        self.nested = type_list is None
739        if family.name == c_lower(space_name):
740            self.render_name = c_lower(family.ident_name)
741        else:
742            self.render_name = c_lower(family.ident_name + '-' + space_name)
743        self.struct_name = 'struct ' + self.render_name
744        if self.nested and space_name in family.consts:
745            self.struct_name += '_'
746        self.ptr_name = self.struct_name + ' *'
747        # All attr sets this one contains, directly or multiple levels down
748        self.child_nests = set()
749
750        self.request = False
751        self.reply = False
752        self.recursive = False
753
754        self.attr_list = []
755        self.attrs = dict()
756        if type_list is not None:
757            for t in type_list:
758                self.attr_list.append((t, self.attr_set[t]),)
759        else:
760            for t in self.attr_set:
761                self.attr_list.append((t, self.attr_set[t]),)
762
763        max_val = 0
764        self.attr_max_val = None
765        for name, attr in self.attr_list:
766            if attr.value >= max_val:
767                max_val = attr.value
768                self.attr_max_val = attr
769            self.attrs[name] = attr
770
771    def __iter__(self):
772        yield from self.attrs
773
774    def __getitem__(self, key):
775        return self.attrs[key]
776
777    def member_list(self):
778        return self.attr_list
779
780    def set_inherited(self, new_inherited):
781        if self._inherited != new_inherited:
782            raise Exception("Inheriting different members not supported")
783        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
784
785
786class EnumEntry(SpecEnumEntry):
787    def __init__(self, enum_set, yaml, prev, value_start):
788        super().__init__(enum_set, yaml, prev, value_start)
789
790        if prev:
791            self.value_change = (self.value != prev.value + 1)
792        else:
793            self.value_change = (self.value != 0)
794        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
795
796        # Added by resolve:
797        self.c_name = None
798        delattr(self, "c_name")
799
800    def resolve(self):
801        self.resolve_up(super())
802
803        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
804
805
806class EnumSet(SpecEnumSet):
807    def __init__(self, family, yaml):
808        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
809
810        if 'enum-name' in yaml:
811            if yaml['enum-name']:
812                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
813                self.user_type = self.enum_name
814            else:
815                self.enum_name = None
816        else:
817            self.enum_name = 'enum ' + self.render_name
818
819        if self.enum_name:
820            self.user_type = self.enum_name
821        else:
822            self.user_type = 'int'
823
824        self.value_pfx = yaml.get('name-prefix', f"{family.ident_name}-{yaml['name']}-")
825        self.header = yaml.get('header', None)
826        self.enum_cnt_name = yaml.get('enum-cnt-name', None)
827
828        super().__init__(family, yaml)
829
830    def new_entry(self, entry, prev_entry, value_start):
831        return EnumEntry(self, entry, prev_entry, value_start)
832
833    def value_range(self):
834        low = min([x.value for x in self.entries.values()])
835        high = max([x.value for x in self.entries.values()])
836
837        if high - low + 1 != len(self.entries):
838            raise Exception("Can't get value range for a noncontiguous enum")
839
840        return low, high
841
842
843class AttrSet(SpecAttrSet):
844    def __init__(self, family, yaml):
845        super().__init__(family, yaml)
846
847        if self.subset_of is None:
848            if 'name-prefix' in yaml:
849                pfx = yaml['name-prefix']
850            elif self.name == family.name:
851                pfx = family.ident_name + '-a-'
852            else:
853                pfx = f"{family.ident_name}-a-{self.name}-"
854            self.name_prefix = c_upper(pfx)
855            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
856            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
857        else:
858            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
859            self.max_name = family.attr_sets[self.subset_of].max_name
860            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
861
862        # Added by resolve:
863        self.c_name = None
864        delattr(self, "c_name")
865
866    def resolve(self):
867        self.c_name = c_lower(self.name)
868        if self.c_name in _C_KW:
869            self.c_name += '_'
870        if self.c_name == self.family.c_name:
871            self.c_name = ''
872
873    def new_attr(self, elem, value):
874        if elem['type'] in scalars:
875            t = TypeScalar(self.family, self, elem, value)
876        elif elem['type'] == 'unused':
877            t = TypeUnused(self.family, self, elem, value)
878        elif elem['type'] == 'pad':
879            t = TypePad(self.family, self, elem, value)
880        elif elem['type'] == 'flag':
881            t = TypeFlag(self.family, self, elem, value)
882        elif elem['type'] == 'string':
883            t = TypeString(self.family, self, elem, value)
884        elif elem['type'] == 'binary':
885            t = TypeBinary(self.family, self, elem, value)
886        elif elem['type'] == 'bitfield32':
887            t = TypeBitfield32(self.family, self, elem, value)
888        elif elem['type'] == 'nest':
889            t = TypeNest(self.family, self, elem, value)
890        elif elem['type'] == 'indexed-array' and 'sub-type' in elem:
891            if elem["sub-type"] == 'nest':
892                t = TypeArrayNest(self.family, self, elem, value)
893            else:
894                raise Exception(f'new_attr: unsupported sub-type {elem["sub-type"]}')
895        elif elem['type'] == 'nest-type-value':
896            t = TypeNestTypeValue(self.family, self, elem, value)
897        else:
898            raise Exception(f"No typed class for type {elem['type']}")
899
900        if 'multi-attr' in elem and elem['multi-attr']:
901            t = TypeMultiAttr(self.family, self, elem, value, t)
902
903        return t
904
905
906class Operation(SpecOperation):
907    def __init__(self, family, yaml, req_value, rsp_value):
908        super().__init__(family, yaml, req_value, rsp_value)
909
910        self.render_name = c_lower(family.ident_name + '_' + self.name)
911
912        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
913                         ('dump' in yaml and 'request' in yaml['dump'])
914
915        self.has_ntf = False
916
917        # Added by resolve:
918        self.enum_name = None
919        delattr(self, "enum_name")
920
921    def resolve(self):
922        self.resolve_up(super())
923
924        if not self.is_async:
925            self.enum_name = self.family.op_prefix + c_upper(self.name)
926        else:
927            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
928
929    def mark_has_ntf(self):
930        self.has_ntf = True
931
932
933class Family(SpecFamily):
934    def __init__(self, file_name, exclude_ops):
935        # Added by resolve:
936        self.c_name = None
937        delattr(self, "c_name")
938        self.op_prefix = None
939        delattr(self, "op_prefix")
940        self.async_op_prefix = None
941        delattr(self, "async_op_prefix")
942        self.mcgrps = None
943        delattr(self, "mcgrps")
944        self.consts = None
945        delattr(self, "consts")
946        self.hooks = None
947        delattr(self, "hooks")
948
949        super().__init__(file_name, exclude_ops=exclude_ops)
950
951        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
952        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
953
954        if 'definitions' not in self.yaml:
955            self.yaml['definitions'] = []
956
957        if 'uapi-header' in self.yaml:
958            self.uapi_header = self.yaml['uapi-header']
959        else:
960            self.uapi_header = f"linux/{self.ident_name}.h"
961        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
962            self.uapi_header_name = self.uapi_header[6:-2]
963        else:
964            self.uapi_header_name = self.ident_name
965
966    def resolve(self):
967        self.resolve_up(super())
968
969        if self.yaml.get('protocol', 'genetlink') not in {'genetlink', 'genetlink-c', 'genetlink-legacy'}:
970            raise Exception("Codegen only supported for genetlink")
971
972        self.c_name = c_lower(self.ident_name)
973        if 'name-prefix' in self.yaml['operations']:
974            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
975        else:
976            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
977        if 'async-prefix' in self.yaml['operations']:
978            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
979        else:
980            self.async_op_prefix = self.op_prefix
981
982        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
983
984        self.hooks = dict()
985        for when in ['pre', 'post']:
986            self.hooks[when] = dict()
987            for op_mode in ['do', 'dump']:
988                self.hooks[when][op_mode] = dict()
989                self.hooks[when][op_mode]['set'] = set()
990                self.hooks[when][op_mode]['list'] = []
991
992        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
993        self.root_sets = dict()
994        # dict space-name -> set('request', 'reply')
995        self.pure_nested_structs = dict()
996
997        self._mark_notify()
998        self._mock_up_events()
999
1000        self._load_root_sets()
1001        self._load_nested_sets()
1002        self._load_attr_use()
1003        self._load_hooks()
1004
1005        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
1006        if self.kernel_policy == 'global':
1007            self._load_global_policy()
1008
1009    def new_enum(self, elem):
1010        return EnumSet(self, elem)
1011
1012    def new_attr_set(self, elem):
1013        return AttrSet(self, elem)
1014
1015    def new_operation(self, elem, req_value, rsp_value):
1016        return Operation(self, elem, req_value, rsp_value)
1017
1018    def _mark_notify(self):
1019        for op in self.msgs.values():
1020            if 'notify' in op:
1021                self.ops[op['notify']].mark_has_ntf()
1022
1023    # Fake a 'do' equivalent of all events, so that we can render their response parsing
1024    def _mock_up_events(self):
1025        for op in self.yaml['operations']['list']:
1026            if 'event' in op:
1027                op['do'] = {
1028                    'reply': {
1029                        'attributes': op['event']['attributes']
1030                    }
1031                }
1032
1033    def _load_root_sets(self):
1034        for op_name, op in self.msgs.items():
1035            if 'attribute-set' not in op:
1036                continue
1037
1038            req_attrs = set()
1039            rsp_attrs = set()
1040            for op_mode in ['do', 'dump']:
1041                if op_mode in op and 'request' in op[op_mode]:
1042                    req_attrs.update(set(op[op_mode]['request']['attributes']))
1043                if op_mode in op and 'reply' in op[op_mode]:
1044                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
1045            if 'event' in op:
1046                rsp_attrs.update(set(op['event']['attributes']))
1047
1048            if op['attribute-set'] not in self.root_sets:
1049                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1050            else:
1051                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1052                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1053
1054    def _sort_pure_types(self):
1055        # Try to reorder according to dependencies
1056        pns_key_list = list(self.pure_nested_structs.keys())
1057        pns_key_seen = set()
1058        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1059        for _ in range(rounds):
1060            if len(pns_key_list) == 0:
1061                break
1062            name = pns_key_list.pop(0)
1063            finished = True
1064            for _, spec in self.attr_sets[name].items():
1065                if 'nested-attributes' in spec:
1066                    nested = spec['nested-attributes']
1067                    # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1068                    if self.pure_nested_structs[nested].recursive:
1069                        continue
1070                    if nested not in pns_key_seen:
1071                        # Dicts are sorted, this will make struct last
1072                        struct = self.pure_nested_structs.pop(name)
1073                        self.pure_nested_structs[name] = struct
1074                        finished = False
1075                        break
1076            if finished:
1077                pns_key_seen.add(name)
1078            else:
1079                pns_key_list.append(name)
1080
1081    def _load_nested_sets(self):
1082        attr_set_queue = list(self.root_sets.keys())
1083        attr_set_seen = set(self.root_sets.keys())
1084
1085        while len(attr_set_queue):
1086            a_set = attr_set_queue.pop(0)
1087            for attr, spec in self.attr_sets[a_set].items():
1088                if 'nested-attributes' not in spec:
1089                    continue
1090
1091                nested = spec['nested-attributes']
1092                if nested not in attr_set_seen:
1093                    attr_set_queue.append(nested)
1094                    attr_set_seen.add(nested)
1095
1096                inherit = set()
1097                if nested not in self.root_sets:
1098                    if nested not in self.pure_nested_structs:
1099                        self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
1100                else:
1101                    raise Exception(f'Using attr set as root and nested not supported - {nested}')
1102
1103                if 'type-value' in spec:
1104                    if nested in self.root_sets:
1105                        raise Exception("Inheriting members to a space used as root not supported")
1106                    inherit.update(set(spec['type-value']))
1107                elif spec['type'] == 'indexed-array':
1108                    inherit.add('idx')
1109                self.pure_nested_structs[nested].set_inherited(inherit)
1110
1111        for root_set, rs_members in self.root_sets.items():
1112            for attr, spec in self.attr_sets[root_set].items():
1113                if 'nested-attributes' in spec:
1114                    nested = spec['nested-attributes']
1115                    if attr in rs_members['request']:
1116                        self.pure_nested_structs[nested].request = True
1117                    if attr in rs_members['reply']:
1118                        self.pure_nested_structs[nested].reply = True
1119
1120        self._sort_pure_types()
1121
1122        # Propagate the request / reply / recursive
1123        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1124            for _, spec in self.attr_sets[attr_set].items():
1125                if 'nested-attributes' in spec:
1126                    child_name = spec['nested-attributes']
1127                    struct.child_nests.add(child_name)
1128                    child = self.pure_nested_structs.get(child_name)
1129                    if child:
1130                        if not child.recursive:
1131                            struct.child_nests.update(child.child_nests)
1132                        child.request |= struct.request
1133                        child.reply |= struct.reply
1134                if attr_set in struct.child_nests:
1135                    struct.recursive = True
1136
1137        self._sort_pure_types()
1138
1139    def _load_attr_use(self):
1140        for _, struct in self.pure_nested_structs.items():
1141            if struct.request:
1142                for _, arg in struct.member_list():
1143                    arg.set_request()
1144            if struct.reply:
1145                for _, arg in struct.member_list():
1146                    arg.set_reply()
1147
1148        for root_set, rs_members in self.root_sets.items():
1149            for attr, spec in self.attr_sets[root_set].items():
1150                if attr in rs_members['request']:
1151                    spec.set_request()
1152                if attr in rs_members['reply']:
1153                    spec.set_reply()
1154
1155    def _load_global_policy(self):
1156        global_set = set()
1157        attr_set_name = None
1158        for op_name, op in self.ops.items():
1159            if not op:
1160                continue
1161            if 'attribute-set' not in op:
1162                continue
1163
1164            if attr_set_name is None:
1165                attr_set_name = op['attribute-set']
1166            if attr_set_name != op['attribute-set']:
1167                raise Exception('For a global policy all ops must use the same set')
1168
1169            for op_mode in ['do', 'dump']:
1170                if op_mode in op:
1171                    req = op[op_mode].get('request')
1172                    if req:
1173                        global_set.update(req.get('attributes', []))
1174
1175        self.global_policy = []
1176        self.global_policy_set = attr_set_name
1177        for attr in self.attr_sets[attr_set_name]:
1178            if attr in global_set:
1179                self.global_policy.append(attr)
1180
1181    def _load_hooks(self):
1182        for op in self.ops.values():
1183            for op_mode in ['do', 'dump']:
1184                if op_mode not in op:
1185                    continue
1186                for when in ['pre', 'post']:
1187                    if when not in op[op_mode]:
1188                        continue
1189                    name = op[op_mode][when]
1190                    if name in self.hooks[when][op_mode]['set']:
1191                        continue
1192                    self.hooks[when][op_mode]['set'].add(name)
1193                    self.hooks[when][op_mode]['list'].append(name)
1194
1195
1196class RenderInfo:
1197    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1198        self.family = family
1199        self.nl = cw.nlib
1200        self.ku_space = ku_space
1201        self.op_mode = op_mode
1202        self.op = op
1203
1204        self.fixed_hdr = None
1205        if op and op.fixed_header:
1206            self.fixed_hdr = 'struct ' + c_lower(op.fixed_header)
1207
1208        # 'do' and 'dump' response parsing is identical
1209        self.type_consistent = True
1210        if op_mode != 'do' and 'dump' in op:
1211            if 'do' in op:
1212                if ('reply' in op['do']) != ('reply' in op["dump"]):
1213                    self.type_consistent = False
1214                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1215                    self.type_consistent = False
1216            else:
1217                self.type_consistent = False
1218
1219        self.attr_set = attr_set
1220        if not self.attr_set:
1221            self.attr_set = op['attribute-set']
1222
1223        self.type_name_conflict = False
1224        if op:
1225            self.type_name = c_lower(op.name)
1226        else:
1227            self.type_name = c_lower(attr_set)
1228            if attr_set in family.consts:
1229                self.type_name_conflict = True
1230
1231        self.cw = cw
1232
1233        self.struct = dict()
1234        if op_mode == 'notify':
1235            op_mode = 'do'
1236        for op_dir in ['request', 'reply']:
1237            if op:
1238                type_list = []
1239                if op_dir in op[op_mode]:
1240                    type_list = op[op_mode][op_dir]['attributes']
1241                self.struct[op_dir] = Struct(family, self.attr_set, type_list=type_list)
1242        if op_mode == 'event':
1243            self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
1244
1245
1246class CodeWriter:
1247    def __init__(self, nlib, out_file=None, overwrite=True):
1248        self.nlib = nlib
1249        self._overwrite = overwrite
1250
1251        self._nl = False
1252        self._block_end = False
1253        self._silent_block = False
1254        self._ind = 0
1255        self._ifdef_block = None
1256        if out_file is None:
1257            self._out = os.sys.stdout
1258        else:
1259            self._out = tempfile.NamedTemporaryFile('w+')
1260            self._out_file = out_file
1261
1262    def __del__(self):
1263        self.close_out_file()
1264
1265    def close_out_file(self):
1266        if self._out == os.sys.stdout:
1267            return
1268        # Avoid modifying the file if contents didn't change
1269        self._out.flush()
1270        if not self._overwrite and os.path.isfile(self._out_file):
1271            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1272                return
1273        with open(self._out_file, 'w+') as out_file:
1274            self._out.seek(0)
1275            shutil.copyfileobj(self._out, out_file)
1276            self._out.close()
1277        self._out = os.sys.stdout
1278
1279    @classmethod
1280    def _is_cond(cls, line):
1281        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1282
1283    def p(self, line, add_ind=0):
1284        if self._block_end:
1285            self._block_end = False
1286            if line.startswith('else'):
1287                line = '} ' + line
1288            else:
1289                self._out.write('\t' * self._ind + '}\n')
1290
1291        if self._nl:
1292            self._out.write('\n')
1293            self._nl = False
1294
1295        ind = self._ind
1296        if line[-1] == ':':
1297            ind -= 1
1298        if self._silent_block:
1299            ind += 1
1300        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1301        if line[0] == '#':
1302            ind = 0
1303        if add_ind:
1304            ind += add_ind
1305        self._out.write('\t' * ind + line + '\n')
1306
1307    def nl(self):
1308        self._nl = True
1309
1310    def block_start(self, line=''):
1311        if line:
1312            line = line + ' '
1313        self.p(line + '{')
1314        self._ind += 1
1315
1316    def block_end(self, line=''):
1317        if line and line[0] not in {';', ','}:
1318            line = ' ' + line
1319        self._ind -= 1
1320        self._nl = False
1321        if not line:
1322            # Delay printing closing bracket in case "else" comes next
1323            if self._block_end:
1324                self._out.write('\t' * (self._ind + 1) + '}\n')
1325            self._block_end = True
1326        else:
1327            self.p('}' + line)
1328
1329    def write_doc_line(self, doc, indent=True):
1330        words = doc.split()
1331        line = ' *'
1332        for word in words:
1333            if len(line) + len(word) >= 79:
1334                self.p(line)
1335                line = ' *'
1336                if indent:
1337                    line += '  '
1338            line += ' ' + word
1339        self.p(line)
1340
1341    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1342        if not args:
1343            args = ['void']
1344
1345        if doc:
1346            self.p('/*')
1347            self.p(' * ' + doc)
1348            self.p(' */')
1349
1350        oneline = qual_ret
1351        if qual_ret[-1] != '*':
1352            oneline += ' '
1353        oneline += f"{name}({', '.join(args)}){suffix}"
1354
1355        if len(oneline) < 80:
1356            self.p(oneline)
1357            return
1358
1359        v = qual_ret
1360        if len(v) > 3:
1361            self.p(v)
1362            v = ''
1363        elif qual_ret[-1] != '*':
1364            v += ' '
1365        v += name + '('
1366        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1367        delta_ind = len(v) - len(ind)
1368        v += args[0]
1369        i = 1
1370        while i < len(args):
1371            next_len = len(v) + len(args[i])
1372            if v[0] == '\t':
1373                next_len += delta_ind
1374            if next_len > 76:
1375                self.p(v + ',')
1376                v = ind
1377            else:
1378                v += ', '
1379            v += args[i]
1380            i += 1
1381        self.p(v + ')' + suffix)
1382
1383    def write_func_lvar(self, local_vars):
1384        if not local_vars:
1385            return
1386
1387        if type(local_vars) is str:
1388            local_vars = [local_vars]
1389
1390        local_vars.sort(key=len, reverse=True)
1391        for var in local_vars:
1392            self.p(var)
1393        self.nl()
1394
1395    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1396        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1397        self.write_func_lvar(local_vars=local_vars)
1398
1399        self.block_start()
1400        for line in body:
1401            self.p(line)
1402        self.block_end()
1403
1404    def writes_defines(self, defines):
1405        longest = 0
1406        for define in defines:
1407            if len(define[0]) > longest:
1408                longest = len(define[0])
1409        longest = ((longest + 8) // 8) * 8
1410        for define in defines:
1411            line = '#define ' + define[0]
1412            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1413            if type(define[1]) is int:
1414                line += str(define[1])
1415            elif type(define[1]) is str:
1416                line += '"' + define[1] + '"'
1417            self.p(line)
1418
1419    def write_struct_init(self, members):
1420        longest = max([len(x[0]) for x in members])
1421        longest += 1  # because we prepend a .
1422        longest = ((longest + 8) // 8) * 8
1423        for one in members:
1424            line = '.' + one[0]
1425            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1426            line += '= ' + str(one[1]) + ','
1427            self.p(line)
1428
1429    def ifdef_block(self, config):
1430        config_option = None
1431        if config:
1432            config_option = 'CONFIG_' + c_upper(config)
1433        if self._ifdef_block == config_option:
1434            return
1435
1436        if self._ifdef_block:
1437            self.p('#endif /* ' + self._ifdef_block + ' */')
1438        if config_option:
1439            self.p('#ifdef ' + config_option)
1440        self._ifdef_block = config_option
1441
1442
1443scalars = {'u8', 'u16', 'u32', 'u64', 's32', 's64', 'uint', 'sint'}
1444
1445direction_to_suffix = {
1446    'reply': '_rsp',
1447    'request': '_req',
1448    '': ''
1449}
1450
1451op_mode_to_wrapper = {
1452    'do': '',
1453    'dump': '_list',
1454    'notify': '_ntf',
1455    'event': '',
1456}
1457
1458_C_KW = {
1459    'auto',
1460    'bool',
1461    'break',
1462    'case',
1463    'char',
1464    'const',
1465    'continue',
1466    'default',
1467    'do',
1468    'double',
1469    'else',
1470    'enum',
1471    'extern',
1472    'float',
1473    'for',
1474    'goto',
1475    'if',
1476    'inline',
1477    'int',
1478    'long',
1479    'register',
1480    'return',
1481    'short',
1482    'signed',
1483    'sizeof',
1484    'static',
1485    'struct',
1486    'switch',
1487    'typedef',
1488    'union',
1489    'unsigned',
1490    'void',
1491    'volatile',
1492    'while'
1493}
1494
1495
1496def rdir(direction):
1497    if direction == 'reply':
1498        return 'request'
1499    if direction == 'request':
1500        return 'reply'
1501    return direction
1502
1503
1504def op_prefix(ri, direction, deref=False):
1505    suffix = f"_{ri.type_name}"
1506
1507    if not ri.op_mode or ri.op_mode == 'do':
1508        suffix += f"{direction_to_suffix[direction]}"
1509    else:
1510        if direction == 'request':
1511            suffix += '_req_dump'
1512        else:
1513            if ri.type_consistent:
1514                if deref:
1515                    suffix += f"{direction_to_suffix[direction]}"
1516                else:
1517                    suffix += op_mode_to_wrapper[ri.op_mode]
1518            else:
1519                suffix += '_rsp'
1520                suffix += '_dump' if deref else '_list'
1521
1522    return f"{ri.family.c_name}{suffix}"
1523
1524
1525def type_name(ri, direction, deref=False):
1526    return f"struct {op_prefix(ri, direction, deref=deref)}"
1527
1528
1529def print_prototype(ri, direction, terminate=True, doc=None):
1530    suffix = ';' if terminate else ''
1531
1532    fname = ri.op.render_name
1533    if ri.op_mode == 'dump':
1534        fname += '_dump'
1535
1536    args = ['struct ynl_sock *ys']
1537    if 'request' in ri.op[ri.op_mode]:
1538        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1539
1540    ret = 'int'
1541    if 'reply' in ri.op[ri.op_mode]:
1542        ret = f"{type_name(ri, rdir(direction))} *"
1543
1544    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1545
1546
1547def print_req_prototype(ri):
1548    print_prototype(ri, "request", doc=ri.op['doc'])
1549
1550
1551def print_dump_prototype(ri):
1552    print_prototype(ri, "request")
1553
1554
1555def put_typol_fwd(cw, struct):
1556    cw.p(f'extern const struct ynl_policy_nest {struct.render_name}_nest;')
1557
1558
1559def put_typol(cw, struct):
1560    type_max = struct.attr_set.max_name
1561    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1562
1563    for _, arg in struct.member_list():
1564        arg.attr_typol(cw)
1565
1566    cw.block_end(line=';')
1567    cw.nl()
1568
1569    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1570    cw.p(f'.max_attr = {type_max},')
1571    cw.p(f'.table = {struct.render_name}_policy,')
1572    cw.block_end(line=';')
1573    cw.nl()
1574
1575
1576def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1577    args = [f'int {arg_name}']
1578    if enum:
1579        args = [enum.user_type + ' ' + arg_name]
1580    cw.write_func_prot('const char *', f'{render_name}_str', args)
1581    cw.block_start()
1582    if enum and enum.type == 'flags':
1583        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1584    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
1585    cw.p('return NULL;')
1586    cw.p(f'return {map_name}[{arg_name}];')
1587    cw.block_end()
1588    cw.nl()
1589
1590
1591def put_op_name_fwd(family, cw):
1592    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
1593
1594
1595def put_op_name(family, cw):
1596    map_name = f'{family.c_name}_op_strmap'
1597    cw.block_start(line=f"static const char * const {map_name}[] =")
1598    for op_name, op in family.msgs.items():
1599        if op.rsp_value:
1600            # Make sure we don't add duplicated entries, if multiple commands
1601            # produce the same response in legacy families.
1602            if family.rsp_by_value[op.rsp_value] != op:
1603                cw.p(f'// skip "{op_name}", duplicate reply value')
1604                continue
1605
1606            if op.req_value == op.rsp_value:
1607                cw.p(f'[{op.enum_name}] = "{op_name}",')
1608            else:
1609                cw.p(f'[{op.rsp_value}] = "{op_name}",')
1610    cw.block_end(line=';')
1611    cw.nl()
1612
1613    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
1614
1615
1616def put_enum_to_str_fwd(family, cw, enum):
1617    args = [enum.user_type + ' value']
1618    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1619
1620
1621def put_enum_to_str(family, cw, enum):
1622    map_name = f'{enum.render_name}_strmap'
1623    cw.block_start(line=f"static const char * const {map_name}[] =")
1624    for entry in enum.entries.values():
1625        cw.p(f'[{entry.value}] = "{entry.name}",')
1626    cw.block_end(line=';')
1627    cw.nl()
1628
1629    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
1630
1631
1632def put_req_nested_prototype(ri, struct, suffix=';'):
1633    func_args = ['struct nlmsghdr *nlh',
1634                 'unsigned int attr_type',
1635                 f'{struct.ptr_name}obj']
1636
1637    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
1638                          suffix=suffix)
1639
1640
1641def put_req_nested(ri, struct):
1642    put_req_nested_prototype(ri, struct, suffix='')
1643    ri.cw.block_start()
1644    ri.cw.write_func_lvar('struct nlattr *nest;')
1645
1646    ri.cw.p("nest = ynl_attr_nest_start(nlh, attr_type);")
1647
1648    for _, arg in struct.member_list():
1649        arg.attr_put(ri, "obj")
1650
1651    ri.cw.p("ynl_attr_nest_end(nlh, nest);")
1652
1653    ri.cw.nl()
1654    ri.cw.p('return 0;')
1655    ri.cw.block_end()
1656    ri.cw.nl()
1657
1658
1659def _multi_parse(ri, struct, init_lines, local_vars):
1660    if struct.nested:
1661        iter_line = "ynl_attr_for_each_nested(attr, nested)"
1662    else:
1663        if ri.fixed_hdr:
1664            local_vars += ['void *hdr;']
1665        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
1666
1667    array_nests = set()
1668    multi_attrs = set()
1669    needs_parg = False
1670    for arg, aspec in struct.member_list():
1671        if aspec['type'] == 'indexed-array' and 'sub-type' in aspec:
1672            if aspec["sub-type"] == 'nest':
1673                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1674                array_nests.add(arg)
1675            else:
1676                raise Exception(f'Not supported sub-type {aspec["sub-type"]}')
1677        if 'multi-attr' in aspec:
1678            multi_attrs.add(arg)
1679        needs_parg |= 'nested-attributes' in aspec
1680    if array_nests or multi_attrs:
1681        local_vars.append('int i;')
1682    if needs_parg:
1683        local_vars.append('struct ynl_parse_arg parg;')
1684        init_lines.append('parg.ys = yarg->ys;')
1685
1686    all_multi = array_nests | multi_attrs
1687
1688    for anest in sorted(all_multi):
1689        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
1690
1691    ri.cw.block_start()
1692    ri.cw.write_func_lvar(local_vars)
1693
1694    for line in init_lines:
1695        ri.cw.p(line)
1696    ri.cw.nl()
1697
1698    for arg in struct.inherited:
1699        ri.cw.p(f'dst->{arg} = {arg};')
1700
1701    if ri.fixed_hdr:
1702        ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
1703        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({ri.fixed_hdr}));")
1704    for anest in sorted(all_multi):
1705        aspec = struct[anest]
1706        ri.cw.p(f"if (dst->{aspec.c_name})")
1707        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
1708
1709    ri.cw.nl()
1710    ri.cw.block_start(line=iter_line)
1711    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
1712    ri.cw.nl()
1713
1714    first = True
1715    for _, arg in struct.member_list():
1716        good = arg.attr_get(ri, 'dst', first=first)
1717        # First may be 'unused' or 'pad', ignore those
1718        first &= not good
1719
1720    ri.cw.block_end()
1721    ri.cw.nl()
1722
1723    for anest in sorted(array_nests):
1724        aspec = struct[anest]
1725
1726        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1727        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1728        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1729        ri.cw.p('i = 0;')
1730        ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1731        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1732        ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1733        ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
1734        ri.cw.p('return YNL_PARSE_CB_ERROR;')
1735        ri.cw.p('i++;')
1736        ri.cw.block_end()
1737        ri.cw.block_end()
1738    ri.cw.nl()
1739
1740    for anest in sorted(multi_attrs):
1741        aspec = struct[anest]
1742        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1743        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1744        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1745        ri.cw.p('i = 0;')
1746        if 'nested-attributes' in aspec:
1747            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1748        ri.cw.block_start(line=iter_line)
1749        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
1750        if 'nested-attributes' in aspec:
1751            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1752            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1753            ri.cw.p('return YNL_PARSE_CB_ERROR;')
1754        elif aspec.type in scalars:
1755            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
1756        else:
1757            raise Exception('Nest parsing type not supported yet')
1758        ri.cw.p('i++;')
1759        ri.cw.block_end()
1760        ri.cw.block_end()
1761        ri.cw.block_end()
1762    ri.cw.nl()
1763
1764    if struct.nested:
1765        ri.cw.p('return 0;')
1766    else:
1767        ri.cw.p('return YNL_PARSE_CB_OK;')
1768    ri.cw.block_end()
1769    ri.cw.nl()
1770
1771
1772def parse_rsp_nested_prototype(ri, struct, suffix=';'):
1773    func_args = ['struct ynl_parse_arg *yarg',
1774                 'const struct nlattr *nested']
1775    for arg in struct.inherited:
1776        func_args.append('__u32 ' + arg)
1777
1778    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
1779                          suffix=suffix)
1780
1781
1782def parse_rsp_nested(ri, struct):
1783    parse_rsp_nested_prototype(ri, struct, suffix='')
1784
1785    local_vars = ['const struct nlattr *attr;',
1786                  f'{struct.ptr_name}dst = yarg->data;']
1787    init_lines = []
1788
1789    if struct.member_list():
1790        _multi_parse(ri, struct, init_lines, local_vars)
1791    else:
1792        # Empty nest
1793        ri.cw.block_start()
1794        ri.cw.p('return 0;')
1795        ri.cw.block_end()
1796        ri.cw.nl()
1797
1798
1799def parse_rsp_msg(ri, deref=False):
1800    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1801        return
1802
1803    func_args = ['const struct nlmsghdr *nlh',
1804                 'struct ynl_parse_arg *yarg']
1805
1806    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1807                  'const struct nlattr *attr;']
1808    init_lines = ['dst = yarg->data;']
1809
1810    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1811
1812    if ri.struct["reply"].member_list():
1813        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1814    else:
1815        # Empty reply
1816        ri.cw.block_start()
1817        ri.cw.p('return YNL_PARSE_CB_OK;')
1818        ri.cw.block_end()
1819        ri.cw.nl()
1820
1821
1822def print_req(ri):
1823    ret_ok = '0'
1824    ret_err = '-1'
1825    direction = "request"
1826    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
1827                  'struct nlmsghdr *nlh;',
1828                  'int err;']
1829
1830    if 'reply' in ri.op[ri.op_mode]:
1831        ret_ok = 'rsp'
1832        ret_err = 'NULL'
1833        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
1834
1835    if ri.fixed_hdr:
1836        local_vars += ['size_t hdr_len;',
1837                       'void *hdr;']
1838
1839    print_prototype(ri, direction, terminate=False)
1840    ri.cw.block_start()
1841    ri.cw.write_func_lvar(local_vars)
1842
1843    ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1844
1845    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1846    if 'reply' in ri.op[ri.op_mode]:
1847        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1848    ri.cw.nl()
1849
1850    if ri.fixed_hdr:
1851        ri.cw.p("hdr_len = sizeof(req->_hdr);")
1852        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
1853        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
1854        ri.cw.nl()
1855
1856    for _, attr in ri.struct["request"].member_list():
1857        attr.attr_put(ri, "req")
1858    ri.cw.nl()
1859
1860    if 'reply' in ri.op[ri.op_mode]:
1861        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1862        ri.cw.p('yrs.yarg.data = rsp;')
1863        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
1864        if ri.op.value is not None:
1865            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
1866        else:
1867            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
1868        ri.cw.nl()
1869    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
1870    ri.cw.p('if (err < 0)')
1871    if 'reply' in ri.op[ri.op_mode]:
1872        ri.cw.p('goto err_free;')
1873    else:
1874        ri.cw.p('return -1;')
1875    ri.cw.nl()
1876
1877    ri.cw.p(f"return {ret_ok};")
1878    ri.cw.nl()
1879
1880    if 'reply' in ri.op[ri.op_mode]:
1881        ri.cw.p('err_free:')
1882        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1883        ri.cw.p(f"return {ret_err};")
1884
1885    ri.cw.block_end()
1886
1887
1888def print_dump(ri):
1889    direction = "request"
1890    print_prototype(ri, direction, terminate=False)
1891    ri.cw.block_start()
1892    local_vars = ['struct ynl_dump_state yds = {};',
1893                  'struct nlmsghdr *nlh;',
1894                  'int err;']
1895
1896    if ri.fixed_hdr:
1897        local_vars += ['size_t hdr_len;',
1898                       'void *hdr;']
1899
1900    ri.cw.write_func_lvar(local_vars)
1901
1902    ri.cw.p('yds.yarg.ys = ys;')
1903    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1904    ri.cw.p("yds.yarg.data = NULL;")
1905    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1906    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1907    if ri.op.value is not None:
1908        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
1909    else:
1910        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
1911    ri.cw.nl()
1912    ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1913
1914    if ri.fixed_hdr:
1915        ri.cw.p("hdr_len = sizeof(req->_hdr);")
1916        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
1917        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
1918        ri.cw.nl()
1919
1920    if "request" in ri.op[ri.op_mode]:
1921        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1922        ri.cw.nl()
1923        for _, attr in ri.struct["request"].member_list():
1924            attr.attr_put(ri, "req")
1925    ri.cw.nl()
1926
1927    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
1928    ri.cw.p('if (err < 0)')
1929    ri.cw.p('goto free_list;')
1930    ri.cw.nl()
1931
1932    ri.cw.p('return yds.first;')
1933    ri.cw.nl()
1934    ri.cw.p('free_list:')
1935    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
1936    ri.cw.p('return NULL;')
1937    ri.cw.block_end()
1938
1939
1940def call_free(ri, direction, var):
1941    return f"{op_prefix(ri, direction)}_free({var});"
1942
1943
1944def free_arg_name(direction):
1945    if direction:
1946        return direction_to_suffix[direction][1:]
1947    return 'obj'
1948
1949
1950def print_alloc_wrapper(ri, direction):
1951    name = op_prefix(ri, direction)
1952    ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
1953    ri.cw.block_start()
1954    ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
1955    ri.cw.block_end()
1956
1957
1958def print_free_prototype(ri, direction, suffix=';'):
1959    name = op_prefix(ri, direction)
1960    struct_name = name
1961    if ri.type_name_conflict:
1962        struct_name += '_'
1963    arg = free_arg_name(direction)
1964    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
1965
1966
1967def _print_type(ri, direction, struct):
1968    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
1969    if not direction and ri.type_name_conflict:
1970        suffix += '_'
1971
1972    if ri.op_mode == 'dump':
1973        suffix += '_dump'
1974
1975    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
1976
1977    if ri.fixed_hdr:
1978        ri.cw.p(ri.fixed_hdr + ' _hdr;')
1979        ri.cw.nl()
1980
1981    meta_started = False
1982    for _, attr in struct.member_list():
1983        for type_filter in ['len', 'bit']:
1984            line = attr.presence_member(ri.ku_space, type_filter)
1985            if line:
1986                if not meta_started:
1987                    ri.cw.block_start(line=f"struct")
1988                    meta_started = True
1989                ri.cw.p(line)
1990    if meta_started:
1991        ri.cw.block_end(line='_present;')
1992        ri.cw.nl()
1993
1994    for arg in struct.inherited:
1995        ri.cw.p(f"__u32 {arg};")
1996
1997    for _, attr in struct.member_list():
1998        attr.struct_member(ri)
1999
2000    ri.cw.block_end(line=';')
2001    ri.cw.nl()
2002
2003
2004def print_type(ri, direction):
2005    _print_type(ri, direction, ri.struct[direction])
2006
2007
2008def print_type_full(ri, struct):
2009    _print_type(ri, "", struct)
2010
2011
2012def print_type_helpers(ri, direction, deref=False):
2013    print_free_prototype(ri, direction)
2014    ri.cw.nl()
2015
2016    if ri.ku_space == 'user' and direction == 'request':
2017        for _, attr in ri.struct[direction].member_list():
2018            attr.setter(ri, ri.attr_set, direction, deref=deref)
2019    ri.cw.nl()
2020
2021
2022def print_req_type_helpers(ri):
2023    if len(ri.struct["request"].attr_list) == 0:
2024        return
2025    print_alloc_wrapper(ri, "request")
2026    print_type_helpers(ri, "request")
2027
2028
2029def print_rsp_type_helpers(ri):
2030    if 'reply' not in ri.op[ri.op_mode]:
2031        return
2032    print_type_helpers(ri, "reply")
2033
2034
2035def print_parse_prototype(ri, direction, terminate=True):
2036    suffix = "_rsp" if direction == "reply" else "_req"
2037    term = ';' if terminate else ''
2038
2039    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
2040                          ['const struct nlattr **tb',
2041                           f"struct {ri.op.render_name}{suffix} *req"],
2042                          suffix=term)
2043
2044
2045def print_req_type(ri):
2046    if len(ri.struct["request"].attr_list) == 0:
2047        return
2048    print_type(ri, "request")
2049
2050
2051def print_req_free(ri):
2052    if 'request' not in ri.op[ri.op_mode]:
2053        return
2054    _free_type(ri, 'request', ri.struct['request'])
2055
2056
2057def print_rsp_type(ri):
2058    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2059        direction = 'reply'
2060    elif ri.op_mode == 'event':
2061        direction = 'reply'
2062    else:
2063        return
2064    print_type(ri, direction)
2065
2066
2067def print_wrapped_type(ri):
2068    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2069    if ri.op_mode == 'dump':
2070        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2071    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2072        ri.cw.p('__u16 family;')
2073        ri.cw.p('__u8 cmd;')
2074        ri.cw.p('struct ynl_ntf_base_type *next;')
2075        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2076    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2077    ri.cw.block_end(line=';')
2078    ri.cw.nl()
2079    print_free_prototype(ri, 'reply')
2080    ri.cw.nl()
2081
2082
2083def _free_type_members_iter(ri, struct):
2084    for _, attr in struct.member_list():
2085        if attr.free_needs_iter():
2086            ri.cw.p('unsigned int i;')
2087            ri.cw.nl()
2088            break
2089
2090
2091def _free_type_members(ri, var, struct, ref=''):
2092    for _, attr in struct.member_list():
2093        attr.free(ri, var, ref)
2094
2095
2096def _free_type(ri, direction, struct):
2097    var = free_arg_name(direction)
2098
2099    print_free_prototype(ri, direction, suffix='')
2100    ri.cw.block_start()
2101    _free_type_members_iter(ri, struct)
2102    _free_type_members(ri, var, struct)
2103    if direction:
2104        ri.cw.p(f'free({var});')
2105    ri.cw.block_end()
2106    ri.cw.nl()
2107
2108
2109def free_rsp_nested_prototype(ri):
2110        print_free_prototype(ri, "")
2111
2112
2113def free_rsp_nested(ri, struct):
2114    _free_type(ri, "", struct)
2115
2116
2117def print_rsp_free(ri):
2118    if 'reply' not in ri.op[ri.op_mode]:
2119        return
2120    _free_type(ri, 'reply', ri.struct['reply'])
2121
2122
2123def print_dump_type_free(ri):
2124    sub_type = type_name(ri, 'reply')
2125
2126    print_free_prototype(ri, 'reply', suffix='')
2127    ri.cw.block_start()
2128    ri.cw.p(f"{sub_type} *next = rsp;")
2129    ri.cw.nl()
2130    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2131    _free_type_members_iter(ri, ri.struct['reply'])
2132    ri.cw.p('rsp = next;')
2133    ri.cw.p('next = rsp->next;')
2134    ri.cw.nl()
2135
2136    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2137    ri.cw.p(f'free(rsp);')
2138    ri.cw.block_end()
2139    ri.cw.block_end()
2140    ri.cw.nl()
2141
2142
2143def print_ntf_type_free(ri):
2144    print_free_prototype(ri, 'reply', suffix='')
2145    ri.cw.block_start()
2146    _free_type_members_iter(ri, ri.struct['reply'])
2147    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2148    ri.cw.p(f'free(rsp);')
2149    ri.cw.block_end()
2150    ri.cw.nl()
2151
2152
2153def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2154    if terminate and ri and policy_should_be_static(struct.family):
2155        return
2156
2157    if terminate:
2158        prefix = 'extern '
2159    else:
2160        if ri and policy_should_be_static(struct.family):
2161            prefix = 'static '
2162        else:
2163            prefix = ''
2164
2165    suffix = ';' if terminate else ' = {'
2166
2167    max_attr = struct.attr_max_val
2168    if ri:
2169        name = ri.op.render_name
2170        if ri.op.dual_policy:
2171            name += '_' + ri.op_mode
2172    else:
2173        name = struct.render_name
2174    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2175
2176
2177def print_req_policy(cw, struct, ri=None):
2178    if ri and ri.op:
2179        cw.ifdef_block(ri.op.get('config-cond', None))
2180    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2181    for _, arg in struct.member_list():
2182        arg.attr_policy(cw)
2183    cw.p("};")
2184    cw.ifdef_block(None)
2185    cw.nl()
2186
2187
2188def kernel_can_gen_family_struct(family):
2189    return family.proto == 'genetlink'
2190
2191
2192def policy_should_be_static(family):
2193    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2194
2195
2196def print_kernel_policy_ranges(family, cw):
2197    first = True
2198    for _, attr_set in family.attr_sets.items():
2199        if attr_set.subset_of:
2200            continue
2201
2202        for _, attr in attr_set.items():
2203            if not attr.request:
2204                continue
2205            if 'full-range' not in attr.checks:
2206                continue
2207
2208            if first:
2209                cw.p('/* Integer value ranges */')
2210                first = False
2211
2212            sign = '' if attr.type[0] == 'u' else '_signed'
2213            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2214            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2215            members = []
2216            if 'min' in attr.checks:
2217                members.append(('min', attr.get_limit_str('min', suffix=suffix)))
2218            if 'max' in attr.checks:
2219                members.append(('max', attr.get_limit_str('max', suffix=suffix)))
2220            cw.write_struct_init(members)
2221            cw.block_end(line=';')
2222            cw.nl()
2223
2224
2225def print_kernel_op_table_fwd(family, cw, terminate):
2226    exported = not kernel_can_gen_family_struct(family)
2227
2228    if not terminate or exported:
2229        cw.p(f"/* Ops table for {family.ident_name} */")
2230
2231        pol_to_struct = {'global': 'genl_small_ops',
2232                         'per-op': 'genl_ops',
2233                         'split': 'genl_split_ops'}
2234        struct_type = pol_to_struct[family.kernel_policy]
2235
2236        if not exported:
2237            cnt = ""
2238        elif family.kernel_policy == 'split':
2239            cnt = 0
2240            for op in family.ops.values():
2241                if 'do' in op:
2242                    cnt += 1
2243                if 'dump' in op:
2244                    cnt += 1
2245        else:
2246            cnt = len(family.ops)
2247
2248        qual = 'static const' if not exported else 'const'
2249        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2250        if terminate:
2251            cw.p(f"extern {line};")
2252        else:
2253            cw.block_start(line=line + ' =')
2254
2255    if not terminate:
2256        return
2257
2258    cw.nl()
2259    for name in family.hooks['pre']['do']['list']:
2260        cw.write_func_prot('int', c_lower(name),
2261                           ['const struct genl_split_ops *ops',
2262                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2263    for name in family.hooks['post']['do']['list']:
2264        cw.write_func_prot('void', c_lower(name),
2265                           ['const struct genl_split_ops *ops',
2266                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2267    for name in family.hooks['pre']['dump']['list']:
2268        cw.write_func_prot('int', c_lower(name),
2269                           ['struct netlink_callback *cb'], suffix=';')
2270    for name in family.hooks['post']['dump']['list']:
2271        cw.write_func_prot('int', c_lower(name),
2272                           ['struct netlink_callback *cb'], suffix=';')
2273
2274    cw.nl()
2275
2276    for op_name, op in family.ops.items():
2277        if op.is_async:
2278            continue
2279
2280        if 'do' in op:
2281            name = c_lower(f"{family.ident_name}-nl-{op_name}-doit")
2282            cw.write_func_prot('int', name,
2283                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2284
2285        if 'dump' in op:
2286            name = c_lower(f"{family.ident_name}-nl-{op_name}-dumpit")
2287            cw.write_func_prot('int', name,
2288                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2289    cw.nl()
2290
2291
2292def print_kernel_op_table_hdr(family, cw):
2293    print_kernel_op_table_fwd(family, cw, terminate=True)
2294
2295
2296def print_kernel_op_table(family, cw):
2297    print_kernel_op_table_fwd(family, cw, terminate=False)
2298    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2299        for op_name, op in family.ops.items():
2300            if op.is_async:
2301                continue
2302
2303            cw.ifdef_block(op.get('config-cond', None))
2304            cw.block_start()
2305            members = [('cmd', op.enum_name)]
2306            if 'dont-validate' in op:
2307                members.append(('validate',
2308                                ' | '.join([c_upper('genl-dont-validate-' + x)
2309                                            for x in op['dont-validate']])), )
2310            for op_mode in ['do', 'dump']:
2311                if op_mode in op:
2312                    name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2313                    members.append((op_mode + 'it', name))
2314            if family.kernel_policy == 'per-op':
2315                struct = Struct(family, op['attribute-set'],
2316                                type_list=op['do']['request']['attributes'])
2317
2318                name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2319                members.append(('policy', name))
2320                members.append(('maxattr', struct.attr_max_val.enum_name))
2321            if 'flags' in op:
2322                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2323            cw.write_struct_init(members)
2324            cw.block_end(line=',')
2325    elif family.kernel_policy == 'split':
2326        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2327                    'dump': {'pre': 'start', 'post': 'done'}}
2328
2329        for op_name, op in family.ops.items():
2330            for op_mode in ['do', 'dump']:
2331                if op.is_async or op_mode not in op:
2332                    continue
2333
2334                cw.ifdef_block(op.get('config-cond', None))
2335                cw.block_start()
2336                members = [('cmd', op.enum_name)]
2337                if 'dont-validate' in op:
2338                    dont_validate = []
2339                    for x in op['dont-validate']:
2340                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2341                            continue
2342                        if op_mode == "dump" and x == 'strict':
2343                            continue
2344                        dont_validate.append(x)
2345
2346                    if dont_validate:
2347                        members.append(('validate',
2348                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2349                                                    for x in dont_validate])), )
2350                name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2351                if 'pre' in op[op_mode]:
2352                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2353                members.append((op_mode + 'it', name))
2354                if 'post' in op[op_mode]:
2355                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2356                if 'request' in op[op_mode]:
2357                    struct = Struct(family, op['attribute-set'],
2358                                    type_list=op[op_mode]['request']['attributes'])
2359
2360                    if op.dual_policy:
2361                        name = c_lower(f"{family.ident_name}-{op_name}-{op_mode}-nl-policy")
2362                    else:
2363                        name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2364                    members.append(('policy', name))
2365                    members.append(('maxattr', struct.attr_max_val.enum_name))
2366                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2367                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2368                cw.write_struct_init(members)
2369                cw.block_end(line=',')
2370    cw.ifdef_block(None)
2371
2372    cw.block_end(line=';')
2373    cw.nl()
2374
2375
2376def print_kernel_mcgrp_hdr(family, cw):
2377    if not family.mcgrps['list']:
2378        return
2379
2380    cw.block_start('enum')
2381    for grp in family.mcgrps['list']:
2382        grp_id = c_upper(f"{family.ident_name}-nlgrp-{grp['name']},")
2383        cw.p(grp_id)
2384    cw.block_end(';')
2385    cw.nl()
2386
2387
2388def print_kernel_mcgrp_src(family, cw):
2389    if not family.mcgrps['list']:
2390        return
2391
2392    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
2393    for grp in family.mcgrps['list']:
2394        name = grp['name']
2395        grp_id = c_upper(f"{family.ident_name}-nlgrp-{name}")
2396        cw.p('[' + grp_id + '] = { "' + name + '", },')
2397    cw.block_end(';')
2398    cw.nl()
2399
2400
2401def print_kernel_family_struct_hdr(family, cw):
2402    if not kernel_can_gen_family_struct(family):
2403        return
2404
2405    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
2406    cw.nl()
2407    if 'sock-priv' in family.kernel_family:
2408        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
2409        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
2410        cw.nl()
2411
2412
2413def print_kernel_family_struct_src(family, cw):
2414    if not kernel_can_gen_family_struct(family):
2415        return
2416
2417    if 'sock-priv' in family.kernel_family:
2418        # Generate "trampolines" to make CFI happy
2419        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_init",
2420                      [f"{family.c_name}_nl_sock_priv_init(priv);"],
2421                      ["void *priv"])
2422        cw.nl()
2423        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_destroy",
2424                      [f"{family.c_name}_nl_sock_priv_destroy(priv);"],
2425                      ["void *priv"])
2426        cw.nl()
2427
2428    cw.block_start(f"struct genl_family {family.ident_name}_nl_family __ro_after_init =")
2429    cw.p('.name\t\t= ' + family.fam_key + ',')
2430    cw.p('.version\t= ' + family.ver_key + ',')
2431    cw.p('.netnsok\t= true,')
2432    cw.p('.parallel_ops\t= true,')
2433    cw.p('.module\t\t= THIS_MODULE,')
2434    if family.kernel_policy == 'per-op':
2435        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
2436        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2437    elif family.kernel_policy == 'split':
2438        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
2439        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2440    if family.mcgrps['list']:
2441        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
2442        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
2443    if 'sock-priv' in family.kernel_family:
2444        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
2445        cw.p(f'.sock_priv_init\t= __{family.c_name}_nl_sock_priv_init,')
2446        cw.p(f'.sock_priv_destroy = __{family.c_name}_nl_sock_priv_destroy,')
2447    cw.block_end(';')
2448
2449
2450def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2451    start_line = 'enum'
2452    if enum_name in obj:
2453        if obj[enum_name]:
2454            start_line = 'enum ' + c_lower(obj[enum_name])
2455    elif ckey and ckey in obj:
2456        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
2457    cw.block_start(line=start_line)
2458
2459
2460def render_uapi_unified(family, cw, max_by_define, separate_ntf):
2461    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2462    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2463    max_value = f"({cnt_name} - 1)"
2464
2465    uapi_enum_start(family, cw, family['operations'], 'enum-name')
2466    val = 0
2467    for op in family.msgs.values():
2468        if separate_ntf and ('notify' in op or 'event' in op):
2469            continue
2470
2471        suffix = ','
2472        if op.value != val:
2473            suffix = f" = {op.value},"
2474            val = op.value
2475        cw.p(op.enum_name + suffix)
2476        val += 1
2477    cw.nl()
2478    cw.p(cnt_name + ('' if max_by_define else ','))
2479    if not max_by_define:
2480        cw.p(f"{max_name} = {max_value}")
2481    cw.block_end(line=';')
2482    if max_by_define:
2483        cw.p(f"#define {max_name} {max_value}")
2484    cw.nl()
2485
2486
2487def render_uapi_directional(family, cw, max_by_define):
2488    max_name = f"{family.op_prefix}USER_MAX"
2489    cnt_name = f"__{family.op_prefix}USER_CNT"
2490    max_value = f"({cnt_name} - 1)"
2491
2492    cw.block_start(line='enum')
2493    cw.p(c_upper(f'{family.name}_MSG_USER_NONE = 0,'))
2494    val = 0
2495    for op in family.msgs.values():
2496        if 'do' in op and 'event' not in op:
2497            suffix = ','
2498            if op.value and op.value != val:
2499                suffix = f" = {op.value},"
2500                val = op.value
2501            cw.p(op.enum_name + suffix)
2502            val += 1
2503    cw.nl()
2504    cw.p(cnt_name + ('' if max_by_define else ','))
2505    if not max_by_define:
2506        cw.p(f"{max_name} = {max_value}")
2507    cw.block_end(line=';')
2508    if max_by_define:
2509        cw.p(f"#define {max_name} {max_value}")
2510    cw.nl()
2511
2512    max_name = f"{family.op_prefix}KERNEL_MAX"
2513    cnt_name = f"__{family.op_prefix}KERNEL_CNT"
2514    max_value = f"({cnt_name} - 1)"
2515
2516    cw.block_start(line='enum')
2517    cw.p(c_upper(f'{family.name}_MSG_KERNEL_NONE = 0,'))
2518    val = 0
2519    for op in family.msgs.values():
2520        if ('do' in op and 'reply' in op['do']) or 'notify' in op or 'event' in op:
2521            enum_name = op.enum_name
2522            if 'event' not in op and 'notify' not in op:
2523                enum_name = f'{enum_name}_REPLY'
2524
2525            suffix = ','
2526            if op.value and op.value != val:
2527                suffix = f" = {op.value},"
2528                val = op.value
2529            cw.p(enum_name + suffix)
2530            val += 1
2531    cw.nl()
2532    cw.p(cnt_name + ('' if max_by_define else ','))
2533    if not max_by_define:
2534        cw.p(f"{max_name} = {max_value}")
2535    cw.block_end(line=';')
2536    if max_by_define:
2537        cw.p(f"#define {max_name} {max_value}")
2538    cw.nl()
2539
2540
2541def render_uapi(family, cw):
2542    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
2543    hdr_prot = hdr_prot.replace('/', '_')
2544    cw.p('#ifndef ' + hdr_prot)
2545    cw.p('#define ' + hdr_prot)
2546    cw.nl()
2547
2548    defines = [(family.fam_key, family["name"]),
2549               (family.ver_key, family.get('version', 1))]
2550    cw.writes_defines(defines)
2551    cw.nl()
2552
2553    defines = []
2554    for const in family['definitions']:
2555        if const.get('header'):
2556            continue
2557
2558        if const['type'] != 'const':
2559            cw.writes_defines(defines)
2560            defines = []
2561            cw.nl()
2562
2563        # Write kdoc for enum and flags (one day maybe also structs)
2564        if const['type'] == 'enum' or const['type'] == 'flags':
2565            enum = family.consts[const['name']]
2566
2567            if enum.header:
2568                continue
2569
2570            if enum.has_doc():
2571                if enum.has_entry_doc():
2572                    cw.p('/**')
2573                    doc = ''
2574                    if 'doc' in enum:
2575                        doc = ' - ' + enum['doc']
2576                    cw.write_doc_line(enum.enum_name + doc)
2577                else:
2578                    cw.p('/*')
2579                    cw.write_doc_line(enum['doc'], indent=False)
2580                for entry in enum.entries.values():
2581                    if entry.has_doc():
2582                        doc = '@' + entry.c_name + ': ' + entry['doc']
2583                        cw.write_doc_line(doc)
2584                cw.p(' */')
2585
2586            uapi_enum_start(family, cw, const, 'name')
2587            name_pfx = const.get('name-prefix', f"{family.ident_name}-{const['name']}-")
2588            for entry in enum.entries.values():
2589                suffix = ','
2590                if entry.value_change:
2591                    suffix = f" = {entry.user_value()}" + suffix
2592                cw.p(entry.c_name + suffix)
2593
2594            if const.get('render-max', False):
2595                cw.nl()
2596                cw.p('/* private: */')
2597                if const['type'] == 'flags':
2598                    max_name = c_upper(name_pfx + 'mask')
2599                    max_val = f' = {enum.get_mask()},'
2600                    cw.p(max_name + max_val)
2601                else:
2602                    cnt_name = enum.enum_cnt_name
2603                    max_name = c_upper(name_pfx + 'max')
2604                    if not cnt_name:
2605                        cnt_name = '__' + name_pfx + 'max'
2606                    cw.p(c_upper(cnt_name) + ',')
2607                    cw.p(max_name + ' = (' + c_upper(cnt_name) + ' - 1)')
2608            cw.block_end(line=';')
2609            cw.nl()
2610        elif const['type'] == 'const':
2611            defines.append([c_upper(family.get('c-define-name',
2612                                               f"{family.ident_name}-{const['name']}")),
2613                            const['value']])
2614
2615    if defines:
2616        cw.writes_defines(defines)
2617        cw.nl()
2618
2619    max_by_define = family.get('max-by-define', False)
2620
2621    for _, attr_set in family.attr_sets.items():
2622        if attr_set.subset_of:
2623            continue
2624
2625        max_value = f"({attr_set.cnt_name} - 1)"
2626
2627        val = 0
2628        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2629        for _, attr in attr_set.items():
2630            suffix = ','
2631            if attr.value != val:
2632                suffix = f" = {attr.value},"
2633                val = attr.value
2634            val += 1
2635            cw.p(attr.enum_name + suffix)
2636        if attr_set.items():
2637            cw.nl()
2638        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
2639        if not max_by_define:
2640            cw.p(f"{attr_set.max_name} = {max_value}")
2641        cw.block_end(line=';')
2642        if max_by_define:
2643            cw.p(f"#define {attr_set.max_name} {max_value}")
2644        cw.nl()
2645
2646    # Commands
2647    separate_ntf = 'async-prefix' in family['operations']
2648
2649    if family.msg_id_model == 'unified':
2650        render_uapi_unified(family, cw, max_by_define, separate_ntf)
2651    elif family.msg_id_model == 'directional':
2652        render_uapi_directional(family, cw, max_by_define)
2653    else:
2654        raise Exception(f'Unsupported message enum-model {family.msg_id_model}')
2655
2656    if separate_ntf:
2657        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2658        for op in family.msgs.values():
2659            if separate_ntf and not ('notify' in op or 'event' in op):
2660                continue
2661
2662            suffix = ','
2663            if 'value' in op:
2664                suffix = f" = {op['value']},"
2665            cw.p(op.enum_name + suffix)
2666        cw.block_end(line=';')
2667        cw.nl()
2668
2669    # Multicast
2670    defines = []
2671    for grp in family.mcgrps['list']:
2672        name = grp['name']
2673        defines.append([c_upper(grp.get('c-define-name', f"{family.ident_name}-mcgrp-{name}")),
2674                        f'{name}'])
2675    cw.nl()
2676    if defines:
2677        cw.writes_defines(defines)
2678        cw.nl()
2679
2680    cw.p(f'#endif /* {hdr_prot} */')
2681
2682
2683def _render_user_ntf_entry(ri, op):
2684    ri.cw.block_start(line=f"[{op.enum_name}] = ")
2685    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
2686    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
2687    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
2688    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
2689    ri.cw.block_end(line=',')
2690
2691
2692def render_user_family(family, cw, prototype):
2693    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
2694    if prototype:
2695        cw.p(f'extern {symbol};')
2696        return
2697
2698    if family.ntfs:
2699        cw.block_start(line=f"static const struct ynl_ntf_info {family['name']}_ntf_info[] = ")
2700        for ntf_op_name, ntf_op in family.ntfs.items():
2701            if 'notify' in ntf_op:
2702                op = family.ops[ntf_op['notify']]
2703                ri = RenderInfo(cw, family, "user", op, "notify")
2704            elif 'event' in ntf_op:
2705                ri = RenderInfo(cw, family, "user", ntf_op, "event")
2706            else:
2707                raise Exception('Invalid notification ' + ntf_op_name)
2708            _render_user_ntf_entry(ri, ntf_op)
2709        for op_name, op in family.ops.items():
2710            if 'event' not in op:
2711                continue
2712            ri = RenderInfo(cw, family, "user", op, "event")
2713            _render_user_ntf_entry(ri, op)
2714        cw.block_end(line=";")
2715        cw.nl()
2716
2717    cw.block_start(f'{symbol} = ')
2718    cw.p(f'.name\t\t= "{family.c_name}",')
2719    if family.fixed_header:
2720        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
2721    else:
2722        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
2723    if family.ntfs:
2724        cw.p(f".ntf_info\t= {family['name']}_ntf_info,")
2725        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family['name']}_ntf_info),")
2726    cw.block_end(line=';')
2727
2728
2729def family_contains_bitfield32(family):
2730    for _, attr_set in family.attr_sets.items():
2731        if attr_set.subset_of:
2732            continue
2733        for _, attr in attr_set.items():
2734            if attr.type == "bitfield32":
2735                return True
2736    return False
2737
2738
2739def find_kernel_root(full_path):
2740    sub_path = ''
2741    while True:
2742        sub_path = os.path.join(os.path.basename(full_path), sub_path)
2743        full_path = os.path.dirname(full_path)
2744        maintainers = os.path.join(full_path, "MAINTAINERS")
2745        if os.path.exists(maintainers):
2746            return full_path, sub_path[:-1]
2747
2748
2749def main():
2750    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2751    parser.add_argument('--mode', dest='mode', type=str, required=True,
2752                        choices=('user', 'kernel', 'uapi'))
2753    parser.add_argument('--spec', dest='spec', type=str, required=True)
2754    parser.add_argument('--header', dest='header', action='store_true', default=None)
2755    parser.add_argument('--source', dest='header', action='store_false')
2756    parser.add_argument('--user-header', nargs='+', default=[])
2757    parser.add_argument('--cmp-out', action='store_true', default=None,
2758                        help='Do not overwrite the output file if the new output is identical to the old')
2759    parser.add_argument('--exclude-op', action='append', default=[])
2760    parser.add_argument('-o', dest='out_file', type=str, default=None)
2761    args = parser.parse_args()
2762
2763    if args.header is None:
2764        parser.error("--header or --source is required")
2765
2766    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
2767
2768    try:
2769        parsed = Family(args.spec, exclude_ops)
2770        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
2771            print('Spec license:', parsed.license)
2772            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
2773            os.sys.exit(1)
2774    except yaml.YAMLError as exc:
2775        print(exc)
2776        os.sys.exit(1)
2777        return
2778
2779    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
2780
2781    _, spec_kernel = find_kernel_root(args.spec)
2782    if args.mode == 'uapi' or args.header:
2783        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
2784    else:
2785        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
2786    cw.p("/* Do not edit directly, auto-generated from: */")
2787    cw.p(f"/*\t{spec_kernel} */")
2788    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2789    if args.exclude_op or args.user_header:
2790        line = ''
2791        line += ' --user-header '.join([''] + args.user_header)
2792        line += ' --exclude-op '.join([''] + args.exclude_op)
2793        cw.p(f'/* YNL-ARG{line} */')
2794    cw.nl()
2795
2796    if args.mode == 'uapi':
2797        render_uapi(parsed, cw)
2798        return
2799
2800    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
2801    if args.header:
2802        cw.p('#ifndef ' + hdr_prot)
2803        cw.p('#define ' + hdr_prot)
2804        cw.nl()
2805
2806    if args.out_file:
2807        hdr_file = os.path.basename(args.out_file[:-2]) + ".h"
2808    else:
2809        hdr_file = "generated_header_file.h"
2810
2811    if args.mode == 'kernel':
2812        cw.p('#include <net/netlink.h>')
2813        cw.p('#include <net/genetlink.h>')
2814        cw.nl()
2815        if not args.header:
2816            if args.out_file:
2817                cw.p(f'#include "{hdr_file}"')
2818            cw.nl()
2819        headers = ['uapi/' + parsed.uapi_header]
2820        headers += parsed.kernel_family.get('headers', [])
2821    else:
2822        cw.p('#include <stdlib.h>')
2823        cw.p('#include <string.h>')
2824        if args.header:
2825            cw.p('#include <linux/types.h>')
2826            if family_contains_bitfield32(parsed):
2827                cw.p('#include <linux/netlink.h>')
2828        else:
2829            cw.p(f'#include "{hdr_file}"')
2830            cw.p('#include "ynl.h"')
2831        headers = []
2832    for definition in parsed['definitions']:
2833        if 'header' in definition:
2834            headers.append(definition['header'])
2835    if args.mode == 'user':
2836        headers.append(parsed.uapi_header)
2837    seen_header = []
2838    for one in headers:
2839        if one not in seen_header:
2840            cw.p(f"#include <{one}>")
2841            seen_header.append(one)
2842    cw.nl()
2843
2844    if args.mode == "user":
2845        if not args.header:
2846            cw.p("#include <linux/genetlink.h>")
2847            cw.nl()
2848            for one in args.user_header:
2849                cw.p(f'#include "{one}"')
2850        else:
2851            cw.p('struct ynl_sock;')
2852            cw.nl()
2853            render_user_family(parsed, cw, True)
2854        cw.nl()
2855
2856    if args.mode == "kernel":
2857        if args.header:
2858            for _, struct in sorted(parsed.pure_nested_structs.items()):
2859                if struct.request:
2860                    cw.p('/* Common nested types */')
2861                    break
2862            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2863                if struct.request:
2864                    print_req_policy_fwd(cw, struct)
2865            cw.nl()
2866
2867            if parsed.kernel_policy == 'global':
2868                cw.p(f"/* Global operation policy for {parsed.name} */")
2869
2870                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2871                print_req_policy_fwd(cw, struct)
2872                cw.nl()
2873
2874            if parsed.kernel_policy in {'per-op', 'split'}:
2875                for op_name, op in parsed.ops.items():
2876                    if 'do' in op and 'event' not in op:
2877                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
2878                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2879                        cw.nl()
2880
2881            print_kernel_op_table_hdr(parsed, cw)
2882            print_kernel_mcgrp_hdr(parsed, cw)
2883            print_kernel_family_struct_hdr(parsed, cw)
2884        else:
2885            print_kernel_policy_ranges(parsed, cw)
2886
2887            for _, struct in sorted(parsed.pure_nested_structs.items()):
2888                if struct.request:
2889                    cw.p('/* Common nested types */')
2890                    break
2891            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2892                if struct.request:
2893                    print_req_policy(cw, struct)
2894            cw.nl()
2895
2896            if parsed.kernel_policy == 'global':
2897                cw.p(f"/* Global operation policy for {parsed.name} */")
2898
2899                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2900                print_req_policy(cw, struct)
2901                cw.nl()
2902
2903            for op_name, op in parsed.ops.items():
2904                if parsed.kernel_policy in {'per-op', 'split'}:
2905                    for op_mode in ['do', 'dump']:
2906                        if op_mode in op and 'request' in op[op_mode]:
2907                            cw.p(f"/* {op.enum_name} - {op_mode} */")
2908                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
2909                            print_req_policy(cw, ri.struct['request'], ri=ri)
2910                            cw.nl()
2911
2912            print_kernel_op_table(parsed, cw)
2913            print_kernel_mcgrp_src(parsed, cw)
2914            print_kernel_family_struct_src(parsed, cw)
2915
2916    if args.mode == "user":
2917        if args.header:
2918            cw.p('/* Enums */')
2919            put_op_name_fwd(parsed, cw)
2920
2921            for name, const in parsed.consts.items():
2922                if isinstance(const, EnumSet):
2923                    put_enum_to_str_fwd(parsed, cw, const)
2924            cw.nl()
2925
2926            cw.p('/* Common nested types */')
2927            for attr_set, struct in parsed.pure_nested_structs.items():
2928                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2929                print_type_full(ri, struct)
2930
2931            for op_name, op in parsed.ops.items():
2932                cw.p(f"/* ============== {op.enum_name} ============== */")
2933
2934                if 'do' in op and 'event' not in op:
2935                    cw.p(f"/* {op.enum_name} - do */")
2936                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2937                    print_req_type(ri)
2938                    print_req_type_helpers(ri)
2939                    cw.nl()
2940                    print_rsp_type(ri)
2941                    print_rsp_type_helpers(ri)
2942                    cw.nl()
2943                    print_req_prototype(ri)
2944                    cw.nl()
2945
2946                if 'dump' in op:
2947                    cw.p(f"/* {op.enum_name} - dump */")
2948                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
2949                    print_req_type(ri)
2950                    print_req_type_helpers(ri)
2951                    if not ri.type_consistent:
2952                        print_rsp_type(ri)
2953                    print_wrapped_type(ri)
2954                    print_dump_prototype(ri)
2955                    cw.nl()
2956
2957                if op.has_ntf:
2958                    cw.p(f"/* {op.enum_name} - notify */")
2959                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2960                    if not ri.type_consistent:
2961                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2962                    print_wrapped_type(ri)
2963
2964            for op_name, op in parsed.ntfs.items():
2965                if 'event' in op:
2966                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
2967                    cw.p(f"/* {op.enum_name} - event */")
2968                    print_rsp_type(ri)
2969                    cw.nl()
2970                    print_wrapped_type(ri)
2971            cw.nl()
2972        else:
2973            cw.p('/* Enums */')
2974            put_op_name(parsed, cw)
2975
2976            for name, const in parsed.consts.items():
2977                if isinstance(const, EnumSet):
2978                    put_enum_to_str(parsed, cw, const)
2979            cw.nl()
2980
2981            has_recursive_nests = False
2982            cw.p('/* Policies */')
2983            for struct in parsed.pure_nested_structs.values():
2984                if struct.recursive:
2985                    put_typol_fwd(cw, struct)
2986                    has_recursive_nests = True
2987            if has_recursive_nests:
2988                cw.nl()
2989            for name in parsed.pure_nested_structs:
2990                struct = Struct(parsed, name)
2991                put_typol(cw, struct)
2992            for name in parsed.root_sets:
2993                struct = Struct(parsed, name)
2994                put_typol(cw, struct)
2995
2996            cw.p('/* Common nested types */')
2997            if has_recursive_nests:
2998                for attr_set, struct in parsed.pure_nested_structs.items():
2999                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3000                    free_rsp_nested_prototype(ri)
3001                    if struct.request:
3002                        put_req_nested_prototype(ri, struct)
3003                    if struct.reply:
3004                        parse_rsp_nested_prototype(ri, struct)
3005                cw.nl()
3006            for attr_set, struct in parsed.pure_nested_structs.items():
3007                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3008
3009                free_rsp_nested(ri, struct)
3010                if struct.request:
3011                    put_req_nested(ri, struct)
3012                if struct.reply:
3013                    parse_rsp_nested(ri, struct)
3014
3015            for op_name, op in parsed.ops.items():
3016                cw.p(f"/* ============== {op.enum_name} ============== */")
3017                if 'do' in op and 'event' not in op:
3018                    cw.p(f"/* {op.enum_name} - do */")
3019                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3020                    print_req_free(ri)
3021                    print_rsp_free(ri)
3022                    parse_rsp_msg(ri)
3023                    print_req(ri)
3024                    cw.nl()
3025
3026                if 'dump' in op:
3027                    cw.p(f"/* {op.enum_name} - dump */")
3028                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
3029                    if not ri.type_consistent:
3030                        parse_rsp_msg(ri, deref=True)
3031                    print_req_free(ri)
3032                    print_dump_type_free(ri)
3033                    print_dump(ri)
3034                    cw.nl()
3035
3036                if op.has_ntf:
3037                    cw.p(f"/* {op.enum_name} - notify */")
3038                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3039                    if not ri.type_consistent:
3040                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3041                    print_ntf_type_free(ri)
3042
3043            for op_name, op in parsed.ntfs.items():
3044                if 'event' in op:
3045                    cw.p(f"/* {op.enum_name} - event */")
3046
3047                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3048                    parse_rsp_msg(ri)
3049
3050                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
3051                    print_ntf_type_free(ri)
3052            cw.nl()
3053            render_user_family(parsed, cw, False)
3054
3055    if args.header:
3056        cw.p(f'#endif /* {hdr_prot} */')
3057
3058
3059if __name__ == "__main__":
3060    main()
3061