xref: /linux/tools/net/ynl/pyynl/ynl_gen_c.py (revision fe259a1bb26ec78842c975d992331705b0c2c2e8)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import filecmp
7import pathlib
8import os
9import re
10import shutil
11import sys
12import tempfile
13import yaml
14
15sys.path.append(pathlib.Path(__file__).resolve().parent.as_posix())
16from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
17
18
19def c_upper(name):
20    return name.upper().replace('-', '_')
21
22
23def c_lower(name):
24    return name.lower().replace('-', '_')
25
26
27def limit_to_number(name):
28    """
29    Turn a string limit like u32-max or s64-min into its numerical value
30    """
31    if name[0] == 'u' and name.endswith('-min'):
32        return 0
33    width = int(name[1:-4])
34    if name[0] == 's':
35        width -= 1
36    value = (1 << width) - 1
37    if name[0] == 's' and name.endswith('-min'):
38        value = -value - 1
39    return value
40
41
42class BaseNlLib:
43    def get_family_id(self):
44        return 'ys->family_id'
45
46
47class Type(SpecAttr):
48    def __init__(self, family, attr_set, attr, value):
49        super().__init__(family, attr_set, attr, value)
50
51        self.attr = attr
52        self.attr_set = attr_set
53        self.type = attr['type']
54        self.checks = attr.get('checks', {})
55
56        self.request = False
57        self.reply = False
58
59        if 'len' in attr:
60            self.len = attr['len']
61
62        if 'nested-attributes' in attr:
63            self.nested_attrs = attr['nested-attributes']
64            if self.nested_attrs == family.name:
65                self.nested_render_name = c_lower(f"{family.ident_name}")
66            else:
67                self.nested_render_name = c_lower(f"{family.ident_name}_{self.nested_attrs}")
68
69            if self.nested_attrs in self.family.consts:
70                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
71            else:
72                self.nested_struct_type = 'struct ' + self.nested_render_name
73
74        self.c_name = c_lower(self.name)
75        if self.c_name in _C_KW:
76            self.c_name += '_'
77        if self.c_name[0].isdigit():
78            self.c_name = '_' + self.c_name
79
80        # Added by resolve():
81        self.enum_name = None
82        delattr(self, "enum_name")
83
84    def _get_real_attr(self):
85        # if the attr is for a subset return the "real" attr (just one down, does not recurse)
86        return self.family.attr_sets[self.attr_set.subset_of][self.name]
87
88    def set_request(self):
89        self.request = True
90        if self.attr_set.subset_of:
91            self._get_real_attr().set_request()
92
93    def set_reply(self):
94        self.reply = True
95        if self.attr_set.subset_of:
96            self._get_real_attr().set_reply()
97
98    def get_limit(self, limit, default=None):
99        value = self.checks.get(limit, default)
100        if value is None:
101            return value
102        if isinstance(value, int):
103            return value
104        if value in self.family.consts:
105            return self.family.consts[value]["value"]
106        return limit_to_number(value)
107
108    def get_limit_str(self, limit, default=None, suffix=''):
109        value = self.checks.get(limit, default)
110        if value is None:
111            return ''
112        if isinstance(value, int):
113            return str(value) + suffix
114        if value in self.family.consts:
115            const = self.family.consts[value]
116            if const.get('header'):
117                return c_upper(value)
118            return c_upper(f"{self.family['name']}-{value}")
119        return c_upper(value)
120
121    def resolve(self):
122        if 'name-prefix' in self.attr:
123            enum_name = f"{self.attr['name-prefix']}{self.name}"
124        else:
125            enum_name = f"{self.attr_set.name_prefix}{self.name}"
126        self.enum_name = c_upper(enum_name)
127
128        if self.attr_set.subset_of:
129            if self.checks != self._get_real_attr().checks:
130                raise Exception("Overriding checks not supported by codegen, yet")
131
132    def is_multi_val(self):
133        return None
134
135    def is_scalar(self):
136        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
137
138    def is_recursive(self):
139        return False
140
141    def is_recursive_for_op(self, ri):
142        return self.is_recursive() and not ri.op
143
144    def presence_type(self):
145        return 'bit'
146
147    def presence_member(self, space, type_filter):
148        if self.presence_type() != type_filter:
149            return
150
151        if self.presence_type() == 'bit':
152            pfx = '__' if space == 'user' else ''
153            return f"{pfx}u32 {self.c_name}:1;"
154
155        if self.presence_type() == 'len':
156            pfx = '__' if space == 'user' else ''
157            return f"{pfx}u32 {self.c_name}_len;"
158
159    def _complex_member_type(self, ri):
160        return None
161
162    def free_needs_iter(self):
163        return False
164
165    def _free_lines(self, ri, var, ref):
166        if self.is_multi_val() or self.presence_type() == 'len':
167            return [f'free({var}->{ref}{self.c_name});']
168        return []
169
170    def free(self, ri, var, ref):
171        lines = self._free_lines(ri, var, ref)
172        for line in lines:
173            ri.cw.p(line)
174
175    def arg_member(self, ri):
176        member = self._complex_member_type(ri)
177        if member:
178            arg = [member + ' *' + self.c_name]
179            if self.presence_type() == 'count':
180                arg += ['unsigned int n_' + self.c_name]
181            return arg
182        raise Exception(f"Struct member not implemented for class type {self.type}")
183
184    def struct_member(self, ri):
185        if self.is_multi_val():
186            ri.cw.p(f"unsigned int n_{self.c_name};")
187        member = self._complex_member_type(ri)
188        if member:
189            ptr = '*' if self.is_multi_val() else ''
190            if self.is_recursive_for_op(ri):
191                ptr = '*'
192            ri.cw.p(f"{member} {ptr}{self.c_name};")
193            return
194        members = self.arg_member(ri)
195        for one in members:
196            ri.cw.p(one + ';')
197
198    def _attr_policy(self, policy):
199        return '{ .type = ' + policy + ', }'
200
201    def attr_policy(self, cw):
202        policy = f'NLA_{c_upper(self.type)}'
203        if self.attr.get('byte-order') == 'big-endian':
204            if self.type in {'u16', 'u32'}:
205                policy = f'NLA_BE{self.type[1:]}'
206
207        spec = self._attr_policy(policy)
208        cw.p(f"\t[{self.enum_name}] = {spec},")
209
210    def _attr_typol(self):
211        raise Exception(f"Type policy not implemented for class type {self.type}")
212
213    def attr_typol(self, cw):
214        typol = self._attr_typol()
215        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
216
217    def _attr_put_line(self, ri, var, line):
218        if self.presence_type() == 'bit':
219            ri.cw.p(f"if ({var}->_present.{self.c_name})")
220        elif self.presence_type() == 'len':
221            ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
222        ri.cw.p(f"{line};")
223
224    def _attr_put_simple(self, ri, var, put_type):
225        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
226        self._attr_put_line(ri, var, line)
227
228    def attr_put(self, ri, var):
229        raise Exception(f"Put not implemented for class type {self.type}")
230
231    def _attr_get(self, ri, var):
232        raise Exception(f"Attr get not implemented for class type {self.type}")
233
234    def attr_get(self, ri, var, first):
235        lines, init_lines, local_vars = self._attr_get(ri, var)
236        if type(lines) is str:
237            lines = [lines]
238        if type(init_lines) is str:
239            init_lines = [init_lines]
240
241        kw = 'if' if first else 'else if'
242        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
243        if local_vars:
244            for local in local_vars:
245                ri.cw.p(local)
246            ri.cw.nl()
247
248        if not self.is_multi_val():
249            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
250            ri.cw.p("return YNL_PARSE_CB_ERROR;")
251            if self.presence_type() == 'bit':
252                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
253
254        if init_lines:
255            ri.cw.nl()
256            for line in init_lines:
257                ri.cw.p(line)
258
259        for line in lines:
260            ri.cw.p(line)
261        ri.cw.block_end()
262        return True
263
264    def _setter_lines(self, ri, member, presence):
265        raise Exception(f"Setter not implemented for class type {self.type}")
266
267    def setter(self, ri, space, direction, deref=False, ref=None):
268        ref = (ref if ref else []) + [self.c_name]
269        var = "req"
270        member = f"{var}->{'.'.join(ref)}"
271
272        local_vars = []
273        if self.free_needs_iter():
274            local_vars += ['unsigned int i;']
275
276        code = []
277        presence = ''
278        for i in range(0, len(ref)):
279            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
280            # Every layer below last is a nest, so we know it uses bit presence
281            # last layer is "self" and may be a complex type
282            if i == len(ref) - 1 and self.presence_type() != 'bit':
283                continue
284            code.append(presence + ' = 1;')
285        ref_path = '.'.join(ref[:-1])
286        if ref_path:
287            ref_path += '.'
288        code += self._free_lines(ri, var, ref_path)
289        code += self._setter_lines(ri, member, presence)
290
291        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
292        free = bool([x for x in code if 'free(' in x])
293        alloc = bool([x for x in code if 'alloc(' in x])
294        if free and not alloc:
295            func_name = '__' + func_name
296        ri.cw.write_func('static inline void', func_name, local_vars=local_vars,
297                         body=code,
298                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
299
300
301class TypeUnused(Type):
302    def presence_type(self):
303        return ''
304
305    def arg_member(self, ri):
306        return []
307
308    def _attr_get(self, ri, var):
309        return ['return YNL_PARSE_CB_ERROR;'], None, None
310
311    def _attr_typol(self):
312        return '.type = YNL_PT_REJECT, '
313
314    def attr_policy(self, cw):
315        pass
316
317    def attr_put(self, ri, var):
318        pass
319
320    def attr_get(self, ri, var, first):
321        pass
322
323    def setter(self, ri, space, direction, deref=False, ref=None):
324        pass
325
326
327class TypePad(Type):
328    def presence_type(self):
329        return ''
330
331    def arg_member(self, ri):
332        return []
333
334    def _attr_typol(self):
335        return '.type = YNL_PT_IGNORE, '
336
337    def attr_put(self, ri, var):
338        pass
339
340    def attr_get(self, ri, var, first):
341        pass
342
343    def attr_policy(self, cw):
344        pass
345
346    def setter(self, ri, space, direction, deref=False, ref=None):
347        pass
348
349
350class TypeScalar(Type):
351    def __init__(self, family, attr_set, attr, value):
352        super().__init__(family, attr_set, attr, value)
353
354        self.byte_order_comment = ''
355        if 'byte-order' in attr:
356            self.byte_order_comment = f" /* {attr['byte-order']} */"
357
358        if 'enum' in self.attr:
359            enum = self.family.consts[self.attr['enum']]
360            low, high = enum.value_range()
361            if 'min' not in self.checks:
362                if low != 0 or self.type[0] == 's':
363                    self.checks['min'] = low
364            if 'max' not in self.checks:
365                self.checks['max'] = high
366
367        if 'min' in self.checks and 'max' in self.checks:
368            if self.get_limit('min') > self.get_limit('max'):
369                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
370            self.checks['range'] = True
371
372        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
373        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
374        if low < 0 and self.type[0] == 'u':
375            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
376        if low < -32768 or high > 32767:
377            self.checks['full-range'] = True
378
379        # Added by resolve():
380        self.is_bitfield = None
381        delattr(self, "is_bitfield")
382        self.type_name = None
383        delattr(self, "type_name")
384
385    def resolve(self):
386        self.resolve_up(super())
387
388        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
389            self.is_bitfield = True
390        elif 'enum' in self.attr:
391            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
392        else:
393            self.is_bitfield = False
394
395        if not self.is_bitfield and 'enum' in self.attr:
396            self.type_name = self.family.consts[self.attr['enum']].user_type
397        elif self.is_auto_scalar:
398            self.type_name = '__' + self.type[0] + '64'
399        else:
400            self.type_name = '__' + self.type
401
402    def _attr_policy(self, policy):
403        if 'flags-mask' in self.checks or self.is_bitfield:
404            if self.is_bitfield:
405                enum = self.family.consts[self.attr['enum']]
406                mask = enum.get_mask(as_flags=True)
407            else:
408                flags = self.family.consts[self.checks['flags-mask']]
409                flag_cnt = len(flags['entries'])
410                mask = (1 << flag_cnt) - 1
411            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
412        elif 'full-range' in self.checks:
413            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
414        elif 'range' in self.checks:
415            return f"NLA_POLICY_RANGE({policy}, {self.get_limit_str('min')}, {self.get_limit_str('max')})"
416        elif 'min' in self.checks:
417            return f"NLA_POLICY_MIN({policy}, {self.get_limit_str('min')})"
418        elif 'max' in self.checks:
419            return f"NLA_POLICY_MAX({policy}, {self.get_limit_str('max')})"
420        return super()._attr_policy(policy)
421
422    def _attr_typol(self):
423        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
424
425    def arg_member(self, ri):
426        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
427
428    def attr_put(self, ri, var):
429        self._attr_put_simple(ri, var, self.type)
430
431    def _attr_get(self, ri, var):
432        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
433
434    def _setter_lines(self, ri, member, presence):
435        return [f"{member} = {self.c_name};"]
436
437
438class TypeFlag(Type):
439    def arg_member(self, ri):
440        return []
441
442    def _attr_typol(self):
443        return '.type = YNL_PT_FLAG, '
444
445    def attr_put(self, ri, var):
446        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
447
448    def _attr_get(self, ri, var):
449        return [], None, None
450
451    def _setter_lines(self, ri, member, presence):
452        return []
453
454
455class TypeString(Type):
456    def arg_member(self, ri):
457        return [f"const char *{self.c_name}"]
458
459    def presence_type(self):
460        return 'len'
461
462    def struct_member(self, ri):
463        ri.cw.p(f"char *{self.c_name};")
464
465    def _attr_typol(self):
466        return f'.type = YNL_PT_NUL_STR, '
467
468    def _attr_policy(self, policy):
469        if 'exact-len' in self.checks:
470            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
471        else:
472            mem = '{ .type = ' + policy
473            if 'max-len' in self.checks:
474                mem += ', .len = ' + self.get_limit_str('max-len')
475            mem += ', }'
476        return mem
477
478    def attr_policy(self, cw):
479        if self.checks.get('unterminated-ok', False):
480            policy = 'NLA_STRING'
481        else:
482            policy = 'NLA_NUL_STRING'
483
484        spec = self._attr_policy(policy)
485        cw.p(f"\t[{self.enum_name}] = {spec},")
486
487    def attr_put(self, ri, var):
488        self._attr_put_simple(ri, var, 'str')
489
490    def _attr_get(self, ri, var):
491        len_mem = var + '->_present.' + self.c_name + '_len'
492        return [f"{len_mem} = len;",
493                f"{var}->{self.c_name} = malloc(len + 1);",
494                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
495                f"{var}->{self.c_name}[len] = 0;"], \
496               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
497               ['unsigned int len;']
498
499    def _setter_lines(self, ri, member, presence):
500        return [f"{presence}_len = strlen({self.c_name});",
501                f"{member} = malloc({presence}_len + 1);",
502                f'memcpy({member}, {self.c_name}, {presence}_len);',
503                f'{member}[{presence}_len] = 0;']
504
505
506class TypeBinary(Type):
507    def arg_member(self, ri):
508        return [f"const void *{self.c_name}", 'size_t len']
509
510    def presence_type(self):
511        return 'len'
512
513    def struct_member(self, ri):
514        ri.cw.p(f"void *{self.c_name};")
515
516    def _attr_typol(self):
517        return f'.type = YNL_PT_BINARY,'
518
519    def _attr_policy(self, policy):
520        if len(self.checks) == 0:
521            pass
522        elif len(self.checks) == 1:
523            check_name = list(self.checks)[0]
524            if check_name not in {'exact-len', 'min-len', 'max-len'}:
525                raise Exception('Unsupported check for binary type: ' + check_name)
526        else:
527            raise Exception('More than one check for binary type not implemented, yet')
528
529        if len(self.checks) == 0:
530            mem = '{ .type = NLA_BINARY, }'
531        elif 'exact-len' in self.checks:
532            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
533        elif 'min-len' in self.checks:
534            mem = '{ .len = ' + self.get_limit_str('min-len') + ', }'
535        elif 'max-len' in self.checks:
536            mem = 'NLA_POLICY_MAX_LEN(' + self.get_limit_str('max-len') + ')'
537
538        return mem
539
540    def attr_put(self, ri, var):
541        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
542                            f"{var}->{self.c_name}, {var}->_present.{self.c_name}_len)")
543
544    def _attr_get(self, ri, var):
545        len_mem = var + '->_present.' + self.c_name + '_len'
546        return [f"{len_mem} = len;",
547                f"{var}->{self.c_name} = malloc(len);",
548                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
549               ['len = ynl_attr_data_len(attr);'], \
550               ['unsigned int len;']
551
552    def _setter_lines(self, ri, member, presence):
553        return [f"{presence}_len = len;",
554                f"{member} = malloc({presence}_len);",
555                f'memcpy({member}, {self.c_name}, {presence}_len);']
556
557
558class TypeBitfield32(Type):
559    def _complex_member_type(self, ri):
560        return "struct nla_bitfield32"
561
562    def _attr_typol(self):
563        return f'.type = YNL_PT_BITFIELD32, '
564
565    def _attr_policy(self, policy):
566        if not 'enum' in self.attr:
567            raise Exception('Enum required for bitfield32 attr')
568        enum = self.family.consts[self.attr['enum']]
569        mask = enum.get_mask(as_flags=True)
570        return f"NLA_POLICY_BITFIELD32({mask})"
571
572    def attr_put(self, ri, var):
573        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
574        self._attr_put_line(ri, var, line)
575
576    def _attr_get(self, ri, var):
577        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
578
579    def _setter_lines(self, ri, member, presence):
580        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
581
582
583class TypeNest(Type):
584    def is_recursive(self):
585        return self.family.pure_nested_structs[self.nested_attrs].recursive
586
587    def _complex_member_type(self, ri):
588        return self.nested_struct_type
589
590    def _free_lines(self, ri, var, ref):
591        lines = []
592        at = '&'
593        if self.is_recursive_for_op(ri):
594            at = ''
595            lines += [f'if ({var}->{ref}{self.c_name})']
596        lines += [f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});']
597        return lines
598
599    def _attr_typol(self):
600        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
601
602    def _attr_policy(self, policy):
603        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
604
605    def attr_put(self, ri, var):
606        at = '' if self.is_recursive_for_op(ri) else '&'
607        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
608                            f"{self.enum_name}, {at}{var}->{self.c_name})")
609
610    def _attr_get(self, ri, var):
611        get_lines = [f"if ({self.nested_render_name}_parse(&parg, attr))",
612                     "return YNL_PARSE_CB_ERROR;"]
613        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
614                      f"parg.data = &{var}->{self.c_name};"]
615        return get_lines, init_lines, None
616
617    def setter(self, ri, space, direction, deref=False, ref=None):
618        ref = (ref if ref else []) + [self.c_name]
619
620        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
621            if attr.is_recursive():
622                continue
623            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
624
625
626class TypeMultiAttr(Type):
627    def __init__(self, family, attr_set, attr, value, base_type):
628        super().__init__(family, attr_set, attr, value)
629
630        self.base_type = base_type
631
632    def is_multi_val(self):
633        return True
634
635    def presence_type(self):
636        return 'count'
637
638    def _complex_member_type(self, ri):
639        if 'type' not in self.attr or self.attr['type'] == 'nest':
640            return self.nested_struct_type
641        elif self.attr['type'] in scalars:
642            scalar_pfx = '__' if ri.ku_space == 'user' else ''
643            return scalar_pfx + self.attr['type']
644        else:
645            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
646
647    def free_needs_iter(self):
648        return 'type' not in self.attr or self.attr['type'] == 'nest'
649
650    def _free_lines(self, ri, var, ref):
651        lines = []
652        if self.attr['type'] in scalars:
653            lines += [f"free({var}->{ref}{self.c_name});"]
654        elif 'type' not in self.attr or self.attr['type'] == 'nest':
655            lines += [
656                f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)",
657                f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);',
658                f"free({var}->{ref}{self.c_name});",
659            ]
660        else:
661            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
662        return lines
663
664    def _attr_policy(self, policy):
665        return self.base_type._attr_policy(policy)
666
667    def _attr_typol(self):
668        return self.base_type._attr_typol()
669
670    def _attr_get(self, ri, var):
671        return f'n_{self.c_name}++;', None, None
672
673    def attr_put(self, ri, var):
674        if self.attr['type'] in scalars:
675            put_type = self.type
676            ri.cw.p(f"for (i = 0; i < {var}->n_{self.c_name}; i++)")
677            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
678        elif 'type' not in self.attr or self.attr['type'] == 'nest':
679            ri.cw.p(f"for (i = 0; i < {var}->n_{self.c_name}; i++)")
680            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
681                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
682        else:
683            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
684
685    def _setter_lines(self, ri, member, presence):
686        # For multi-attr we have a count, not presence, hack up the presence
687        presence = presence[:-(len('_present.') + len(self.c_name))] + "n_" + self.c_name
688        return [f"{member} = {self.c_name};",
689                f"{presence} = n_{self.c_name};"]
690
691
692class TypeArrayNest(Type):
693    def is_multi_val(self):
694        return True
695
696    def presence_type(self):
697        return 'count'
698
699    def _complex_member_type(self, ri):
700        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
701            return self.nested_struct_type
702        elif self.attr['sub-type'] in scalars:
703            scalar_pfx = '__' if ri.ku_space == 'user' else ''
704            return scalar_pfx + self.attr['sub-type']
705        else:
706            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
707
708    def _attr_typol(self):
709        if self.attr['sub-type'] in scalars:
710            return f'.type = YNL_PT_U{c_upper(self.sub_type[1:])}, '
711        else:
712            return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
713
714    def _attr_get(self, ri, var):
715        local_vars = ['const struct nlattr *attr2;']
716        get_lines = [f'attr_{self.c_name} = attr;',
717                     'ynl_attr_for_each_nested(attr2, attr) {',
718                     '\tif (ynl_attr_validate(yarg, attr2))',
719                     '\t\treturn YNL_PARSE_CB_ERROR;',
720                     f'\t{var}->n_{self.c_name}++;',
721                     '}']
722        return get_lines, None, local_vars
723
724
725class TypeNestTypeValue(Type):
726    def _complex_member_type(self, ri):
727        return self.nested_struct_type
728
729    def _attr_typol(self):
730        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
731
732    def _attr_get(self, ri, var):
733        prev = 'attr'
734        tv_args = ''
735        get_lines = []
736        local_vars = []
737        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
738                      f"parg.data = &{var}->{self.c_name};"]
739        if 'type-value' in self.attr:
740            tv_names = [c_lower(x) for x in self.attr["type-value"]]
741            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
742            local_vars += [f'__u32 {", ".join(tv_names)};']
743            for level in self.attr["type-value"]:
744                level = c_lower(level)
745                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
746                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
747                prev = 'attr_' + level
748
749            tv_args = f", {', '.join(tv_names)}"
750
751        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
752        return get_lines, init_lines, local_vars
753
754
755class Struct:
756    def __init__(self, family, space_name, type_list=None, inherited=None):
757        self.family = family
758        self.space_name = space_name
759        self.attr_set = family.attr_sets[space_name]
760        # Use list to catch comparisons with empty sets
761        self._inherited = inherited if inherited is not None else []
762        self.inherited = []
763
764        self.nested = type_list is None
765        if family.name == c_lower(space_name):
766            self.render_name = c_lower(family.ident_name)
767        else:
768            self.render_name = c_lower(family.ident_name + '-' + space_name)
769        self.struct_name = 'struct ' + self.render_name
770        if self.nested and space_name in family.consts:
771            self.struct_name += '_'
772        self.ptr_name = self.struct_name + ' *'
773        # All attr sets this one contains, directly or multiple levels down
774        self.child_nests = set()
775
776        self.request = False
777        self.reply = False
778        self.recursive = False
779        self.in_multi_val = False  # used by a MultiAttr or and legacy arrays
780
781        self.attr_list = []
782        self.attrs = dict()
783        if type_list is not None:
784            for t in type_list:
785                self.attr_list.append((t, self.attr_set[t]),)
786        else:
787            for t in self.attr_set:
788                self.attr_list.append((t, self.attr_set[t]),)
789
790        max_val = 0
791        self.attr_max_val = None
792        for name, attr in self.attr_list:
793            if attr.value >= max_val:
794                max_val = attr.value
795                self.attr_max_val = attr
796            self.attrs[name] = attr
797
798    def __iter__(self):
799        yield from self.attrs
800
801    def __getitem__(self, key):
802        return self.attrs[key]
803
804    def member_list(self):
805        return self.attr_list
806
807    def set_inherited(self, new_inherited):
808        if self._inherited != new_inherited:
809            raise Exception("Inheriting different members not supported")
810        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
811
812
813class EnumEntry(SpecEnumEntry):
814    def __init__(self, enum_set, yaml, prev, value_start):
815        super().__init__(enum_set, yaml, prev, value_start)
816
817        if prev:
818            self.value_change = (self.value != prev.value + 1)
819        else:
820            self.value_change = (self.value != 0)
821        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
822
823        # Added by resolve:
824        self.c_name = None
825        delattr(self, "c_name")
826
827    def resolve(self):
828        self.resolve_up(super())
829
830        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
831
832
833class EnumSet(SpecEnumSet):
834    def __init__(self, family, yaml):
835        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
836
837        if 'enum-name' in yaml:
838            if yaml['enum-name']:
839                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
840                self.user_type = self.enum_name
841            else:
842                self.enum_name = None
843        else:
844            self.enum_name = 'enum ' + self.render_name
845
846        if self.enum_name:
847            self.user_type = self.enum_name
848        else:
849            self.user_type = 'int'
850
851        self.value_pfx = yaml.get('name-prefix', f"{family.ident_name}-{yaml['name']}-")
852        self.header = yaml.get('header', None)
853        self.enum_cnt_name = yaml.get('enum-cnt-name', None)
854
855        super().__init__(family, yaml)
856
857    def new_entry(self, entry, prev_entry, value_start):
858        return EnumEntry(self, entry, prev_entry, value_start)
859
860    def value_range(self):
861        low = min([x.value for x in self.entries.values()])
862        high = max([x.value for x in self.entries.values()])
863
864        if high - low + 1 != len(self.entries):
865            raise Exception("Can't get value range for a noncontiguous enum")
866
867        return low, high
868
869
870class AttrSet(SpecAttrSet):
871    def __init__(self, family, yaml):
872        super().__init__(family, yaml)
873
874        if self.subset_of is None:
875            if 'name-prefix' in yaml:
876                pfx = yaml['name-prefix']
877            elif self.name == family.name:
878                pfx = family.ident_name + '-a-'
879            else:
880                pfx = f"{family.ident_name}-a-{self.name}-"
881            self.name_prefix = c_upper(pfx)
882            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
883            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
884        else:
885            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
886            self.max_name = family.attr_sets[self.subset_of].max_name
887            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
888
889        # Added by resolve:
890        self.c_name = None
891        delattr(self, "c_name")
892
893    def resolve(self):
894        self.c_name = c_lower(self.name)
895        if self.c_name in _C_KW:
896            self.c_name += '_'
897        if self.c_name == self.family.c_name:
898            self.c_name = ''
899
900    def new_attr(self, elem, value):
901        if elem['type'] in scalars:
902            t = TypeScalar(self.family, self, elem, value)
903        elif elem['type'] == 'unused':
904            t = TypeUnused(self.family, self, elem, value)
905        elif elem['type'] == 'pad':
906            t = TypePad(self.family, self, elem, value)
907        elif elem['type'] == 'flag':
908            t = TypeFlag(self.family, self, elem, value)
909        elif elem['type'] == 'string':
910            t = TypeString(self.family, self, elem, value)
911        elif elem['type'] == 'binary':
912            t = TypeBinary(self.family, self, elem, value)
913        elif elem['type'] == 'bitfield32':
914            t = TypeBitfield32(self.family, self, elem, value)
915        elif elem['type'] == 'nest':
916            t = TypeNest(self.family, self, elem, value)
917        elif elem['type'] == 'indexed-array' and 'sub-type' in elem:
918            if elem["sub-type"] in ['nest', 'u32']:
919                t = TypeArrayNest(self.family, self, elem, value)
920            else:
921                raise Exception(f'new_attr: unsupported sub-type {elem["sub-type"]}')
922        elif elem['type'] == 'nest-type-value':
923            t = TypeNestTypeValue(self.family, self, elem, value)
924        else:
925            raise Exception(f"No typed class for type {elem['type']}")
926
927        if 'multi-attr' in elem and elem['multi-attr']:
928            t = TypeMultiAttr(self.family, self, elem, value, t)
929
930        return t
931
932
933class Operation(SpecOperation):
934    def __init__(self, family, yaml, req_value, rsp_value):
935        super().__init__(family, yaml, req_value, rsp_value)
936
937        self.render_name = c_lower(family.ident_name + '_' + self.name)
938
939        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
940                         ('dump' in yaml and 'request' in yaml['dump'])
941
942        self.has_ntf = False
943
944        # Added by resolve:
945        self.enum_name = None
946        delattr(self, "enum_name")
947
948    def resolve(self):
949        self.resolve_up(super())
950
951        if not self.is_async:
952            self.enum_name = self.family.op_prefix + c_upper(self.name)
953        else:
954            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
955
956    def mark_has_ntf(self):
957        self.has_ntf = True
958
959
960class Family(SpecFamily):
961    def __init__(self, file_name, exclude_ops):
962        # Added by resolve:
963        self.c_name = None
964        delattr(self, "c_name")
965        self.op_prefix = None
966        delattr(self, "op_prefix")
967        self.async_op_prefix = None
968        delattr(self, "async_op_prefix")
969        self.mcgrps = None
970        delattr(self, "mcgrps")
971        self.consts = None
972        delattr(self, "consts")
973        self.hooks = None
974        delattr(self, "hooks")
975
976        super().__init__(file_name, exclude_ops=exclude_ops)
977
978        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
979        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
980
981        if 'definitions' not in self.yaml:
982            self.yaml['definitions'] = []
983
984        if 'uapi-header' in self.yaml:
985            self.uapi_header = self.yaml['uapi-header']
986        else:
987            self.uapi_header = f"linux/{self.ident_name}.h"
988        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
989            self.uapi_header_name = self.uapi_header[6:-2]
990        else:
991            self.uapi_header_name = self.ident_name
992
993    def resolve(self):
994        self.resolve_up(super())
995
996        self.c_name = c_lower(self.ident_name)
997        if 'name-prefix' in self.yaml['operations']:
998            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
999        else:
1000            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
1001        if 'async-prefix' in self.yaml['operations']:
1002            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
1003        else:
1004            self.async_op_prefix = self.op_prefix
1005
1006        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
1007
1008        self.hooks = dict()
1009        for when in ['pre', 'post']:
1010            self.hooks[when] = dict()
1011            for op_mode in ['do', 'dump']:
1012                self.hooks[when][op_mode] = dict()
1013                self.hooks[when][op_mode]['set'] = set()
1014                self.hooks[when][op_mode]['list'] = []
1015
1016        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
1017        self.root_sets = dict()
1018        # dict space-name -> set('request', 'reply')
1019        self.pure_nested_structs = dict()
1020
1021        self._mark_notify()
1022        self._mock_up_events()
1023
1024        self._load_root_sets()
1025        self._load_nested_sets()
1026        self._load_attr_use()
1027        self._load_hooks()
1028
1029        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
1030        if self.kernel_policy == 'global':
1031            self._load_global_policy()
1032
1033    def new_enum(self, elem):
1034        return EnumSet(self, elem)
1035
1036    def new_attr_set(self, elem):
1037        return AttrSet(self, elem)
1038
1039    def new_operation(self, elem, req_value, rsp_value):
1040        return Operation(self, elem, req_value, rsp_value)
1041
1042    def is_classic(self):
1043        return self.proto == 'netlink-raw'
1044
1045    def _mark_notify(self):
1046        for op in self.msgs.values():
1047            if 'notify' in op:
1048                self.ops[op['notify']].mark_has_ntf()
1049
1050    # Fake a 'do' equivalent of all events, so that we can render their response parsing
1051    def _mock_up_events(self):
1052        for op in self.yaml['operations']['list']:
1053            if 'event' in op:
1054                op['do'] = {
1055                    'reply': {
1056                        'attributes': op['event']['attributes']
1057                    }
1058                }
1059
1060    def _load_root_sets(self):
1061        for op_name, op in self.msgs.items():
1062            if 'attribute-set' not in op:
1063                continue
1064
1065            req_attrs = set()
1066            rsp_attrs = set()
1067            for op_mode in ['do', 'dump']:
1068                if op_mode in op and 'request' in op[op_mode]:
1069                    req_attrs.update(set(op[op_mode]['request']['attributes']))
1070                if op_mode in op and 'reply' in op[op_mode]:
1071                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
1072            if 'event' in op:
1073                rsp_attrs.update(set(op['event']['attributes']))
1074
1075            if op['attribute-set'] not in self.root_sets:
1076                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1077            else:
1078                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1079                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1080
1081    def _sort_pure_types(self):
1082        # Try to reorder according to dependencies
1083        pns_key_list = list(self.pure_nested_structs.keys())
1084        pns_key_seen = set()
1085        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1086        for _ in range(rounds):
1087            if len(pns_key_list) == 0:
1088                break
1089            name = pns_key_list.pop(0)
1090            finished = True
1091            for _, spec in self.attr_sets[name].items():
1092                if 'nested-attributes' in spec:
1093                    nested = spec['nested-attributes']
1094                    # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1095                    if self.pure_nested_structs[nested].recursive:
1096                        continue
1097                    if nested not in pns_key_seen:
1098                        # Dicts are sorted, this will make struct last
1099                        struct = self.pure_nested_structs.pop(name)
1100                        self.pure_nested_structs[name] = struct
1101                        finished = False
1102                        break
1103            if finished:
1104                pns_key_seen.add(name)
1105            else:
1106                pns_key_list.append(name)
1107
1108    def _load_nested_sets(self):
1109        attr_set_queue = list(self.root_sets.keys())
1110        attr_set_seen = set(self.root_sets.keys())
1111
1112        while len(attr_set_queue):
1113            a_set = attr_set_queue.pop(0)
1114            for attr, spec in self.attr_sets[a_set].items():
1115                if 'nested-attributes' not in spec:
1116                    continue
1117
1118                nested = spec['nested-attributes']
1119                if nested not in attr_set_seen:
1120                    attr_set_queue.append(nested)
1121                    attr_set_seen.add(nested)
1122
1123                inherit = set()
1124                if nested not in self.root_sets:
1125                    if nested not in self.pure_nested_structs:
1126                        self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
1127                else:
1128                    raise Exception(f'Using attr set as root and nested not supported - {nested}')
1129
1130                if 'type-value' in spec:
1131                    if nested in self.root_sets:
1132                        raise Exception("Inheriting members to a space used as root not supported")
1133                    inherit.update(set(spec['type-value']))
1134                elif spec['type'] == 'indexed-array':
1135                    inherit.add('idx')
1136                self.pure_nested_structs[nested].set_inherited(inherit)
1137
1138        for root_set, rs_members in self.root_sets.items():
1139            for attr, spec in self.attr_sets[root_set].items():
1140                if 'nested-attributes' in spec:
1141                    nested = spec['nested-attributes']
1142                    if attr in rs_members['request']:
1143                        self.pure_nested_structs[nested].request = True
1144                    if attr in rs_members['reply']:
1145                        self.pure_nested_structs[nested].reply = True
1146
1147                if spec.is_multi_val():
1148                    child = self.pure_nested_structs.get(nested)
1149                    child.in_multi_val = True
1150
1151        self._sort_pure_types()
1152
1153        # Propagate the request / reply / recursive
1154        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1155            for _, spec in self.attr_sets[attr_set].items():
1156                if 'nested-attributes' in spec:
1157                    child_name = spec['nested-attributes']
1158                    struct.child_nests.add(child_name)
1159                    child = self.pure_nested_structs.get(child_name)
1160                    if child:
1161                        if not child.recursive:
1162                            struct.child_nests.update(child.child_nests)
1163                        child.request |= struct.request
1164                        child.reply |= struct.reply
1165                        if spec.is_multi_val():
1166                            child.in_multi_val = True
1167                if attr_set in struct.child_nests:
1168                    struct.recursive = True
1169
1170        self._sort_pure_types()
1171
1172    def _load_attr_use(self):
1173        for _, struct in self.pure_nested_structs.items():
1174            if struct.request:
1175                for _, arg in struct.member_list():
1176                    arg.set_request()
1177            if struct.reply:
1178                for _, arg in struct.member_list():
1179                    arg.set_reply()
1180
1181        for root_set, rs_members in self.root_sets.items():
1182            for attr, spec in self.attr_sets[root_set].items():
1183                if attr in rs_members['request']:
1184                    spec.set_request()
1185                if attr in rs_members['reply']:
1186                    spec.set_reply()
1187
1188    def _load_global_policy(self):
1189        global_set = set()
1190        attr_set_name = None
1191        for op_name, op in self.ops.items():
1192            if not op:
1193                continue
1194            if 'attribute-set' not in op:
1195                continue
1196
1197            if attr_set_name is None:
1198                attr_set_name = op['attribute-set']
1199            if attr_set_name != op['attribute-set']:
1200                raise Exception('For a global policy all ops must use the same set')
1201
1202            for op_mode in ['do', 'dump']:
1203                if op_mode in op:
1204                    req = op[op_mode].get('request')
1205                    if req:
1206                        global_set.update(req.get('attributes', []))
1207
1208        self.global_policy = []
1209        self.global_policy_set = attr_set_name
1210        for attr in self.attr_sets[attr_set_name]:
1211            if attr in global_set:
1212                self.global_policy.append(attr)
1213
1214    def _load_hooks(self):
1215        for op in self.ops.values():
1216            for op_mode in ['do', 'dump']:
1217                if op_mode not in op:
1218                    continue
1219                for when in ['pre', 'post']:
1220                    if when not in op[op_mode]:
1221                        continue
1222                    name = op[op_mode][when]
1223                    if name in self.hooks[when][op_mode]['set']:
1224                        continue
1225                    self.hooks[when][op_mode]['set'].add(name)
1226                    self.hooks[when][op_mode]['list'].append(name)
1227
1228
1229class RenderInfo:
1230    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1231        self.family = family
1232        self.nl = cw.nlib
1233        self.ku_space = ku_space
1234        self.op_mode = op_mode
1235        self.op = op
1236
1237        self.fixed_hdr = None
1238        if op and op.fixed_header:
1239            self.fixed_hdr = 'struct ' + c_lower(op.fixed_header)
1240
1241        # 'do' and 'dump' response parsing is identical
1242        self.type_consistent = True
1243        self.type_oneside = False
1244        if op_mode != 'do' and 'dump' in op:
1245            if 'do' in op:
1246                if ('reply' in op['do']) != ('reply' in op["dump"]):
1247                    self.type_consistent = False
1248                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1249                    self.type_consistent = False
1250            else:
1251                self.type_consistent = True
1252                self.type_oneside = True
1253
1254        self.attr_set = attr_set
1255        if not self.attr_set:
1256            self.attr_set = op['attribute-set']
1257
1258        self.type_name_conflict = False
1259        if op:
1260            self.type_name = c_lower(op.name)
1261        else:
1262            self.type_name = c_lower(attr_set)
1263            if attr_set in family.consts:
1264                self.type_name_conflict = True
1265
1266        self.cw = cw
1267
1268        self.struct = dict()
1269        if op_mode == 'notify':
1270            op_mode = 'do'
1271        for op_dir in ['request', 'reply']:
1272            if op:
1273                type_list = []
1274                if op_dir in op[op_mode]:
1275                    type_list = op[op_mode][op_dir]['attributes']
1276                self.struct[op_dir] = Struct(family, self.attr_set, type_list=type_list)
1277        if op_mode == 'event':
1278            self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
1279
1280    def type_empty(self, key):
1281        return len(self.struct[key].attr_list) == 0 and self.fixed_hdr is None
1282
1283
1284class CodeWriter:
1285    def __init__(self, nlib, out_file=None, overwrite=True):
1286        self.nlib = nlib
1287        self._overwrite = overwrite
1288
1289        self._nl = False
1290        self._block_end = False
1291        self._silent_block = False
1292        self._ind = 0
1293        self._ifdef_block = None
1294        if out_file is None:
1295            self._out = os.sys.stdout
1296        else:
1297            self._out = tempfile.NamedTemporaryFile('w+')
1298            self._out_file = out_file
1299
1300    def __del__(self):
1301        self.close_out_file()
1302
1303    def close_out_file(self):
1304        if self._out == os.sys.stdout:
1305            return
1306        # Avoid modifying the file if contents didn't change
1307        self._out.flush()
1308        if not self._overwrite and os.path.isfile(self._out_file):
1309            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1310                return
1311        with open(self._out_file, 'w+') as out_file:
1312            self._out.seek(0)
1313            shutil.copyfileobj(self._out, out_file)
1314            self._out.close()
1315        self._out = os.sys.stdout
1316
1317    @classmethod
1318    def _is_cond(cls, line):
1319        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1320
1321    def p(self, line, add_ind=0):
1322        if self._block_end:
1323            self._block_end = False
1324            if line.startswith('else'):
1325                line = '} ' + line
1326            else:
1327                self._out.write('\t' * self._ind + '}\n')
1328
1329        if self._nl:
1330            self._out.write('\n')
1331            self._nl = False
1332
1333        ind = self._ind
1334        if line[-1] == ':':
1335            ind -= 1
1336        if self._silent_block:
1337            ind += 1
1338        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1339        if line[0] == '#':
1340            ind = 0
1341        if add_ind:
1342            ind += add_ind
1343        self._out.write('\t' * ind + line + '\n')
1344
1345    def nl(self):
1346        self._nl = True
1347
1348    def block_start(self, line=''):
1349        if line:
1350            line = line + ' '
1351        self.p(line + '{')
1352        self._ind += 1
1353
1354    def block_end(self, line=''):
1355        if line and line[0] not in {';', ','}:
1356            line = ' ' + line
1357        self._ind -= 1
1358        self._nl = False
1359        if not line:
1360            # Delay printing closing bracket in case "else" comes next
1361            if self._block_end:
1362                self._out.write('\t' * (self._ind + 1) + '}\n')
1363            self._block_end = True
1364        else:
1365            self.p('}' + line)
1366
1367    def write_doc_line(self, doc, indent=True):
1368        words = doc.split()
1369        line = ' *'
1370        for word in words:
1371            if len(line) + len(word) >= 79:
1372                self.p(line)
1373                line = ' *'
1374                if indent:
1375                    line += '  '
1376            line += ' ' + word
1377        self.p(line)
1378
1379    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1380        if not args:
1381            args = ['void']
1382
1383        if doc:
1384            self.p('/*')
1385            self.p(' * ' + doc)
1386            self.p(' */')
1387
1388        oneline = qual_ret
1389        if qual_ret[-1] != '*':
1390            oneline += ' '
1391        oneline += f"{name}({', '.join(args)}){suffix}"
1392
1393        if len(oneline) < 80:
1394            self.p(oneline)
1395            return
1396
1397        v = qual_ret
1398        if len(v) > 3:
1399            self.p(v)
1400            v = ''
1401        elif qual_ret[-1] != '*':
1402            v += ' '
1403        v += name + '('
1404        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1405        delta_ind = len(v) - len(ind)
1406        v += args[0]
1407        i = 1
1408        while i < len(args):
1409            next_len = len(v) + len(args[i])
1410            if v[0] == '\t':
1411                next_len += delta_ind
1412            if next_len > 76:
1413                self.p(v + ',')
1414                v = ind
1415            else:
1416                v += ', '
1417            v += args[i]
1418            i += 1
1419        self.p(v + ')' + suffix)
1420
1421    def write_func_lvar(self, local_vars):
1422        if not local_vars:
1423            return
1424
1425        if type(local_vars) is str:
1426            local_vars = [local_vars]
1427
1428        local_vars.sort(key=len, reverse=True)
1429        for var in local_vars:
1430            self.p(var)
1431        self.nl()
1432
1433    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1434        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1435        self.block_start()
1436        self.write_func_lvar(local_vars=local_vars)
1437
1438        for line in body:
1439            self.p(line)
1440        self.block_end()
1441
1442    def writes_defines(self, defines):
1443        longest = 0
1444        for define in defines:
1445            if len(define[0]) > longest:
1446                longest = len(define[0])
1447        longest = ((longest + 8) // 8) * 8
1448        for define in defines:
1449            line = '#define ' + define[0]
1450            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1451            if type(define[1]) is int:
1452                line += str(define[1])
1453            elif type(define[1]) is str:
1454                line += '"' + define[1] + '"'
1455            self.p(line)
1456
1457    def write_struct_init(self, members):
1458        longest = max([len(x[0]) for x in members])
1459        longest += 1  # because we prepend a .
1460        longest = ((longest + 8) // 8) * 8
1461        for one in members:
1462            line = '.' + one[0]
1463            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1464            line += '= ' + str(one[1]) + ','
1465            self.p(line)
1466
1467    def ifdef_block(self, config):
1468        config_option = None
1469        if config:
1470            config_option = 'CONFIG_' + c_upper(config)
1471        if self._ifdef_block == config_option:
1472            return
1473
1474        if self._ifdef_block:
1475            self.p('#endif /* ' + self._ifdef_block + ' */')
1476        if config_option:
1477            self.p('#ifdef ' + config_option)
1478        self._ifdef_block = config_option
1479
1480
1481scalars = {'u8', 'u16', 'u32', 'u64', 's8', 's16', 's32', 's64', 'uint', 'sint'}
1482
1483direction_to_suffix = {
1484    'reply': '_rsp',
1485    'request': '_req',
1486    '': ''
1487}
1488
1489op_mode_to_wrapper = {
1490    'do': '',
1491    'dump': '_list',
1492    'notify': '_ntf',
1493    'event': '',
1494}
1495
1496_C_KW = {
1497    'auto',
1498    'bool',
1499    'break',
1500    'case',
1501    'char',
1502    'const',
1503    'continue',
1504    'default',
1505    'do',
1506    'double',
1507    'else',
1508    'enum',
1509    'extern',
1510    'float',
1511    'for',
1512    'goto',
1513    'if',
1514    'inline',
1515    'int',
1516    'long',
1517    'register',
1518    'return',
1519    'short',
1520    'signed',
1521    'sizeof',
1522    'static',
1523    'struct',
1524    'switch',
1525    'typedef',
1526    'union',
1527    'unsigned',
1528    'void',
1529    'volatile',
1530    'while'
1531}
1532
1533
1534def rdir(direction):
1535    if direction == 'reply':
1536        return 'request'
1537    if direction == 'request':
1538        return 'reply'
1539    return direction
1540
1541
1542def op_prefix(ri, direction, deref=False):
1543    suffix = f"_{ri.type_name}"
1544
1545    if not ri.op_mode or ri.op_mode == 'do':
1546        suffix += f"{direction_to_suffix[direction]}"
1547    else:
1548        if direction == 'request':
1549            suffix += '_req'
1550            if not ri.type_oneside:
1551                suffix += '_dump'
1552        else:
1553            if ri.type_consistent:
1554                if deref:
1555                    suffix += f"{direction_to_suffix[direction]}"
1556                else:
1557                    suffix += op_mode_to_wrapper[ri.op_mode]
1558            else:
1559                suffix += '_rsp'
1560                suffix += '_dump' if deref else '_list'
1561
1562    return f"{ri.family.c_name}{suffix}"
1563
1564
1565def type_name(ri, direction, deref=False):
1566    return f"struct {op_prefix(ri, direction, deref=deref)}"
1567
1568
1569def print_prototype(ri, direction, terminate=True, doc=None):
1570    suffix = ';' if terminate else ''
1571
1572    fname = ri.op.render_name
1573    if ri.op_mode == 'dump':
1574        fname += '_dump'
1575
1576    args = ['struct ynl_sock *ys']
1577    if 'request' in ri.op[ri.op_mode]:
1578        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1579
1580    ret = 'int'
1581    if 'reply' in ri.op[ri.op_mode]:
1582        ret = f"{type_name(ri, rdir(direction))} *"
1583
1584    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1585
1586
1587def print_req_prototype(ri):
1588    print_prototype(ri, "request", doc=ri.op['doc'])
1589
1590
1591def print_dump_prototype(ri):
1592    print_prototype(ri, "request")
1593
1594
1595def put_typol_fwd(cw, struct):
1596    cw.p(f'extern const struct ynl_policy_nest {struct.render_name}_nest;')
1597
1598
1599def put_typol(cw, struct):
1600    type_max = struct.attr_set.max_name
1601    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1602
1603    for _, arg in struct.member_list():
1604        arg.attr_typol(cw)
1605
1606    cw.block_end(line=';')
1607    cw.nl()
1608
1609    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1610    cw.p(f'.max_attr = {type_max},')
1611    cw.p(f'.table = {struct.render_name}_policy,')
1612    cw.block_end(line=';')
1613    cw.nl()
1614
1615
1616def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1617    args = [f'int {arg_name}']
1618    if enum:
1619        args = [enum.user_type + ' ' + arg_name]
1620    cw.write_func_prot('const char *', f'{render_name}_str', args)
1621    cw.block_start()
1622    if enum and enum.type == 'flags':
1623        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1624    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
1625    cw.p('return NULL;')
1626    cw.p(f'return {map_name}[{arg_name}];')
1627    cw.block_end()
1628    cw.nl()
1629
1630
1631def put_op_name_fwd(family, cw):
1632    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
1633
1634
1635def put_op_name(family, cw):
1636    map_name = f'{family.c_name}_op_strmap'
1637    cw.block_start(line=f"static const char * const {map_name}[] =")
1638    for op_name, op in family.msgs.items():
1639        if op.rsp_value:
1640            # Make sure we don't add duplicated entries, if multiple commands
1641            # produce the same response in legacy families.
1642            if family.rsp_by_value[op.rsp_value] != op:
1643                cw.p(f'// skip "{op_name}", duplicate reply value')
1644                continue
1645
1646            if op.req_value == op.rsp_value:
1647                cw.p(f'[{op.enum_name}] = "{op_name}",')
1648            else:
1649                cw.p(f'[{op.rsp_value}] = "{op_name}",')
1650    cw.block_end(line=';')
1651    cw.nl()
1652
1653    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
1654
1655
1656def put_enum_to_str_fwd(family, cw, enum):
1657    args = [enum.user_type + ' value']
1658    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1659
1660
1661def put_enum_to_str(family, cw, enum):
1662    map_name = f'{enum.render_name}_strmap'
1663    cw.block_start(line=f"static const char * const {map_name}[] =")
1664    for entry in enum.entries.values():
1665        cw.p(f'[{entry.value}] = "{entry.name}",')
1666    cw.block_end(line=';')
1667    cw.nl()
1668
1669    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
1670
1671
1672def put_req_nested_prototype(ri, struct, suffix=';'):
1673    func_args = ['struct nlmsghdr *nlh',
1674                 'unsigned int attr_type',
1675                 f'{struct.ptr_name}obj']
1676
1677    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
1678                          suffix=suffix)
1679
1680
1681def put_req_nested(ri, struct):
1682    local_vars = []
1683    init_lines = []
1684
1685    local_vars.append('struct nlattr *nest;')
1686    init_lines.append("nest = ynl_attr_nest_start(nlh, attr_type);")
1687
1688    for _, arg in struct.member_list():
1689        if arg.presence_type() == 'count':
1690            local_vars.append('unsigned int i;')
1691            break
1692
1693    put_req_nested_prototype(ri, struct, suffix='')
1694    ri.cw.block_start()
1695    ri.cw.write_func_lvar(local_vars)
1696
1697    for line in init_lines:
1698        ri.cw.p(line)
1699
1700    for _, arg in struct.member_list():
1701        arg.attr_put(ri, "obj")
1702
1703    ri.cw.p("ynl_attr_nest_end(nlh, nest);")
1704
1705    ri.cw.nl()
1706    ri.cw.p('return 0;')
1707    ri.cw.block_end()
1708    ri.cw.nl()
1709
1710
1711def _multi_parse(ri, struct, init_lines, local_vars):
1712    if struct.nested:
1713        iter_line = "ynl_attr_for_each_nested(attr, nested)"
1714    else:
1715        if ri.fixed_hdr:
1716            local_vars += ['void *hdr;']
1717        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
1718
1719    array_nests = set()
1720    multi_attrs = set()
1721    needs_parg = False
1722    for arg, aspec in struct.member_list():
1723        if aspec['type'] == 'indexed-array' and 'sub-type' in aspec:
1724            if aspec["sub-type"] == 'nest':
1725                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1726                array_nests.add(arg)
1727            elif aspec['sub-type'] in scalars:
1728                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1729                array_nests.add(arg)
1730            else:
1731                raise Exception(f'Not supported sub-type {aspec["sub-type"]}')
1732        if 'multi-attr' in aspec:
1733            multi_attrs.add(arg)
1734        needs_parg |= 'nested-attributes' in aspec
1735    if array_nests or multi_attrs:
1736        local_vars.append('int i;')
1737    if needs_parg:
1738        local_vars.append('struct ynl_parse_arg parg;')
1739        init_lines.append('parg.ys = yarg->ys;')
1740
1741    all_multi = array_nests | multi_attrs
1742
1743    for anest in sorted(all_multi):
1744        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
1745
1746    ri.cw.block_start()
1747    ri.cw.write_func_lvar(local_vars)
1748
1749    for line in init_lines:
1750        ri.cw.p(line)
1751    ri.cw.nl()
1752
1753    for arg in struct.inherited:
1754        ri.cw.p(f'dst->{arg} = {arg};')
1755
1756    if ri.fixed_hdr:
1757        if ri.family.is_classic():
1758            ri.cw.p('hdr = ynl_nlmsg_data(nlh);')
1759        else:
1760            ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
1761        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({ri.fixed_hdr}));")
1762    for anest in sorted(all_multi):
1763        aspec = struct[anest]
1764        ri.cw.p(f"if (dst->{aspec.c_name})")
1765        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
1766
1767    ri.cw.nl()
1768    ri.cw.block_start(line=iter_line)
1769    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
1770    ri.cw.nl()
1771
1772    first = True
1773    for _, arg in struct.member_list():
1774        good = arg.attr_get(ri, 'dst', first=first)
1775        # First may be 'unused' or 'pad', ignore those
1776        first &= not good
1777
1778    ri.cw.block_end()
1779    ri.cw.nl()
1780
1781    for anest in sorted(array_nests):
1782        aspec = struct[anest]
1783
1784        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1785        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1786        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1787        ri.cw.p('i = 0;')
1788        if 'nested-attributes' in aspec:
1789            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1790        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1791        if 'nested-attributes' in aspec:
1792            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1793            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
1794            ri.cw.p('return YNL_PARSE_CB_ERROR;')
1795        elif aspec.sub_type in scalars:
1796            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.sub_type}(attr);")
1797        else:
1798            raise Exception(f"Nest parsing type not supported in {aspec['name']}")
1799        ri.cw.p('i++;')
1800        ri.cw.block_end()
1801        ri.cw.block_end()
1802    ri.cw.nl()
1803
1804    for anest in sorted(multi_attrs):
1805        aspec = struct[anest]
1806        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1807        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1808        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1809        ri.cw.p('i = 0;')
1810        if 'nested-attributes' in aspec:
1811            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1812        ri.cw.block_start(line=iter_line)
1813        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
1814        if 'nested-attributes' in aspec:
1815            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1816            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1817            ri.cw.p('return YNL_PARSE_CB_ERROR;')
1818        elif aspec.type in scalars:
1819            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
1820        else:
1821            raise Exception('Nest parsing type not supported yet')
1822        ri.cw.p('i++;')
1823        ri.cw.block_end()
1824        ri.cw.block_end()
1825        ri.cw.block_end()
1826    ri.cw.nl()
1827
1828    if struct.nested:
1829        ri.cw.p('return 0;')
1830    else:
1831        ri.cw.p('return YNL_PARSE_CB_OK;')
1832    ri.cw.block_end()
1833    ri.cw.nl()
1834
1835
1836def parse_rsp_nested_prototype(ri, struct, suffix=';'):
1837    func_args = ['struct ynl_parse_arg *yarg',
1838                 'const struct nlattr *nested']
1839    for arg in struct.inherited:
1840        func_args.append('__u32 ' + arg)
1841
1842    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
1843                          suffix=suffix)
1844
1845
1846def parse_rsp_nested(ri, struct):
1847    parse_rsp_nested_prototype(ri, struct, suffix='')
1848
1849    local_vars = ['const struct nlattr *attr;',
1850                  f'{struct.ptr_name}dst = yarg->data;']
1851    init_lines = []
1852
1853    if struct.member_list():
1854        _multi_parse(ri, struct, init_lines, local_vars)
1855    else:
1856        # Empty nest
1857        ri.cw.block_start()
1858        ri.cw.p('return 0;')
1859        ri.cw.block_end()
1860        ri.cw.nl()
1861
1862
1863def parse_rsp_msg(ri, deref=False):
1864    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1865        return
1866
1867    func_args = ['const struct nlmsghdr *nlh',
1868                 'struct ynl_parse_arg *yarg']
1869
1870    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1871                  'const struct nlattr *attr;']
1872    init_lines = ['dst = yarg->data;']
1873
1874    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1875
1876    if ri.struct["reply"].member_list():
1877        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1878    else:
1879        # Empty reply
1880        ri.cw.block_start()
1881        ri.cw.p('return YNL_PARSE_CB_OK;')
1882        ri.cw.block_end()
1883        ri.cw.nl()
1884
1885
1886def print_req(ri):
1887    ret_ok = '0'
1888    ret_err = '-1'
1889    direction = "request"
1890    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
1891                  'struct nlmsghdr *nlh;',
1892                  'int err;']
1893
1894    if 'reply' in ri.op[ri.op_mode]:
1895        ret_ok = 'rsp'
1896        ret_err = 'NULL'
1897        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
1898
1899    if ri.fixed_hdr:
1900        local_vars += ['size_t hdr_len;',
1901                       'void *hdr;']
1902
1903    for _, attr in ri.struct["request"].member_list():
1904        if attr.presence_type() == 'count':
1905            local_vars += ['unsigned int i;']
1906            break
1907
1908    print_prototype(ri, direction, terminate=False)
1909    ri.cw.block_start()
1910    ri.cw.write_func_lvar(local_vars)
1911
1912    if ri.family.is_classic():
1913        ri.cw.p(f"nlh = ynl_msg_start_req(ys, {ri.op.enum_name});")
1914    else:
1915        ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1916
1917    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1918    if 'reply' in ri.op[ri.op_mode]:
1919        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1920    ri.cw.nl()
1921
1922    if ri.fixed_hdr:
1923        ri.cw.p("hdr_len = sizeof(req->_hdr);")
1924        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
1925        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
1926        ri.cw.nl()
1927
1928    for _, attr in ri.struct["request"].member_list():
1929        attr.attr_put(ri, "req")
1930    ri.cw.nl()
1931
1932    if 'reply' in ri.op[ri.op_mode]:
1933        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1934        ri.cw.p('yrs.yarg.data = rsp;')
1935        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
1936        if ri.op.value is not None:
1937            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
1938        else:
1939            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
1940        ri.cw.nl()
1941    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
1942    ri.cw.p('if (err < 0)')
1943    if 'reply' in ri.op[ri.op_mode]:
1944        ri.cw.p('goto err_free;')
1945    else:
1946        ri.cw.p('return -1;')
1947    ri.cw.nl()
1948
1949    ri.cw.p(f"return {ret_ok};")
1950    ri.cw.nl()
1951
1952    if 'reply' in ri.op[ri.op_mode]:
1953        ri.cw.p('err_free:')
1954        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1955        ri.cw.p(f"return {ret_err};")
1956
1957    ri.cw.block_end()
1958
1959
1960def print_dump(ri):
1961    direction = "request"
1962    print_prototype(ri, direction, terminate=False)
1963    ri.cw.block_start()
1964    local_vars = ['struct ynl_dump_state yds = {};',
1965                  'struct nlmsghdr *nlh;',
1966                  'int err;']
1967
1968    if ri.fixed_hdr:
1969        local_vars += ['size_t hdr_len;',
1970                       'void *hdr;']
1971
1972    ri.cw.write_func_lvar(local_vars)
1973
1974    ri.cw.p('yds.yarg.ys = ys;')
1975    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1976    ri.cw.p("yds.yarg.data = NULL;")
1977    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1978    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1979    if ri.op.value is not None:
1980        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
1981    else:
1982        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
1983    ri.cw.nl()
1984    if ri.family.is_classic():
1985        ri.cw.p(f"nlh = ynl_msg_start_dump(ys, {ri.op.enum_name});")
1986    else:
1987        ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1988
1989    if ri.fixed_hdr:
1990        ri.cw.p("hdr_len = sizeof(req->_hdr);")
1991        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
1992        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
1993        ri.cw.nl()
1994
1995    if "request" in ri.op[ri.op_mode]:
1996        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1997        ri.cw.nl()
1998        for _, attr in ri.struct["request"].member_list():
1999            attr.attr_put(ri, "req")
2000    ri.cw.nl()
2001
2002    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
2003    ri.cw.p('if (err < 0)')
2004    ri.cw.p('goto free_list;')
2005    ri.cw.nl()
2006
2007    ri.cw.p('return yds.first;')
2008    ri.cw.nl()
2009    ri.cw.p('free_list:')
2010    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
2011    ri.cw.p('return NULL;')
2012    ri.cw.block_end()
2013
2014
2015def call_free(ri, direction, var):
2016    return f"{op_prefix(ri, direction)}_free({var});"
2017
2018
2019def free_arg_name(direction):
2020    if direction:
2021        return direction_to_suffix[direction][1:]
2022    return 'obj'
2023
2024
2025def print_alloc_wrapper(ri, direction):
2026    name = op_prefix(ri, direction)
2027    ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
2028    ri.cw.block_start()
2029    ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
2030    ri.cw.block_end()
2031
2032
2033def print_free_prototype(ri, direction, suffix=';'):
2034    name = op_prefix(ri, direction)
2035    struct_name = name
2036    if ri.type_name_conflict:
2037        struct_name += '_'
2038    arg = free_arg_name(direction)
2039    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
2040
2041
2042def _print_type(ri, direction, struct):
2043    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
2044    if not direction and ri.type_name_conflict:
2045        suffix += '_'
2046
2047    if ri.op_mode == 'dump' and not ri.type_oneside:
2048        suffix += '_dump'
2049
2050    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
2051
2052    if ri.fixed_hdr:
2053        ri.cw.p(ri.fixed_hdr + ' _hdr;')
2054        ri.cw.nl()
2055
2056    meta_started = False
2057    for _, attr in struct.member_list():
2058        for type_filter in ['len', 'bit']:
2059            line = attr.presence_member(ri.ku_space, type_filter)
2060            if line:
2061                if not meta_started:
2062                    ri.cw.block_start(line=f"struct")
2063                    meta_started = True
2064                ri.cw.p(line)
2065    if meta_started:
2066        ri.cw.block_end(line='_present;')
2067        ri.cw.nl()
2068
2069    for arg in struct.inherited:
2070        ri.cw.p(f"__u32 {arg};")
2071
2072    for _, attr in struct.member_list():
2073        attr.struct_member(ri)
2074
2075    ri.cw.block_end(line=';')
2076    ri.cw.nl()
2077
2078
2079def print_type(ri, direction):
2080    _print_type(ri, direction, ri.struct[direction])
2081
2082
2083def print_type_full(ri, struct):
2084    _print_type(ri, "", struct)
2085
2086
2087def print_type_helpers(ri, direction, deref=False):
2088    print_free_prototype(ri, direction)
2089    ri.cw.nl()
2090
2091    if ri.ku_space == 'user' and direction == 'request':
2092        for _, attr in ri.struct[direction].member_list():
2093            attr.setter(ri, ri.attr_set, direction, deref=deref)
2094    ri.cw.nl()
2095
2096
2097def print_req_type_helpers(ri):
2098    if ri.type_empty("request"):
2099        return
2100    print_alloc_wrapper(ri, "request")
2101    print_type_helpers(ri, "request")
2102
2103
2104def print_rsp_type_helpers(ri):
2105    if 'reply' not in ri.op[ri.op_mode]:
2106        return
2107    print_type_helpers(ri, "reply")
2108
2109
2110def print_parse_prototype(ri, direction, terminate=True):
2111    suffix = "_rsp" if direction == "reply" else "_req"
2112    term = ';' if terminate else ''
2113
2114    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
2115                          ['const struct nlattr **tb',
2116                           f"struct {ri.op.render_name}{suffix} *req"],
2117                          suffix=term)
2118
2119
2120def print_req_type(ri):
2121    if ri.type_empty("request"):
2122        return
2123    print_type(ri, "request")
2124
2125
2126def print_req_free(ri):
2127    if 'request' not in ri.op[ri.op_mode]:
2128        return
2129    _free_type(ri, 'request', ri.struct['request'])
2130
2131
2132def print_rsp_type(ri):
2133    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2134        direction = 'reply'
2135    elif ri.op_mode == 'event':
2136        direction = 'reply'
2137    else:
2138        return
2139    print_type(ri, direction)
2140
2141
2142def print_wrapped_type(ri):
2143    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2144    if ri.op_mode == 'dump':
2145        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2146    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2147        ri.cw.p('__u16 family;')
2148        ri.cw.p('__u8 cmd;')
2149        ri.cw.p('struct ynl_ntf_base_type *next;')
2150        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2151    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2152    ri.cw.block_end(line=';')
2153    ri.cw.nl()
2154    print_free_prototype(ri, 'reply')
2155    ri.cw.nl()
2156
2157
2158def _free_type_members_iter(ri, struct):
2159    for _, attr in struct.member_list():
2160        if attr.free_needs_iter():
2161            ri.cw.p('unsigned int i;')
2162            ri.cw.nl()
2163            break
2164
2165
2166def _free_type_members(ri, var, struct, ref=''):
2167    for _, attr in struct.member_list():
2168        attr.free(ri, var, ref)
2169
2170
2171def _free_type(ri, direction, struct):
2172    var = free_arg_name(direction)
2173
2174    print_free_prototype(ri, direction, suffix='')
2175    ri.cw.block_start()
2176    _free_type_members_iter(ri, struct)
2177    _free_type_members(ri, var, struct)
2178    if direction:
2179        ri.cw.p(f'free({var});')
2180    ri.cw.block_end()
2181    ri.cw.nl()
2182
2183
2184def free_rsp_nested_prototype(ri):
2185        print_free_prototype(ri, "")
2186
2187
2188def free_rsp_nested(ri, struct):
2189    _free_type(ri, "", struct)
2190
2191
2192def print_rsp_free(ri):
2193    if 'reply' not in ri.op[ri.op_mode]:
2194        return
2195    _free_type(ri, 'reply', ri.struct['reply'])
2196
2197
2198def print_dump_type_free(ri):
2199    sub_type = type_name(ri, 'reply')
2200
2201    print_free_prototype(ri, 'reply', suffix='')
2202    ri.cw.block_start()
2203    ri.cw.p(f"{sub_type} *next = rsp;")
2204    ri.cw.nl()
2205    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2206    _free_type_members_iter(ri, ri.struct['reply'])
2207    ri.cw.p('rsp = next;')
2208    ri.cw.p('next = rsp->next;')
2209    ri.cw.nl()
2210
2211    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2212    ri.cw.p(f'free(rsp);')
2213    ri.cw.block_end()
2214    ri.cw.block_end()
2215    ri.cw.nl()
2216
2217
2218def print_ntf_type_free(ri):
2219    print_free_prototype(ri, 'reply', suffix='')
2220    ri.cw.block_start()
2221    _free_type_members_iter(ri, ri.struct['reply'])
2222    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2223    ri.cw.p(f'free(rsp);')
2224    ri.cw.block_end()
2225    ri.cw.nl()
2226
2227
2228def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2229    if terminate and ri and policy_should_be_static(struct.family):
2230        return
2231
2232    if terminate:
2233        prefix = 'extern '
2234    else:
2235        if ri and policy_should_be_static(struct.family):
2236            prefix = 'static '
2237        else:
2238            prefix = ''
2239
2240    suffix = ';' if terminate else ' = {'
2241
2242    max_attr = struct.attr_max_val
2243    if ri:
2244        name = ri.op.render_name
2245        if ri.op.dual_policy:
2246            name += '_' + ri.op_mode
2247    else:
2248        name = struct.render_name
2249    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2250
2251
2252def print_req_policy(cw, struct, ri=None):
2253    if ri and ri.op:
2254        cw.ifdef_block(ri.op.get('config-cond', None))
2255    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2256    for _, arg in struct.member_list():
2257        arg.attr_policy(cw)
2258    cw.p("};")
2259    cw.ifdef_block(None)
2260    cw.nl()
2261
2262
2263def kernel_can_gen_family_struct(family):
2264    return family.proto == 'genetlink'
2265
2266
2267def policy_should_be_static(family):
2268    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2269
2270
2271def print_kernel_policy_ranges(family, cw):
2272    first = True
2273    for _, attr_set in family.attr_sets.items():
2274        if attr_set.subset_of:
2275            continue
2276
2277        for _, attr in attr_set.items():
2278            if not attr.request:
2279                continue
2280            if 'full-range' not in attr.checks:
2281                continue
2282
2283            if first:
2284                cw.p('/* Integer value ranges */')
2285                first = False
2286
2287            sign = '' if attr.type[0] == 'u' else '_signed'
2288            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2289            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2290            members = []
2291            if 'min' in attr.checks:
2292                members.append(('min', attr.get_limit_str('min', suffix=suffix)))
2293            if 'max' in attr.checks:
2294                members.append(('max', attr.get_limit_str('max', suffix=suffix)))
2295            cw.write_struct_init(members)
2296            cw.block_end(line=';')
2297            cw.nl()
2298
2299
2300def print_kernel_op_table_fwd(family, cw, terminate):
2301    exported = not kernel_can_gen_family_struct(family)
2302
2303    if not terminate or exported:
2304        cw.p(f"/* Ops table for {family.ident_name} */")
2305
2306        pol_to_struct = {'global': 'genl_small_ops',
2307                         'per-op': 'genl_ops',
2308                         'split': 'genl_split_ops'}
2309        struct_type = pol_to_struct[family.kernel_policy]
2310
2311        if not exported:
2312            cnt = ""
2313        elif family.kernel_policy == 'split':
2314            cnt = 0
2315            for op in family.ops.values():
2316                if 'do' in op:
2317                    cnt += 1
2318                if 'dump' in op:
2319                    cnt += 1
2320        else:
2321            cnt = len(family.ops)
2322
2323        qual = 'static const' if not exported else 'const'
2324        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2325        if terminate:
2326            cw.p(f"extern {line};")
2327        else:
2328            cw.block_start(line=line + ' =')
2329
2330    if not terminate:
2331        return
2332
2333    cw.nl()
2334    for name in family.hooks['pre']['do']['list']:
2335        cw.write_func_prot('int', c_lower(name),
2336                           ['const struct genl_split_ops *ops',
2337                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2338    for name in family.hooks['post']['do']['list']:
2339        cw.write_func_prot('void', c_lower(name),
2340                           ['const struct genl_split_ops *ops',
2341                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2342    for name in family.hooks['pre']['dump']['list']:
2343        cw.write_func_prot('int', c_lower(name),
2344                           ['struct netlink_callback *cb'], suffix=';')
2345    for name in family.hooks['post']['dump']['list']:
2346        cw.write_func_prot('int', c_lower(name),
2347                           ['struct netlink_callback *cb'], suffix=';')
2348
2349    cw.nl()
2350
2351    for op_name, op in family.ops.items():
2352        if op.is_async:
2353            continue
2354
2355        if 'do' in op:
2356            name = c_lower(f"{family.ident_name}-nl-{op_name}-doit")
2357            cw.write_func_prot('int', name,
2358                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2359
2360        if 'dump' in op:
2361            name = c_lower(f"{family.ident_name}-nl-{op_name}-dumpit")
2362            cw.write_func_prot('int', name,
2363                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2364    cw.nl()
2365
2366
2367def print_kernel_op_table_hdr(family, cw):
2368    print_kernel_op_table_fwd(family, cw, terminate=True)
2369
2370
2371def print_kernel_op_table(family, cw):
2372    print_kernel_op_table_fwd(family, cw, terminate=False)
2373    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2374        for op_name, op in family.ops.items():
2375            if op.is_async:
2376                continue
2377
2378            cw.ifdef_block(op.get('config-cond', None))
2379            cw.block_start()
2380            members = [('cmd', op.enum_name)]
2381            if 'dont-validate' in op:
2382                members.append(('validate',
2383                                ' | '.join([c_upper('genl-dont-validate-' + x)
2384                                            for x in op['dont-validate']])), )
2385            for op_mode in ['do', 'dump']:
2386                if op_mode in op:
2387                    name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2388                    members.append((op_mode + 'it', name))
2389            if family.kernel_policy == 'per-op':
2390                struct = Struct(family, op['attribute-set'],
2391                                type_list=op['do']['request']['attributes'])
2392
2393                name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2394                members.append(('policy', name))
2395                members.append(('maxattr', struct.attr_max_val.enum_name))
2396            if 'flags' in op:
2397                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2398            cw.write_struct_init(members)
2399            cw.block_end(line=',')
2400    elif family.kernel_policy == 'split':
2401        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2402                    'dump': {'pre': 'start', 'post': 'done'}}
2403
2404        for op_name, op in family.ops.items():
2405            for op_mode in ['do', 'dump']:
2406                if op.is_async or op_mode not in op:
2407                    continue
2408
2409                cw.ifdef_block(op.get('config-cond', None))
2410                cw.block_start()
2411                members = [('cmd', op.enum_name)]
2412                if 'dont-validate' in op:
2413                    dont_validate = []
2414                    for x in op['dont-validate']:
2415                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2416                            continue
2417                        if op_mode == "dump" and x == 'strict':
2418                            continue
2419                        dont_validate.append(x)
2420
2421                    if dont_validate:
2422                        members.append(('validate',
2423                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2424                                                    for x in dont_validate])), )
2425                name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2426                if 'pre' in op[op_mode]:
2427                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2428                members.append((op_mode + 'it', name))
2429                if 'post' in op[op_mode]:
2430                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2431                if 'request' in op[op_mode]:
2432                    struct = Struct(family, op['attribute-set'],
2433                                    type_list=op[op_mode]['request']['attributes'])
2434
2435                    if op.dual_policy:
2436                        name = c_lower(f"{family.ident_name}-{op_name}-{op_mode}-nl-policy")
2437                    else:
2438                        name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2439                    members.append(('policy', name))
2440                    members.append(('maxattr', struct.attr_max_val.enum_name))
2441                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2442                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2443                cw.write_struct_init(members)
2444                cw.block_end(line=',')
2445    cw.ifdef_block(None)
2446
2447    cw.block_end(line=';')
2448    cw.nl()
2449
2450
2451def print_kernel_mcgrp_hdr(family, cw):
2452    if not family.mcgrps['list']:
2453        return
2454
2455    cw.block_start('enum')
2456    for grp in family.mcgrps['list']:
2457        grp_id = c_upper(f"{family.ident_name}-nlgrp-{grp['name']},")
2458        cw.p(grp_id)
2459    cw.block_end(';')
2460    cw.nl()
2461
2462
2463def print_kernel_mcgrp_src(family, cw):
2464    if not family.mcgrps['list']:
2465        return
2466
2467    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
2468    for grp in family.mcgrps['list']:
2469        name = grp['name']
2470        grp_id = c_upper(f"{family.ident_name}-nlgrp-{name}")
2471        cw.p('[' + grp_id + '] = { "' + name + '", },')
2472    cw.block_end(';')
2473    cw.nl()
2474
2475
2476def print_kernel_family_struct_hdr(family, cw):
2477    if not kernel_can_gen_family_struct(family):
2478        return
2479
2480    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
2481    cw.nl()
2482    if 'sock-priv' in family.kernel_family:
2483        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
2484        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
2485        cw.nl()
2486
2487
2488def print_kernel_family_struct_src(family, cw):
2489    if not kernel_can_gen_family_struct(family):
2490        return
2491
2492    if 'sock-priv' in family.kernel_family:
2493        # Generate "trampolines" to make CFI happy
2494        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_init",
2495                      [f"{family.c_name}_nl_sock_priv_init(priv);"],
2496                      ["void *priv"])
2497        cw.nl()
2498        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_destroy",
2499                      [f"{family.c_name}_nl_sock_priv_destroy(priv);"],
2500                      ["void *priv"])
2501        cw.nl()
2502
2503    cw.block_start(f"struct genl_family {family.ident_name}_nl_family __ro_after_init =")
2504    cw.p('.name\t\t= ' + family.fam_key + ',')
2505    cw.p('.version\t= ' + family.ver_key + ',')
2506    cw.p('.netnsok\t= true,')
2507    cw.p('.parallel_ops\t= true,')
2508    cw.p('.module\t\t= THIS_MODULE,')
2509    if family.kernel_policy == 'per-op':
2510        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
2511        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2512    elif family.kernel_policy == 'split':
2513        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
2514        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2515    if family.mcgrps['list']:
2516        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
2517        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
2518    if 'sock-priv' in family.kernel_family:
2519        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
2520        cw.p(f'.sock_priv_init\t= __{family.c_name}_nl_sock_priv_init,')
2521        cw.p(f'.sock_priv_destroy = __{family.c_name}_nl_sock_priv_destroy,')
2522    cw.block_end(';')
2523
2524
2525def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2526    start_line = 'enum'
2527    if enum_name in obj:
2528        if obj[enum_name]:
2529            start_line = 'enum ' + c_lower(obj[enum_name])
2530    elif ckey and ckey in obj:
2531        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
2532    cw.block_start(line=start_line)
2533
2534
2535def render_uapi_unified(family, cw, max_by_define, separate_ntf):
2536    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2537    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2538    max_value = f"({cnt_name} - 1)"
2539
2540    uapi_enum_start(family, cw, family['operations'], 'enum-name')
2541    val = 0
2542    for op in family.msgs.values():
2543        if separate_ntf and ('notify' in op or 'event' in op):
2544            continue
2545
2546        suffix = ','
2547        if op.value != val:
2548            suffix = f" = {op.value},"
2549            val = op.value
2550        cw.p(op.enum_name + suffix)
2551        val += 1
2552    cw.nl()
2553    cw.p(cnt_name + ('' if max_by_define else ','))
2554    if not max_by_define:
2555        cw.p(f"{max_name} = {max_value}")
2556    cw.block_end(line=';')
2557    if max_by_define:
2558        cw.p(f"#define {max_name} {max_value}")
2559    cw.nl()
2560
2561
2562def render_uapi_directional(family, cw, max_by_define):
2563    max_name = f"{family.op_prefix}USER_MAX"
2564    cnt_name = f"__{family.op_prefix}USER_CNT"
2565    max_value = f"({cnt_name} - 1)"
2566
2567    cw.block_start(line='enum')
2568    cw.p(c_upper(f'{family.name}_MSG_USER_NONE = 0,'))
2569    val = 0
2570    for op in family.msgs.values():
2571        if 'do' in op and 'event' not in op:
2572            suffix = ','
2573            if op.value and op.value != val:
2574                suffix = f" = {op.value},"
2575                val = op.value
2576            cw.p(op.enum_name + suffix)
2577            val += 1
2578    cw.nl()
2579    cw.p(cnt_name + ('' if max_by_define else ','))
2580    if not max_by_define:
2581        cw.p(f"{max_name} = {max_value}")
2582    cw.block_end(line=';')
2583    if max_by_define:
2584        cw.p(f"#define {max_name} {max_value}")
2585    cw.nl()
2586
2587    max_name = f"{family.op_prefix}KERNEL_MAX"
2588    cnt_name = f"__{family.op_prefix}KERNEL_CNT"
2589    max_value = f"({cnt_name} - 1)"
2590
2591    cw.block_start(line='enum')
2592    cw.p(c_upper(f'{family.name}_MSG_KERNEL_NONE = 0,'))
2593    val = 0
2594    for op in family.msgs.values():
2595        if ('do' in op and 'reply' in op['do']) or 'notify' in op or 'event' in op:
2596            enum_name = op.enum_name
2597            if 'event' not in op and 'notify' not in op:
2598                enum_name = f'{enum_name}_REPLY'
2599
2600            suffix = ','
2601            if op.value and op.value != val:
2602                suffix = f" = {op.value},"
2603                val = op.value
2604            cw.p(enum_name + suffix)
2605            val += 1
2606    cw.nl()
2607    cw.p(cnt_name + ('' if max_by_define else ','))
2608    if not max_by_define:
2609        cw.p(f"{max_name} = {max_value}")
2610    cw.block_end(line=';')
2611    if max_by_define:
2612        cw.p(f"#define {max_name} {max_value}")
2613    cw.nl()
2614
2615
2616def render_uapi(family, cw):
2617    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
2618    hdr_prot = hdr_prot.replace('/', '_')
2619    cw.p('#ifndef ' + hdr_prot)
2620    cw.p('#define ' + hdr_prot)
2621    cw.nl()
2622
2623    defines = [(family.fam_key, family["name"]),
2624               (family.ver_key, family.get('version', 1))]
2625    cw.writes_defines(defines)
2626    cw.nl()
2627
2628    defines = []
2629    for const in family['definitions']:
2630        if const.get('header'):
2631            continue
2632
2633        if const['type'] != 'const':
2634            cw.writes_defines(defines)
2635            defines = []
2636            cw.nl()
2637
2638        # Write kdoc for enum and flags (one day maybe also structs)
2639        if const['type'] == 'enum' or const['type'] == 'flags':
2640            enum = family.consts[const['name']]
2641
2642            if enum.header:
2643                continue
2644
2645            if enum.has_doc():
2646                if enum.has_entry_doc():
2647                    cw.p('/**')
2648                    doc = ''
2649                    if 'doc' in enum:
2650                        doc = ' - ' + enum['doc']
2651                    cw.write_doc_line(enum.enum_name + doc)
2652                else:
2653                    cw.p('/*')
2654                    cw.write_doc_line(enum['doc'], indent=False)
2655                for entry in enum.entries.values():
2656                    if entry.has_doc():
2657                        doc = '@' + entry.c_name + ': ' + entry['doc']
2658                        cw.write_doc_line(doc)
2659                cw.p(' */')
2660
2661            uapi_enum_start(family, cw, const, 'name')
2662            name_pfx = const.get('name-prefix', f"{family.ident_name}-{const['name']}-")
2663            for entry in enum.entries.values():
2664                suffix = ','
2665                if entry.value_change:
2666                    suffix = f" = {entry.user_value()}" + suffix
2667                cw.p(entry.c_name + suffix)
2668
2669            if const.get('render-max', False):
2670                cw.nl()
2671                cw.p('/* private: */')
2672                if const['type'] == 'flags':
2673                    max_name = c_upper(name_pfx + 'mask')
2674                    max_val = f' = {enum.get_mask()},'
2675                    cw.p(max_name + max_val)
2676                else:
2677                    cnt_name = enum.enum_cnt_name
2678                    max_name = c_upper(name_pfx + 'max')
2679                    if not cnt_name:
2680                        cnt_name = '__' + name_pfx + 'max'
2681                    cw.p(c_upper(cnt_name) + ',')
2682                    cw.p(max_name + ' = (' + c_upper(cnt_name) + ' - 1)')
2683            cw.block_end(line=';')
2684            cw.nl()
2685        elif const['type'] == 'const':
2686            defines.append([c_upper(family.get('c-define-name',
2687                                               f"{family.ident_name}-{const['name']}")),
2688                            const['value']])
2689
2690    if defines:
2691        cw.writes_defines(defines)
2692        cw.nl()
2693
2694    max_by_define = family.get('max-by-define', False)
2695
2696    for _, attr_set in family.attr_sets.items():
2697        if attr_set.subset_of:
2698            continue
2699
2700        max_value = f"({attr_set.cnt_name} - 1)"
2701
2702        val = 0
2703        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2704        for _, attr in attr_set.items():
2705            suffix = ','
2706            if attr.value != val:
2707                suffix = f" = {attr.value},"
2708                val = attr.value
2709            val += 1
2710            cw.p(attr.enum_name + suffix)
2711        if attr_set.items():
2712            cw.nl()
2713        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
2714        if not max_by_define:
2715            cw.p(f"{attr_set.max_name} = {max_value}")
2716        cw.block_end(line=';')
2717        if max_by_define:
2718            cw.p(f"#define {attr_set.max_name} {max_value}")
2719        cw.nl()
2720
2721    # Commands
2722    separate_ntf = 'async-prefix' in family['operations']
2723
2724    if family.msg_id_model == 'unified':
2725        render_uapi_unified(family, cw, max_by_define, separate_ntf)
2726    elif family.msg_id_model == 'directional':
2727        render_uapi_directional(family, cw, max_by_define)
2728    else:
2729        raise Exception(f'Unsupported message enum-model {family.msg_id_model}')
2730
2731    if separate_ntf:
2732        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2733        for op in family.msgs.values():
2734            if separate_ntf and not ('notify' in op or 'event' in op):
2735                continue
2736
2737            suffix = ','
2738            if 'value' in op:
2739                suffix = f" = {op['value']},"
2740            cw.p(op.enum_name + suffix)
2741        cw.block_end(line=';')
2742        cw.nl()
2743
2744    # Multicast
2745    defines = []
2746    for grp in family.mcgrps['list']:
2747        name = grp['name']
2748        defines.append([c_upper(grp.get('c-define-name', f"{family.ident_name}-mcgrp-{name}")),
2749                        f'{name}'])
2750    cw.nl()
2751    if defines:
2752        cw.writes_defines(defines)
2753        cw.nl()
2754
2755    cw.p(f'#endif /* {hdr_prot} */')
2756
2757
2758def _render_user_ntf_entry(ri, op):
2759    ri.cw.block_start(line=f"[{op.enum_name}] = ")
2760    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
2761    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
2762    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
2763    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
2764    ri.cw.block_end(line=',')
2765
2766
2767def render_user_family(family, cw, prototype):
2768    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
2769    if prototype:
2770        cw.p(f'extern {symbol};')
2771        return
2772
2773    if family.ntfs:
2774        cw.block_start(line=f"static const struct ynl_ntf_info {family.c_name}_ntf_info[] = ")
2775        for ntf_op_name, ntf_op in family.ntfs.items():
2776            if 'notify' in ntf_op:
2777                op = family.ops[ntf_op['notify']]
2778                ri = RenderInfo(cw, family, "user", op, "notify")
2779            elif 'event' in ntf_op:
2780                ri = RenderInfo(cw, family, "user", ntf_op, "event")
2781            else:
2782                raise Exception('Invalid notification ' + ntf_op_name)
2783            _render_user_ntf_entry(ri, ntf_op)
2784        for op_name, op in family.ops.items():
2785            if 'event' not in op:
2786                continue
2787            ri = RenderInfo(cw, family, "user", op, "event")
2788            _render_user_ntf_entry(ri, op)
2789        cw.block_end(line=";")
2790        cw.nl()
2791
2792    cw.block_start(f'{symbol} = ')
2793    cw.p(f'.name\t\t= "{family.c_name}",')
2794    if family.is_classic():
2795        cw.p(f'.is_classic\t= true,')
2796        cw.p(f'.classic_id\t= {family.get("protonum")},')
2797    if family.is_classic():
2798        cw.p(f'.hdr_len\t= sizeof(struct {c_lower(family.fixed_header)}),')
2799    elif family.fixed_header:
2800        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
2801    else:
2802        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
2803    if family.ntfs:
2804        cw.p(f".ntf_info\t= {family.c_name}_ntf_info,")
2805        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family.c_name}_ntf_info),")
2806    cw.block_end(line=';')
2807
2808
2809def family_contains_bitfield32(family):
2810    for _, attr_set in family.attr_sets.items():
2811        if attr_set.subset_of:
2812            continue
2813        for _, attr in attr_set.items():
2814            if attr.type == "bitfield32":
2815                return True
2816    return False
2817
2818
2819def find_kernel_root(full_path):
2820    sub_path = ''
2821    while True:
2822        sub_path = os.path.join(os.path.basename(full_path), sub_path)
2823        full_path = os.path.dirname(full_path)
2824        maintainers = os.path.join(full_path, "MAINTAINERS")
2825        if os.path.exists(maintainers):
2826            return full_path, sub_path[:-1]
2827
2828
2829def main():
2830    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2831    parser.add_argument('--mode', dest='mode', type=str, required=True,
2832                        choices=('user', 'kernel', 'uapi'))
2833    parser.add_argument('--spec', dest='spec', type=str, required=True)
2834    parser.add_argument('--header', dest='header', action='store_true', default=None)
2835    parser.add_argument('--source', dest='header', action='store_false')
2836    parser.add_argument('--user-header', nargs='+', default=[])
2837    parser.add_argument('--cmp-out', action='store_true', default=None,
2838                        help='Do not overwrite the output file if the new output is identical to the old')
2839    parser.add_argument('--exclude-op', action='append', default=[])
2840    parser.add_argument('-o', dest='out_file', type=str, default=None)
2841    args = parser.parse_args()
2842
2843    if args.header is None:
2844        parser.error("--header or --source is required")
2845
2846    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
2847
2848    try:
2849        parsed = Family(args.spec, exclude_ops)
2850        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
2851            print('Spec license:', parsed.license)
2852            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
2853            os.sys.exit(1)
2854    except yaml.YAMLError as exc:
2855        print(exc)
2856        os.sys.exit(1)
2857        return
2858
2859    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
2860
2861    _, spec_kernel = find_kernel_root(args.spec)
2862    if args.mode == 'uapi' or args.header:
2863        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
2864    else:
2865        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
2866    cw.p("/* Do not edit directly, auto-generated from: */")
2867    cw.p(f"/*\t{spec_kernel} */")
2868    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2869    if args.exclude_op or args.user_header:
2870        line = ''
2871        line += ' --user-header '.join([''] + args.user_header)
2872        line += ' --exclude-op '.join([''] + args.exclude_op)
2873        cw.p(f'/* YNL-ARG{line} */')
2874    cw.nl()
2875
2876    if args.mode == 'uapi':
2877        render_uapi(parsed, cw)
2878        return
2879
2880    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
2881    if args.header:
2882        cw.p('#ifndef ' + hdr_prot)
2883        cw.p('#define ' + hdr_prot)
2884        cw.nl()
2885
2886    if args.out_file:
2887        hdr_file = os.path.basename(args.out_file[:-2]) + ".h"
2888    else:
2889        hdr_file = "generated_header_file.h"
2890
2891    if args.mode == 'kernel':
2892        cw.p('#include <net/netlink.h>')
2893        cw.p('#include <net/genetlink.h>')
2894        cw.nl()
2895        if not args.header:
2896            if args.out_file:
2897                cw.p(f'#include "{hdr_file}"')
2898            cw.nl()
2899        headers = ['uapi/' + parsed.uapi_header]
2900        headers += parsed.kernel_family.get('headers', [])
2901    else:
2902        cw.p('#include <stdlib.h>')
2903        cw.p('#include <string.h>')
2904        if args.header:
2905            cw.p('#include <linux/types.h>')
2906            if family_contains_bitfield32(parsed):
2907                cw.p('#include <linux/netlink.h>')
2908        else:
2909            cw.p(f'#include "{hdr_file}"')
2910            cw.p('#include "ynl.h"')
2911        headers = []
2912    for definition in parsed['definitions'] + parsed['attribute-sets']:
2913        if 'header' in definition:
2914            headers.append(definition['header'])
2915    if args.mode == 'user':
2916        headers.append(parsed.uapi_header)
2917    seen_header = []
2918    for one in headers:
2919        if one not in seen_header:
2920            cw.p(f"#include <{one}>")
2921            seen_header.append(one)
2922    cw.nl()
2923
2924    if args.mode == "user":
2925        if not args.header:
2926            cw.p("#include <linux/genetlink.h>")
2927            cw.nl()
2928            for one in args.user_header:
2929                cw.p(f'#include "{one}"')
2930        else:
2931            cw.p('struct ynl_sock;')
2932            cw.nl()
2933            render_user_family(parsed, cw, True)
2934        cw.nl()
2935
2936    if args.mode == "kernel":
2937        if args.header:
2938            for _, struct in sorted(parsed.pure_nested_structs.items()):
2939                if struct.request:
2940                    cw.p('/* Common nested types */')
2941                    break
2942            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2943                if struct.request:
2944                    print_req_policy_fwd(cw, struct)
2945            cw.nl()
2946
2947            if parsed.kernel_policy == 'global':
2948                cw.p(f"/* Global operation policy for {parsed.name} */")
2949
2950                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2951                print_req_policy_fwd(cw, struct)
2952                cw.nl()
2953
2954            if parsed.kernel_policy in {'per-op', 'split'}:
2955                for op_name, op in parsed.ops.items():
2956                    if 'do' in op and 'event' not in op:
2957                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
2958                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2959                        cw.nl()
2960
2961            print_kernel_op_table_hdr(parsed, cw)
2962            print_kernel_mcgrp_hdr(parsed, cw)
2963            print_kernel_family_struct_hdr(parsed, cw)
2964        else:
2965            print_kernel_policy_ranges(parsed, cw)
2966
2967            for _, struct in sorted(parsed.pure_nested_structs.items()):
2968                if struct.request:
2969                    cw.p('/* Common nested types */')
2970                    break
2971            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2972                if struct.request:
2973                    print_req_policy(cw, struct)
2974            cw.nl()
2975
2976            if parsed.kernel_policy == 'global':
2977                cw.p(f"/* Global operation policy for {parsed.name} */")
2978
2979                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2980                print_req_policy(cw, struct)
2981                cw.nl()
2982
2983            for op_name, op in parsed.ops.items():
2984                if parsed.kernel_policy in {'per-op', 'split'}:
2985                    for op_mode in ['do', 'dump']:
2986                        if op_mode in op and 'request' in op[op_mode]:
2987                            cw.p(f"/* {op.enum_name} - {op_mode} */")
2988                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
2989                            print_req_policy(cw, ri.struct['request'], ri=ri)
2990                            cw.nl()
2991
2992            print_kernel_op_table(parsed, cw)
2993            print_kernel_mcgrp_src(parsed, cw)
2994            print_kernel_family_struct_src(parsed, cw)
2995
2996    if args.mode == "user":
2997        if args.header:
2998            cw.p('/* Enums */')
2999            put_op_name_fwd(parsed, cw)
3000
3001            for name, const in parsed.consts.items():
3002                if isinstance(const, EnumSet):
3003                    put_enum_to_str_fwd(parsed, cw, const)
3004            cw.nl()
3005
3006            cw.p('/* Common nested types */')
3007            for attr_set, struct in parsed.pure_nested_structs.items():
3008                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3009                print_type_full(ri, struct)
3010                if struct.request and struct.in_multi_val:
3011                    free_rsp_nested_prototype(ri)
3012                    cw.nl()
3013
3014            for op_name, op in parsed.ops.items():
3015                cw.p(f"/* ============== {op.enum_name} ============== */")
3016
3017                if 'do' in op and 'event' not in op:
3018                    cw.p(f"/* {op.enum_name} - do */")
3019                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3020                    print_req_type(ri)
3021                    print_req_type_helpers(ri)
3022                    cw.nl()
3023                    print_rsp_type(ri)
3024                    print_rsp_type_helpers(ri)
3025                    cw.nl()
3026                    print_req_prototype(ri)
3027                    cw.nl()
3028
3029                if 'dump' in op:
3030                    cw.p(f"/* {op.enum_name} - dump */")
3031                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
3032                    print_req_type(ri)
3033                    print_req_type_helpers(ri)
3034                    if not ri.type_consistent or ri.type_oneside:
3035                        print_rsp_type(ri)
3036                    print_wrapped_type(ri)
3037                    print_dump_prototype(ri)
3038                    cw.nl()
3039
3040                if op.has_ntf:
3041                    cw.p(f"/* {op.enum_name} - notify */")
3042                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3043                    if not ri.type_consistent:
3044                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3045                    print_wrapped_type(ri)
3046
3047            for op_name, op in parsed.ntfs.items():
3048                if 'event' in op:
3049                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
3050                    cw.p(f"/* {op.enum_name} - event */")
3051                    print_rsp_type(ri)
3052                    cw.nl()
3053                    print_wrapped_type(ri)
3054            cw.nl()
3055        else:
3056            cw.p('/* Enums */')
3057            put_op_name(parsed, cw)
3058
3059            for name, const in parsed.consts.items():
3060                if isinstance(const, EnumSet):
3061                    put_enum_to_str(parsed, cw, const)
3062            cw.nl()
3063
3064            has_recursive_nests = False
3065            cw.p('/* Policies */')
3066            for struct in parsed.pure_nested_structs.values():
3067                if struct.recursive:
3068                    put_typol_fwd(cw, struct)
3069                    has_recursive_nests = True
3070            if has_recursive_nests:
3071                cw.nl()
3072            for name in parsed.pure_nested_structs:
3073                struct = Struct(parsed, name)
3074                put_typol(cw, struct)
3075            for name in parsed.root_sets:
3076                struct = Struct(parsed, name)
3077                put_typol(cw, struct)
3078
3079            cw.p('/* Common nested types */')
3080            if has_recursive_nests:
3081                for attr_set, struct in parsed.pure_nested_structs.items():
3082                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3083                    free_rsp_nested_prototype(ri)
3084                    if struct.request:
3085                        put_req_nested_prototype(ri, struct)
3086                    if struct.reply:
3087                        parse_rsp_nested_prototype(ri, struct)
3088                cw.nl()
3089            for attr_set, struct in parsed.pure_nested_structs.items():
3090                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3091
3092                free_rsp_nested(ri, struct)
3093                if struct.request:
3094                    put_req_nested(ri, struct)
3095                if struct.reply:
3096                    parse_rsp_nested(ri, struct)
3097
3098            for op_name, op in parsed.ops.items():
3099                cw.p(f"/* ============== {op.enum_name} ============== */")
3100                if 'do' in op and 'event' not in op:
3101                    cw.p(f"/* {op.enum_name} - do */")
3102                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3103                    print_req_free(ri)
3104                    print_rsp_free(ri)
3105                    parse_rsp_msg(ri)
3106                    print_req(ri)
3107                    cw.nl()
3108
3109                if 'dump' in op:
3110                    cw.p(f"/* {op.enum_name} - dump */")
3111                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
3112                    if not ri.type_consistent or ri.type_oneside:
3113                        parse_rsp_msg(ri, deref=True)
3114                    print_req_free(ri)
3115                    print_dump_type_free(ri)
3116                    print_dump(ri)
3117                    cw.nl()
3118
3119                if op.has_ntf:
3120                    cw.p(f"/* {op.enum_name} - notify */")
3121                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3122                    if not ri.type_consistent:
3123                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3124                    print_ntf_type_free(ri)
3125
3126            for op_name, op in parsed.ntfs.items():
3127                if 'event' in op:
3128                    cw.p(f"/* {op.enum_name} - event */")
3129
3130                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3131                    parse_rsp_msg(ri)
3132
3133                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
3134                    print_ntf_type_free(ri)
3135            cw.nl()
3136            render_user_family(parsed, cw, False)
3137
3138    if args.header:
3139        cw.p(f'#endif /* {hdr_prot} */')
3140
3141
3142if __name__ == "__main__":
3143    main()
3144