xref: /linux/tools/net/ynl/pyynl/ynl_gen_c.py (revision b5c6891b2c5b54bf58069966296917da46cda6f2)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import filecmp
7import pathlib
8import os
9import re
10import shutil
11import sys
12import tempfile
13import yaml
14
15sys.path.append(pathlib.Path(__file__).resolve().parent.as_posix())
16from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
17
18
19def c_upper(name):
20    return name.upper().replace('-', '_')
21
22
23def c_lower(name):
24    return name.lower().replace('-', '_')
25
26
27def limit_to_number(name):
28    """
29    Turn a string limit like u32-max or s64-min into its numerical value
30    """
31    if name[0] == 'u' and name.endswith('-min'):
32        return 0
33    width = int(name[1:-4])
34    if name[0] == 's':
35        width -= 1
36    value = (1 << width) - 1
37    if name[0] == 's' and name.endswith('-min'):
38        value = -value - 1
39    return value
40
41
42class BaseNlLib:
43    def get_family_id(self):
44        return 'ys->family_id'
45
46
47class Type(SpecAttr):
48    def __init__(self, family, attr_set, attr, value):
49        super().__init__(family, attr_set, attr, value)
50
51        self.attr = attr
52        self.attr_set = attr_set
53        self.type = attr['type']
54        self.checks = attr.get('checks', {})
55
56        self.request = False
57        self.reply = False
58
59        if 'len' in attr:
60            self.len = attr['len']
61
62        if 'nested-attributes' in attr:
63            self.nested_attrs = attr['nested-attributes']
64            if self.nested_attrs == family.name:
65                self.nested_render_name = c_lower(f"{family.ident_name}")
66            else:
67                self.nested_render_name = c_lower(f"{family.ident_name}_{self.nested_attrs}")
68
69            if self.nested_attrs in self.family.consts:
70                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
71            else:
72                self.nested_struct_type = 'struct ' + self.nested_render_name
73
74        self.c_name = c_lower(self.name)
75        if self.c_name in _C_KW:
76            self.c_name += '_'
77        if self.c_name[0].isdigit():
78            self.c_name = '_' + self.c_name
79
80        # Added by resolve():
81        self.enum_name = None
82        delattr(self, "enum_name")
83
84    def _get_real_attr(self):
85        # if the attr is for a subset return the "real" attr (just one down, does not recurse)
86        return self.family.attr_sets[self.attr_set.subset_of][self.name]
87
88    def set_request(self):
89        self.request = True
90        if self.attr_set.subset_of:
91            self._get_real_attr().set_request()
92
93    def set_reply(self):
94        self.reply = True
95        if self.attr_set.subset_of:
96            self._get_real_attr().set_reply()
97
98    def get_limit(self, limit, default=None):
99        value = self.checks.get(limit, default)
100        if value is None:
101            return value
102        if isinstance(value, int):
103            return value
104        if value in self.family.consts:
105            return self.family.consts[value]["value"]
106        return limit_to_number(value)
107
108    def get_limit_str(self, limit, default=None, suffix=''):
109        value = self.checks.get(limit, default)
110        if value is None:
111            return ''
112        if isinstance(value, int):
113            return str(value) + suffix
114        if value in self.family.consts:
115            const = self.family.consts[value]
116            if const.get('header'):
117                return c_upper(value)
118            return c_upper(f"{self.family['name']}-{value}")
119        return c_upper(value)
120
121    def resolve(self):
122        if 'name-prefix' in self.attr:
123            enum_name = f"{self.attr['name-prefix']}{self.name}"
124        else:
125            enum_name = f"{self.attr_set.name_prefix}{self.name}"
126        self.enum_name = c_upper(enum_name)
127
128        if self.attr_set.subset_of:
129            if self.checks != self._get_real_attr().checks:
130                raise Exception("Overriding checks not supported by codegen, yet")
131
132    def is_multi_val(self):
133        return None
134
135    def is_scalar(self):
136        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
137
138    def is_recursive(self):
139        return False
140
141    def is_recursive_for_op(self, ri):
142        return self.is_recursive() and not ri.op
143
144    def presence_type(self):
145        return 'bit'
146
147    def presence_member(self, space, type_filter):
148        if self.presence_type() != type_filter:
149            return
150
151        if self.presence_type() == 'bit':
152            pfx = '__' if space == 'user' else ''
153            return f"{pfx}u32 {self.c_name}:1;"
154
155        if self.presence_type() == 'len':
156            pfx = '__' if space == 'user' else ''
157            return f"{pfx}u32 {self.c_name}_len;"
158
159    def _complex_member_type(self, ri):
160        return None
161
162    def free_needs_iter(self):
163        return False
164
165    def _free_lines(self, ri, var, ref):
166        if self.is_multi_val() or self.presence_type() == 'len':
167            return [f'free({var}->{ref}{self.c_name});']
168        return []
169
170    def free(self, ri, var, ref):
171        lines = self._free_lines(ri, var, ref)
172        for line in lines:
173            ri.cw.p(line)
174
175    def arg_member(self, ri):
176        member = self._complex_member_type(ri)
177        if member:
178            arg = [member + ' *' + self.c_name]
179            if self.presence_type() == 'count':
180                arg += ['unsigned int n_' + self.c_name]
181            return arg
182        raise Exception(f"Struct member not implemented for class type {self.type}")
183
184    def struct_member(self, ri):
185        if self.is_multi_val():
186            ri.cw.p(f"unsigned int n_{self.c_name};")
187        member = self._complex_member_type(ri)
188        if member:
189            ptr = '*' if self.is_multi_val() else ''
190            if self.is_recursive_for_op(ri):
191                ptr = '*'
192            ri.cw.p(f"{member} {ptr}{self.c_name};")
193            return
194        members = self.arg_member(ri)
195        for one in members:
196            ri.cw.p(one + ';')
197
198    def _attr_policy(self, policy):
199        return '{ .type = ' + policy + ', }'
200
201    def attr_policy(self, cw):
202        policy = f'NLA_{c_upper(self.type)}'
203        if self.attr.get('byte-order') == 'big-endian':
204            if self.type in {'u16', 'u32'}:
205                policy = f'NLA_BE{self.type[1:]}'
206
207        spec = self._attr_policy(policy)
208        cw.p(f"\t[{self.enum_name}] = {spec},")
209
210    def _attr_typol(self):
211        raise Exception(f"Type policy not implemented for class type {self.type}")
212
213    def attr_typol(self, cw):
214        typol = self._attr_typol()
215        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
216
217    def _attr_put_line(self, ri, var, line):
218        if self.presence_type() == 'bit':
219            ri.cw.p(f"if ({var}->_present.{self.c_name})")
220        elif self.presence_type() == 'len':
221            ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
222        ri.cw.p(f"{line};")
223
224    def _attr_put_simple(self, ri, var, put_type):
225        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
226        self._attr_put_line(ri, var, line)
227
228    def attr_put(self, ri, var):
229        raise Exception(f"Put not implemented for class type {self.type}")
230
231    def _attr_get(self, ri, var):
232        raise Exception(f"Attr get not implemented for class type {self.type}")
233
234    def attr_get(self, ri, var, first):
235        lines, init_lines, local_vars = self._attr_get(ri, var)
236        if type(lines) is str:
237            lines = [lines]
238        if type(init_lines) is str:
239            init_lines = [init_lines]
240
241        kw = 'if' if first else 'else if'
242        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
243        if local_vars:
244            for local in local_vars:
245                ri.cw.p(local)
246            ri.cw.nl()
247
248        if not self.is_multi_val():
249            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
250            ri.cw.p("return YNL_PARSE_CB_ERROR;")
251            if self.presence_type() == 'bit':
252                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
253
254        if init_lines:
255            ri.cw.nl()
256            for line in init_lines:
257                ri.cw.p(line)
258
259        for line in lines:
260            ri.cw.p(line)
261        ri.cw.block_end()
262        return True
263
264    def _setter_lines(self, ri, member, presence):
265        raise Exception(f"Setter not implemented for class type {self.type}")
266
267    def setter(self, ri, space, direction, deref=False, ref=None):
268        ref = (ref if ref else []) + [self.c_name]
269        var = "req"
270        member = f"{var}->{'.'.join(ref)}"
271
272        local_vars = []
273        if self.free_needs_iter():
274            local_vars += ['unsigned int i;']
275
276        code = []
277        presence = ''
278        for i in range(0, len(ref)):
279            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
280            # Every layer below last is a nest, so we know it uses bit presence
281            # last layer is "self" and may be a complex type
282            if i == len(ref) - 1 and self.presence_type() != 'bit':
283                continue
284            code.append(presence + ' = 1;')
285        ref_path = '.'.join(ref[:-1])
286        if ref_path:
287            ref_path += '.'
288        code += self._free_lines(ri, var, ref_path)
289        code += self._setter_lines(ri, member, presence)
290
291        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
292        free = bool([x for x in code if 'free(' in x])
293        alloc = bool([x for x in code if 'alloc(' in x])
294        if free and not alloc:
295            func_name = '__' + func_name
296        ri.cw.write_func('static inline void', func_name, local_vars=local_vars,
297                         body=code,
298                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
299
300
301class TypeUnused(Type):
302    def presence_type(self):
303        return ''
304
305    def arg_member(self, ri):
306        return []
307
308    def _attr_get(self, ri, var):
309        return ['return YNL_PARSE_CB_ERROR;'], None, None
310
311    def _attr_typol(self):
312        return '.type = YNL_PT_REJECT, '
313
314    def attr_policy(self, cw):
315        pass
316
317    def attr_put(self, ri, var):
318        pass
319
320    def attr_get(self, ri, var, first):
321        pass
322
323    def setter(self, ri, space, direction, deref=False, ref=None):
324        pass
325
326
327class TypePad(Type):
328    def presence_type(self):
329        return ''
330
331    def arg_member(self, ri):
332        return []
333
334    def _attr_typol(self):
335        return '.type = YNL_PT_IGNORE, '
336
337    def attr_put(self, ri, var):
338        pass
339
340    def attr_get(self, ri, var, first):
341        pass
342
343    def attr_policy(self, cw):
344        pass
345
346    def setter(self, ri, space, direction, deref=False, ref=None):
347        pass
348
349
350class TypeScalar(Type):
351    def __init__(self, family, attr_set, attr, value):
352        super().__init__(family, attr_set, attr, value)
353
354        self.byte_order_comment = ''
355        if 'byte-order' in attr:
356            self.byte_order_comment = f" /* {attr['byte-order']} */"
357
358        if 'enum' in self.attr:
359            enum = self.family.consts[self.attr['enum']]
360            low, high = enum.value_range()
361            if 'min' not in self.checks:
362                if low != 0 or self.type[0] == 's':
363                    self.checks['min'] = low
364            if 'max' not in self.checks:
365                self.checks['max'] = high
366
367        if 'min' in self.checks and 'max' in self.checks:
368            if self.get_limit('min') > self.get_limit('max'):
369                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
370            self.checks['range'] = True
371
372        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
373        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
374        if low < 0 and self.type[0] == 'u':
375            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
376        if low < -32768 or high > 32767:
377            self.checks['full-range'] = True
378
379        # Added by resolve():
380        self.is_bitfield = None
381        delattr(self, "is_bitfield")
382        self.type_name = None
383        delattr(self, "type_name")
384
385    def resolve(self):
386        self.resolve_up(super())
387
388        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
389            self.is_bitfield = True
390        elif 'enum' in self.attr:
391            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
392        else:
393            self.is_bitfield = False
394
395        if not self.is_bitfield and 'enum' in self.attr:
396            self.type_name = self.family.consts[self.attr['enum']].user_type
397        elif self.is_auto_scalar:
398            self.type_name = '__' + self.type[0] + '64'
399        else:
400            self.type_name = '__' + self.type
401
402    def _attr_policy(self, policy):
403        if 'flags-mask' in self.checks or self.is_bitfield:
404            if self.is_bitfield:
405                enum = self.family.consts[self.attr['enum']]
406                mask = enum.get_mask(as_flags=True)
407            else:
408                flags = self.family.consts[self.checks['flags-mask']]
409                flag_cnt = len(flags['entries'])
410                mask = (1 << flag_cnt) - 1
411            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
412        elif 'full-range' in self.checks:
413            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
414        elif 'range' in self.checks:
415            return f"NLA_POLICY_RANGE({policy}, {self.get_limit_str('min')}, {self.get_limit_str('max')})"
416        elif 'min' in self.checks:
417            return f"NLA_POLICY_MIN({policy}, {self.get_limit_str('min')})"
418        elif 'max' in self.checks:
419            return f"NLA_POLICY_MAX({policy}, {self.get_limit_str('max')})"
420        return super()._attr_policy(policy)
421
422    def _attr_typol(self):
423        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
424
425    def arg_member(self, ri):
426        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
427
428    def attr_put(self, ri, var):
429        self._attr_put_simple(ri, var, self.type)
430
431    def _attr_get(self, ri, var):
432        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
433
434    def _setter_lines(self, ri, member, presence):
435        return [f"{member} = {self.c_name};"]
436
437
438class TypeFlag(Type):
439    def arg_member(self, ri):
440        return []
441
442    def _attr_typol(self):
443        return '.type = YNL_PT_FLAG, '
444
445    def attr_put(self, ri, var):
446        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
447
448    def _attr_get(self, ri, var):
449        return [], None, None
450
451    def _setter_lines(self, ri, member, presence):
452        return []
453
454
455class TypeString(Type):
456    def arg_member(self, ri):
457        return [f"const char *{self.c_name}"]
458
459    def presence_type(self):
460        return 'len'
461
462    def struct_member(self, ri):
463        ri.cw.p(f"char *{self.c_name};")
464
465    def _attr_typol(self):
466        return f'.type = YNL_PT_NUL_STR, '
467
468    def _attr_policy(self, policy):
469        if 'exact-len' in self.checks:
470            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
471        else:
472            mem = '{ .type = ' + policy
473            if 'max-len' in self.checks:
474                mem += ', .len = ' + self.get_limit_str('max-len')
475            mem += ', }'
476        return mem
477
478    def attr_policy(self, cw):
479        if self.checks.get('unterminated-ok', False):
480            policy = 'NLA_STRING'
481        else:
482            policy = 'NLA_NUL_STRING'
483
484        spec = self._attr_policy(policy)
485        cw.p(f"\t[{self.enum_name}] = {spec},")
486
487    def attr_put(self, ri, var):
488        self._attr_put_simple(ri, var, 'str')
489
490    def _attr_get(self, ri, var):
491        len_mem = var + '->_present.' + self.c_name + '_len'
492        return [f"{len_mem} = len;",
493                f"{var}->{self.c_name} = malloc(len + 1);",
494                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
495                f"{var}->{self.c_name}[len] = 0;"], \
496               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
497               ['unsigned int len;']
498
499    def _setter_lines(self, ri, member, presence):
500        return [f"{presence}_len = strlen({self.c_name});",
501                f"{member} = malloc({presence}_len + 1);",
502                f'memcpy({member}, {self.c_name}, {presence}_len);',
503                f'{member}[{presence}_len] = 0;']
504
505
506class TypeBinary(Type):
507    def arg_member(self, ri):
508        return [f"const void *{self.c_name}", 'size_t len']
509
510    def presence_type(self):
511        return 'len'
512
513    def struct_member(self, ri):
514        ri.cw.p(f"void *{self.c_name};")
515
516    def _attr_typol(self):
517        return f'.type = YNL_PT_BINARY,'
518
519    def _attr_policy(self, policy):
520        if len(self.checks) == 0:
521            pass
522        elif len(self.checks) == 1:
523            check_name = list(self.checks)[0]
524            if check_name not in {'exact-len', 'min-len', 'max-len'}:
525                raise Exception('Unsupported check for binary type: ' + check_name)
526        else:
527            raise Exception('More than one check for binary type not implemented, yet')
528
529        if len(self.checks) == 0:
530            mem = '{ .type = NLA_BINARY, }'
531        elif 'exact-len' in self.checks:
532            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
533        elif 'min-len' in self.checks:
534            mem = '{ .len = ' + self.get_limit_str('min-len') + ', }'
535        elif 'max-len' in self.checks:
536            mem = 'NLA_POLICY_MAX_LEN(' + self.get_limit_str('max-len') + ')'
537
538        return mem
539
540    def attr_put(self, ri, var):
541        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
542                            f"{var}->{self.c_name}, {var}->_present.{self.c_name}_len)")
543
544    def _attr_get(self, ri, var):
545        len_mem = var + '->_present.' + self.c_name + '_len'
546        return [f"{len_mem} = len;",
547                f"{var}->{self.c_name} = malloc(len);",
548                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
549               ['len = ynl_attr_data_len(attr);'], \
550               ['unsigned int len;']
551
552    def _setter_lines(self, ri, member, presence):
553        return [f"{presence}_len = len;",
554                f"{member} = malloc({presence}_len);",
555                f'memcpy({member}, {self.c_name}, {presence}_len);']
556
557
558class TypeBitfield32(Type):
559    def _complex_member_type(self, ri):
560        return "struct nla_bitfield32"
561
562    def _attr_typol(self):
563        return f'.type = YNL_PT_BITFIELD32, '
564
565    def _attr_policy(self, policy):
566        if not 'enum' in self.attr:
567            raise Exception('Enum required for bitfield32 attr')
568        enum = self.family.consts[self.attr['enum']]
569        mask = enum.get_mask(as_flags=True)
570        return f"NLA_POLICY_BITFIELD32({mask})"
571
572    def attr_put(self, ri, var):
573        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
574        self._attr_put_line(ri, var, line)
575
576    def _attr_get(self, ri, var):
577        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
578
579    def _setter_lines(self, ri, member, presence):
580        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
581
582
583class TypeNest(Type):
584    def is_recursive(self):
585        return self.family.pure_nested_structs[self.nested_attrs].recursive
586
587    def _complex_member_type(self, ri):
588        return self.nested_struct_type
589
590    def _free_lines(self, ri, var, ref):
591        lines = []
592        at = '&'
593        if self.is_recursive_for_op(ri):
594            at = ''
595            lines += [f'if ({var}->{ref}{self.c_name})']
596        lines += [f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});']
597        return lines
598
599    def _attr_typol(self):
600        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
601
602    def _attr_policy(self, policy):
603        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
604
605    def attr_put(self, ri, var):
606        at = '' if self.is_recursive_for_op(ri) else '&'
607        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
608                            f"{self.enum_name}, {at}{var}->{self.c_name})")
609
610    def _attr_get(self, ri, var):
611        get_lines = [f"if ({self.nested_render_name}_parse(&parg, attr))",
612                     "return YNL_PARSE_CB_ERROR;"]
613        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
614                      f"parg.data = &{var}->{self.c_name};"]
615        return get_lines, init_lines, None
616
617    def setter(self, ri, space, direction, deref=False, ref=None):
618        ref = (ref if ref else []) + [self.c_name]
619
620        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
621            if attr.is_recursive():
622                continue
623            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
624
625
626class TypeMultiAttr(Type):
627    def __init__(self, family, attr_set, attr, value, base_type):
628        super().__init__(family, attr_set, attr, value)
629
630        self.base_type = base_type
631
632    def is_multi_val(self):
633        return True
634
635    def presence_type(self):
636        return 'count'
637
638    def _complex_member_type(self, ri):
639        if 'type' not in self.attr or self.attr['type'] == 'nest':
640            return self.nested_struct_type
641        elif self.attr['type'] in scalars:
642            scalar_pfx = '__' if ri.ku_space == 'user' else ''
643            return scalar_pfx + self.attr['type']
644        else:
645            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
646
647    def free_needs_iter(self):
648        return 'type' not in self.attr or self.attr['type'] == 'nest'
649
650    def _free_lines(self, ri, var, ref):
651        lines = []
652        if self.attr['type'] in scalars:
653            lines += [f"free({var}->{ref}{self.c_name});"]
654        elif 'type' not in self.attr or self.attr['type'] == 'nest':
655            lines += [
656                f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)",
657                f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);',
658                f"free({var}->{ref}{self.c_name});",
659            ]
660        else:
661            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
662        return lines
663
664    def _attr_policy(self, policy):
665        return self.base_type._attr_policy(policy)
666
667    def _attr_typol(self):
668        return self.base_type._attr_typol()
669
670    def _attr_get(self, ri, var):
671        return f'n_{self.c_name}++;', None, None
672
673    def attr_put(self, ri, var):
674        if self.attr['type'] in scalars:
675            put_type = self.type
676            ri.cw.p(f"for (i = 0; i < {var}->n_{self.c_name}; i++)")
677            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
678        elif 'type' not in self.attr or self.attr['type'] == 'nest':
679            ri.cw.p(f"for (i = 0; i < {var}->n_{self.c_name}; i++)")
680            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
681                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
682        else:
683            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
684
685    def _setter_lines(self, ri, member, presence):
686        # For multi-attr we have a count, not presence, hack up the presence
687        presence = presence[:-(len('_present.') + len(self.c_name))] + "n_" + self.c_name
688        return [f"{member} = {self.c_name};",
689                f"{presence} = n_{self.c_name};"]
690
691
692class TypeArrayNest(Type):
693    def is_multi_val(self):
694        return True
695
696    def presence_type(self):
697        return 'count'
698
699    def _complex_member_type(self, ri):
700        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
701            return self.nested_struct_type
702        elif self.attr['sub-type'] in scalars:
703            scalar_pfx = '__' if ri.ku_space == 'user' else ''
704            return scalar_pfx + self.attr['sub-type']
705        else:
706            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
707
708    def _attr_typol(self):
709        if self.attr['sub-type'] in scalars:
710            return f'.type = YNL_PT_U{c_upper(self.sub_type[1:])}, '
711        else:
712            return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
713
714    def _attr_get(self, ri, var):
715        local_vars = ['const struct nlattr *attr2;']
716        get_lines = [f'attr_{self.c_name} = attr;',
717                     'ynl_attr_for_each_nested(attr2, attr) {',
718                     '\tif (ynl_attr_validate(yarg, attr2))',
719                     '\t\treturn YNL_PARSE_CB_ERROR;',
720                     f'\t{var}->n_{self.c_name}++;',
721                     '}']
722        return get_lines, None, local_vars
723
724
725class TypeNestTypeValue(Type):
726    def _complex_member_type(self, ri):
727        return self.nested_struct_type
728
729    def _attr_typol(self):
730        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
731
732    def _attr_get(self, ri, var):
733        prev = 'attr'
734        tv_args = ''
735        get_lines = []
736        local_vars = []
737        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
738                      f"parg.data = &{var}->{self.c_name};"]
739        if 'type-value' in self.attr:
740            tv_names = [c_lower(x) for x in self.attr["type-value"]]
741            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
742            local_vars += [f'__u32 {", ".join(tv_names)};']
743            for level in self.attr["type-value"]:
744                level = c_lower(level)
745                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
746                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
747                prev = 'attr_' + level
748
749            tv_args = f", {', '.join(tv_names)}"
750
751        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
752        return get_lines, init_lines, local_vars
753
754
755class Struct:
756    def __init__(self, family, space_name, type_list=None, inherited=None):
757        self.family = family
758        self.space_name = space_name
759        self.attr_set = family.attr_sets[space_name]
760        # Use list to catch comparisons with empty sets
761        self._inherited = inherited if inherited is not None else []
762        self.inherited = []
763
764        self.nested = type_list is None
765        if family.name == c_lower(space_name):
766            self.render_name = c_lower(family.ident_name)
767        else:
768            self.render_name = c_lower(family.ident_name + '-' + space_name)
769        self.struct_name = 'struct ' + self.render_name
770        if self.nested and space_name in family.consts:
771            self.struct_name += '_'
772        self.ptr_name = self.struct_name + ' *'
773        # All attr sets this one contains, directly or multiple levels down
774        self.child_nests = set()
775
776        self.request = False
777        self.reply = False
778        self.recursive = False
779        self.in_multi_val = False  # used by a MultiAttr or and legacy arrays
780
781        self.attr_list = []
782        self.attrs = dict()
783        if type_list is not None:
784            for t in type_list:
785                self.attr_list.append((t, self.attr_set[t]),)
786        else:
787            for t in self.attr_set:
788                self.attr_list.append((t, self.attr_set[t]),)
789
790        max_val = 0
791        self.attr_max_val = None
792        for name, attr in self.attr_list:
793            if attr.value >= max_val:
794                max_val = attr.value
795                self.attr_max_val = attr
796            self.attrs[name] = attr
797
798    def __iter__(self):
799        yield from self.attrs
800
801    def __getitem__(self, key):
802        return self.attrs[key]
803
804    def member_list(self):
805        return self.attr_list
806
807    def set_inherited(self, new_inherited):
808        if self._inherited != new_inherited:
809            raise Exception("Inheriting different members not supported")
810        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
811
812
813class EnumEntry(SpecEnumEntry):
814    def __init__(self, enum_set, yaml, prev, value_start):
815        super().__init__(enum_set, yaml, prev, value_start)
816
817        if prev:
818            self.value_change = (self.value != prev.value + 1)
819        else:
820            self.value_change = (self.value != 0)
821        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
822
823        # Added by resolve:
824        self.c_name = None
825        delattr(self, "c_name")
826
827    def resolve(self):
828        self.resolve_up(super())
829
830        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
831
832
833class EnumSet(SpecEnumSet):
834    def __init__(self, family, yaml):
835        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
836
837        if 'enum-name' in yaml:
838            if yaml['enum-name']:
839                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
840                self.user_type = self.enum_name
841            else:
842                self.enum_name = None
843        else:
844            self.enum_name = 'enum ' + self.render_name
845
846        if self.enum_name:
847            self.user_type = self.enum_name
848        else:
849            self.user_type = 'int'
850
851        self.value_pfx = yaml.get('name-prefix', f"{family.ident_name}-{yaml['name']}-")
852        self.header = yaml.get('header', None)
853        self.enum_cnt_name = yaml.get('enum-cnt-name', None)
854
855        super().__init__(family, yaml)
856
857    def new_entry(self, entry, prev_entry, value_start):
858        return EnumEntry(self, entry, prev_entry, value_start)
859
860    def value_range(self):
861        low = min([x.value for x in self.entries.values()])
862        high = max([x.value for x in self.entries.values()])
863
864        if high - low + 1 != len(self.entries):
865            raise Exception("Can't get value range for a noncontiguous enum")
866
867        return low, high
868
869
870class AttrSet(SpecAttrSet):
871    def __init__(self, family, yaml):
872        super().__init__(family, yaml)
873
874        if self.subset_of is None:
875            if 'name-prefix' in yaml:
876                pfx = yaml['name-prefix']
877            elif self.name == family.name:
878                pfx = family.ident_name + '-a-'
879            else:
880                pfx = f"{family.ident_name}-a-{self.name}-"
881            self.name_prefix = c_upper(pfx)
882            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
883            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
884        else:
885            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
886            self.max_name = family.attr_sets[self.subset_of].max_name
887            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
888
889        # Added by resolve:
890        self.c_name = None
891        delattr(self, "c_name")
892
893    def resolve(self):
894        self.c_name = c_lower(self.name)
895        if self.c_name in _C_KW:
896            self.c_name += '_'
897        if self.c_name == self.family.c_name:
898            self.c_name = ''
899
900    def new_attr(self, elem, value):
901        if elem['type'] in scalars:
902            t = TypeScalar(self.family, self, elem, value)
903        elif elem['type'] == 'unused':
904            t = TypeUnused(self.family, self, elem, value)
905        elif elem['type'] == 'pad':
906            t = TypePad(self.family, self, elem, value)
907        elif elem['type'] == 'flag':
908            t = TypeFlag(self.family, self, elem, value)
909        elif elem['type'] == 'string':
910            t = TypeString(self.family, self, elem, value)
911        elif elem['type'] == 'binary':
912            t = TypeBinary(self.family, self, elem, value)
913        elif elem['type'] == 'bitfield32':
914            t = TypeBitfield32(self.family, self, elem, value)
915        elif elem['type'] == 'nest':
916            t = TypeNest(self.family, self, elem, value)
917        elif elem['type'] == 'indexed-array' and 'sub-type' in elem:
918            if elem["sub-type"] in ['nest', 'u32']:
919                t = TypeArrayNest(self.family, self, elem, value)
920            else:
921                raise Exception(f'new_attr: unsupported sub-type {elem["sub-type"]}')
922        elif elem['type'] == 'nest-type-value':
923            t = TypeNestTypeValue(self.family, self, elem, value)
924        else:
925            raise Exception(f"No typed class for type {elem['type']}")
926
927        if 'multi-attr' in elem and elem['multi-attr']:
928            t = TypeMultiAttr(self.family, self, elem, value, t)
929
930        return t
931
932
933class Operation(SpecOperation):
934    def __init__(self, family, yaml, req_value, rsp_value):
935        super().__init__(family, yaml, req_value, rsp_value)
936
937        self.render_name = c_lower(family.ident_name + '_' + self.name)
938
939        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
940                         ('dump' in yaml and 'request' in yaml['dump'])
941
942        self.has_ntf = False
943
944        # Added by resolve:
945        self.enum_name = None
946        delattr(self, "enum_name")
947
948    def resolve(self):
949        self.resolve_up(super())
950
951        if not self.is_async:
952            self.enum_name = self.family.op_prefix + c_upper(self.name)
953        else:
954            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
955
956    def mark_has_ntf(self):
957        self.has_ntf = True
958
959
960class Family(SpecFamily):
961    def __init__(self, file_name, exclude_ops):
962        # Added by resolve:
963        self.c_name = None
964        delattr(self, "c_name")
965        self.op_prefix = None
966        delattr(self, "op_prefix")
967        self.async_op_prefix = None
968        delattr(self, "async_op_prefix")
969        self.mcgrps = None
970        delattr(self, "mcgrps")
971        self.consts = None
972        delattr(self, "consts")
973        self.hooks = None
974        delattr(self, "hooks")
975
976        super().__init__(file_name, exclude_ops=exclude_ops)
977
978        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
979        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
980
981        if 'definitions' not in self.yaml:
982            self.yaml['definitions'] = []
983
984        if 'uapi-header' in self.yaml:
985            self.uapi_header = self.yaml['uapi-header']
986        else:
987            self.uapi_header = f"linux/{self.ident_name}.h"
988        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
989            self.uapi_header_name = self.uapi_header[6:-2]
990        else:
991            self.uapi_header_name = self.ident_name
992
993    def resolve(self):
994        self.resolve_up(super())
995
996        if self.yaml.get('protocol', 'genetlink') not in {'genetlink', 'genetlink-c', 'genetlink-legacy'}:
997            raise Exception("Codegen only supported for genetlink")
998
999        self.c_name = c_lower(self.ident_name)
1000        if 'name-prefix' in self.yaml['operations']:
1001            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
1002        else:
1003            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
1004        if 'async-prefix' in self.yaml['operations']:
1005            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
1006        else:
1007            self.async_op_prefix = self.op_prefix
1008
1009        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
1010
1011        self.hooks = dict()
1012        for when in ['pre', 'post']:
1013            self.hooks[when] = dict()
1014            for op_mode in ['do', 'dump']:
1015                self.hooks[when][op_mode] = dict()
1016                self.hooks[when][op_mode]['set'] = set()
1017                self.hooks[when][op_mode]['list'] = []
1018
1019        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
1020        self.root_sets = dict()
1021        # dict space-name -> set('request', 'reply')
1022        self.pure_nested_structs = dict()
1023
1024        self._mark_notify()
1025        self._mock_up_events()
1026
1027        self._load_root_sets()
1028        self._load_nested_sets()
1029        self._load_attr_use()
1030        self._load_hooks()
1031
1032        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
1033        if self.kernel_policy == 'global':
1034            self._load_global_policy()
1035
1036    def new_enum(self, elem):
1037        return EnumSet(self, elem)
1038
1039    def new_attr_set(self, elem):
1040        return AttrSet(self, elem)
1041
1042    def new_operation(self, elem, req_value, rsp_value):
1043        return Operation(self, elem, req_value, rsp_value)
1044
1045    def _mark_notify(self):
1046        for op in self.msgs.values():
1047            if 'notify' in op:
1048                self.ops[op['notify']].mark_has_ntf()
1049
1050    # Fake a 'do' equivalent of all events, so that we can render their response parsing
1051    def _mock_up_events(self):
1052        for op in self.yaml['operations']['list']:
1053            if 'event' in op:
1054                op['do'] = {
1055                    'reply': {
1056                        'attributes': op['event']['attributes']
1057                    }
1058                }
1059
1060    def _load_root_sets(self):
1061        for op_name, op in self.msgs.items():
1062            if 'attribute-set' not in op:
1063                continue
1064
1065            req_attrs = set()
1066            rsp_attrs = set()
1067            for op_mode in ['do', 'dump']:
1068                if op_mode in op and 'request' in op[op_mode]:
1069                    req_attrs.update(set(op[op_mode]['request']['attributes']))
1070                if op_mode in op and 'reply' in op[op_mode]:
1071                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
1072            if 'event' in op:
1073                rsp_attrs.update(set(op['event']['attributes']))
1074
1075            if op['attribute-set'] not in self.root_sets:
1076                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1077            else:
1078                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1079                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1080
1081    def _sort_pure_types(self):
1082        # Try to reorder according to dependencies
1083        pns_key_list = list(self.pure_nested_structs.keys())
1084        pns_key_seen = set()
1085        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1086        for _ in range(rounds):
1087            if len(pns_key_list) == 0:
1088                break
1089            name = pns_key_list.pop(0)
1090            finished = True
1091            for _, spec in self.attr_sets[name].items():
1092                if 'nested-attributes' in spec:
1093                    nested = spec['nested-attributes']
1094                    # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1095                    if self.pure_nested_structs[nested].recursive:
1096                        continue
1097                    if nested not in pns_key_seen:
1098                        # Dicts are sorted, this will make struct last
1099                        struct = self.pure_nested_structs.pop(name)
1100                        self.pure_nested_structs[name] = struct
1101                        finished = False
1102                        break
1103            if finished:
1104                pns_key_seen.add(name)
1105            else:
1106                pns_key_list.append(name)
1107
1108    def _load_nested_sets(self):
1109        attr_set_queue = list(self.root_sets.keys())
1110        attr_set_seen = set(self.root_sets.keys())
1111
1112        while len(attr_set_queue):
1113            a_set = attr_set_queue.pop(0)
1114            for attr, spec in self.attr_sets[a_set].items():
1115                if 'nested-attributes' not in spec:
1116                    continue
1117
1118                nested = spec['nested-attributes']
1119                if nested not in attr_set_seen:
1120                    attr_set_queue.append(nested)
1121                    attr_set_seen.add(nested)
1122
1123                inherit = set()
1124                if nested not in self.root_sets:
1125                    if nested not in self.pure_nested_structs:
1126                        self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
1127                else:
1128                    raise Exception(f'Using attr set as root and nested not supported - {nested}')
1129
1130                if 'type-value' in spec:
1131                    if nested in self.root_sets:
1132                        raise Exception("Inheriting members to a space used as root not supported")
1133                    inherit.update(set(spec['type-value']))
1134                elif spec['type'] == 'indexed-array':
1135                    inherit.add('idx')
1136                self.pure_nested_structs[nested].set_inherited(inherit)
1137
1138        for root_set, rs_members in self.root_sets.items():
1139            for attr, spec in self.attr_sets[root_set].items():
1140                if 'nested-attributes' in spec:
1141                    nested = spec['nested-attributes']
1142                    if attr in rs_members['request']:
1143                        self.pure_nested_structs[nested].request = True
1144                    if attr in rs_members['reply']:
1145                        self.pure_nested_structs[nested].reply = True
1146
1147                if spec.is_multi_val():
1148                    child = self.pure_nested_structs.get(nested)
1149                    child.in_multi_val = True
1150
1151        self._sort_pure_types()
1152
1153        # Propagate the request / reply / recursive
1154        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1155            for _, spec in self.attr_sets[attr_set].items():
1156                if 'nested-attributes' in spec:
1157                    child_name = spec['nested-attributes']
1158                    struct.child_nests.add(child_name)
1159                    child = self.pure_nested_structs.get(child_name)
1160                    if child:
1161                        if not child.recursive:
1162                            struct.child_nests.update(child.child_nests)
1163                        child.request |= struct.request
1164                        child.reply |= struct.reply
1165                        if spec.is_multi_val():
1166                            child.in_multi_val = True
1167                if attr_set in struct.child_nests:
1168                    struct.recursive = True
1169
1170        self._sort_pure_types()
1171
1172    def _load_attr_use(self):
1173        for _, struct in self.pure_nested_structs.items():
1174            if struct.request:
1175                for _, arg in struct.member_list():
1176                    arg.set_request()
1177            if struct.reply:
1178                for _, arg in struct.member_list():
1179                    arg.set_reply()
1180
1181        for root_set, rs_members in self.root_sets.items():
1182            for attr, spec in self.attr_sets[root_set].items():
1183                if attr in rs_members['request']:
1184                    spec.set_request()
1185                if attr in rs_members['reply']:
1186                    spec.set_reply()
1187
1188    def _load_global_policy(self):
1189        global_set = set()
1190        attr_set_name = None
1191        for op_name, op in self.ops.items():
1192            if not op:
1193                continue
1194            if 'attribute-set' not in op:
1195                continue
1196
1197            if attr_set_name is None:
1198                attr_set_name = op['attribute-set']
1199            if attr_set_name != op['attribute-set']:
1200                raise Exception('For a global policy all ops must use the same set')
1201
1202            for op_mode in ['do', 'dump']:
1203                if op_mode in op:
1204                    req = op[op_mode].get('request')
1205                    if req:
1206                        global_set.update(req.get('attributes', []))
1207
1208        self.global_policy = []
1209        self.global_policy_set = attr_set_name
1210        for attr in self.attr_sets[attr_set_name]:
1211            if attr in global_set:
1212                self.global_policy.append(attr)
1213
1214    def _load_hooks(self):
1215        for op in self.ops.values():
1216            for op_mode in ['do', 'dump']:
1217                if op_mode not in op:
1218                    continue
1219                for when in ['pre', 'post']:
1220                    if when not in op[op_mode]:
1221                        continue
1222                    name = op[op_mode][when]
1223                    if name in self.hooks[when][op_mode]['set']:
1224                        continue
1225                    self.hooks[when][op_mode]['set'].add(name)
1226                    self.hooks[when][op_mode]['list'].append(name)
1227
1228
1229class RenderInfo:
1230    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1231        self.family = family
1232        self.nl = cw.nlib
1233        self.ku_space = ku_space
1234        self.op_mode = op_mode
1235        self.op = op
1236
1237        self.fixed_hdr = None
1238        if op and op.fixed_header:
1239            self.fixed_hdr = 'struct ' + c_lower(op.fixed_header)
1240
1241        # 'do' and 'dump' response parsing is identical
1242        self.type_consistent = True
1243        if op_mode != 'do' and 'dump' in op:
1244            if 'do' in op:
1245                if ('reply' in op['do']) != ('reply' in op["dump"]):
1246                    self.type_consistent = False
1247                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1248                    self.type_consistent = False
1249            else:
1250                self.type_consistent = False
1251
1252        self.attr_set = attr_set
1253        if not self.attr_set:
1254            self.attr_set = op['attribute-set']
1255
1256        self.type_name_conflict = False
1257        if op:
1258            self.type_name = c_lower(op.name)
1259        else:
1260            self.type_name = c_lower(attr_set)
1261            if attr_set in family.consts:
1262                self.type_name_conflict = True
1263
1264        self.cw = cw
1265
1266        self.struct = dict()
1267        if op_mode == 'notify':
1268            op_mode = 'do'
1269        for op_dir in ['request', 'reply']:
1270            if op:
1271                type_list = []
1272                if op_dir in op[op_mode]:
1273                    type_list = op[op_mode][op_dir]['attributes']
1274                self.struct[op_dir] = Struct(family, self.attr_set, type_list=type_list)
1275        if op_mode == 'event':
1276            self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
1277
1278
1279class CodeWriter:
1280    def __init__(self, nlib, out_file=None, overwrite=True):
1281        self.nlib = nlib
1282        self._overwrite = overwrite
1283
1284        self._nl = False
1285        self._block_end = False
1286        self._silent_block = False
1287        self._ind = 0
1288        self._ifdef_block = None
1289        if out_file is None:
1290            self._out = os.sys.stdout
1291        else:
1292            self._out = tempfile.NamedTemporaryFile('w+')
1293            self._out_file = out_file
1294
1295    def __del__(self):
1296        self.close_out_file()
1297
1298    def close_out_file(self):
1299        if self._out == os.sys.stdout:
1300            return
1301        # Avoid modifying the file if contents didn't change
1302        self._out.flush()
1303        if not self._overwrite and os.path.isfile(self._out_file):
1304            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1305                return
1306        with open(self._out_file, 'w+') as out_file:
1307            self._out.seek(0)
1308            shutil.copyfileobj(self._out, out_file)
1309            self._out.close()
1310        self._out = os.sys.stdout
1311
1312    @classmethod
1313    def _is_cond(cls, line):
1314        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1315
1316    def p(self, line, add_ind=0):
1317        if self._block_end:
1318            self._block_end = False
1319            if line.startswith('else'):
1320                line = '} ' + line
1321            else:
1322                self._out.write('\t' * self._ind + '}\n')
1323
1324        if self._nl:
1325            self._out.write('\n')
1326            self._nl = False
1327
1328        ind = self._ind
1329        if line[-1] == ':':
1330            ind -= 1
1331        if self._silent_block:
1332            ind += 1
1333        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1334        if line[0] == '#':
1335            ind = 0
1336        if add_ind:
1337            ind += add_ind
1338        self._out.write('\t' * ind + line + '\n')
1339
1340    def nl(self):
1341        self._nl = True
1342
1343    def block_start(self, line=''):
1344        if line:
1345            line = line + ' '
1346        self.p(line + '{')
1347        self._ind += 1
1348
1349    def block_end(self, line=''):
1350        if line and line[0] not in {';', ','}:
1351            line = ' ' + line
1352        self._ind -= 1
1353        self._nl = False
1354        if not line:
1355            # Delay printing closing bracket in case "else" comes next
1356            if self._block_end:
1357                self._out.write('\t' * (self._ind + 1) + '}\n')
1358            self._block_end = True
1359        else:
1360            self.p('}' + line)
1361
1362    def write_doc_line(self, doc, indent=True):
1363        words = doc.split()
1364        line = ' *'
1365        for word in words:
1366            if len(line) + len(word) >= 79:
1367                self.p(line)
1368                line = ' *'
1369                if indent:
1370                    line += '  '
1371            line += ' ' + word
1372        self.p(line)
1373
1374    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1375        if not args:
1376            args = ['void']
1377
1378        if doc:
1379            self.p('/*')
1380            self.p(' * ' + doc)
1381            self.p(' */')
1382
1383        oneline = qual_ret
1384        if qual_ret[-1] != '*':
1385            oneline += ' '
1386        oneline += f"{name}({', '.join(args)}){suffix}"
1387
1388        if len(oneline) < 80:
1389            self.p(oneline)
1390            return
1391
1392        v = qual_ret
1393        if len(v) > 3:
1394            self.p(v)
1395            v = ''
1396        elif qual_ret[-1] != '*':
1397            v += ' '
1398        v += name + '('
1399        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1400        delta_ind = len(v) - len(ind)
1401        v += args[0]
1402        i = 1
1403        while i < len(args):
1404            next_len = len(v) + len(args[i])
1405            if v[0] == '\t':
1406                next_len += delta_ind
1407            if next_len > 76:
1408                self.p(v + ',')
1409                v = ind
1410            else:
1411                v += ', '
1412            v += args[i]
1413            i += 1
1414        self.p(v + ')' + suffix)
1415
1416    def write_func_lvar(self, local_vars):
1417        if not local_vars:
1418            return
1419
1420        if type(local_vars) is str:
1421            local_vars = [local_vars]
1422
1423        local_vars.sort(key=len, reverse=True)
1424        for var in local_vars:
1425            self.p(var)
1426        self.nl()
1427
1428    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1429        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1430        self.block_start()
1431        self.write_func_lvar(local_vars=local_vars)
1432
1433        for line in body:
1434            self.p(line)
1435        self.block_end()
1436
1437    def writes_defines(self, defines):
1438        longest = 0
1439        for define in defines:
1440            if len(define[0]) > longest:
1441                longest = len(define[0])
1442        longest = ((longest + 8) // 8) * 8
1443        for define in defines:
1444            line = '#define ' + define[0]
1445            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1446            if type(define[1]) is int:
1447                line += str(define[1])
1448            elif type(define[1]) is str:
1449                line += '"' + define[1] + '"'
1450            self.p(line)
1451
1452    def write_struct_init(self, members):
1453        longest = max([len(x[0]) for x in members])
1454        longest += 1  # because we prepend a .
1455        longest = ((longest + 8) // 8) * 8
1456        for one in members:
1457            line = '.' + one[0]
1458            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1459            line += '= ' + str(one[1]) + ','
1460            self.p(line)
1461
1462    def ifdef_block(self, config):
1463        config_option = None
1464        if config:
1465            config_option = 'CONFIG_' + c_upper(config)
1466        if self._ifdef_block == config_option:
1467            return
1468
1469        if self._ifdef_block:
1470            self.p('#endif /* ' + self._ifdef_block + ' */')
1471        if config_option:
1472            self.p('#ifdef ' + config_option)
1473        self._ifdef_block = config_option
1474
1475
1476scalars = {'u8', 'u16', 'u32', 'u64', 's8', 's16', 's32', 's64', 'uint', 'sint'}
1477
1478direction_to_suffix = {
1479    'reply': '_rsp',
1480    'request': '_req',
1481    '': ''
1482}
1483
1484op_mode_to_wrapper = {
1485    'do': '',
1486    'dump': '_list',
1487    'notify': '_ntf',
1488    'event': '',
1489}
1490
1491_C_KW = {
1492    'auto',
1493    'bool',
1494    'break',
1495    'case',
1496    'char',
1497    'const',
1498    'continue',
1499    'default',
1500    'do',
1501    'double',
1502    'else',
1503    'enum',
1504    'extern',
1505    'float',
1506    'for',
1507    'goto',
1508    'if',
1509    'inline',
1510    'int',
1511    'long',
1512    'register',
1513    'return',
1514    'short',
1515    'signed',
1516    'sizeof',
1517    'static',
1518    'struct',
1519    'switch',
1520    'typedef',
1521    'union',
1522    'unsigned',
1523    'void',
1524    'volatile',
1525    'while'
1526}
1527
1528
1529def rdir(direction):
1530    if direction == 'reply':
1531        return 'request'
1532    if direction == 'request':
1533        return 'reply'
1534    return direction
1535
1536
1537def op_prefix(ri, direction, deref=False):
1538    suffix = f"_{ri.type_name}"
1539
1540    if not ri.op_mode or ri.op_mode == 'do':
1541        suffix += f"{direction_to_suffix[direction]}"
1542    else:
1543        if direction == 'request':
1544            suffix += '_req_dump'
1545        else:
1546            if ri.type_consistent:
1547                if deref:
1548                    suffix += f"{direction_to_suffix[direction]}"
1549                else:
1550                    suffix += op_mode_to_wrapper[ri.op_mode]
1551            else:
1552                suffix += '_rsp'
1553                suffix += '_dump' if deref else '_list'
1554
1555    return f"{ri.family.c_name}{suffix}"
1556
1557
1558def type_name(ri, direction, deref=False):
1559    return f"struct {op_prefix(ri, direction, deref=deref)}"
1560
1561
1562def print_prototype(ri, direction, terminate=True, doc=None):
1563    suffix = ';' if terminate else ''
1564
1565    fname = ri.op.render_name
1566    if ri.op_mode == 'dump':
1567        fname += '_dump'
1568
1569    args = ['struct ynl_sock *ys']
1570    if 'request' in ri.op[ri.op_mode]:
1571        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1572
1573    ret = 'int'
1574    if 'reply' in ri.op[ri.op_mode]:
1575        ret = f"{type_name(ri, rdir(direction))} *"
1576
1577    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1578
1579
1580def print_req_prototype(ri):
1581    print_prototype(ri, "request", doc=ri.op['doc'])
1582
1583
1584def print_dump_prototype(ri):
1585    print_prototype(ri, "request")
1586
1587
1588def put_typol_fwd(cw, struct):
1589    cw.p(f'extern const struct ynl_policy_nest {struct.render_name}_nest;')
1590
1591
1592def put_typol(cw, struct):
1593    type_max = struct.attr_set.max_name
1594    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1595
1596    for _, arg in struct.member_list():
1597        arg.attr_typol(cw)
1598
1599    cw.block_end(line=';')
1600    cw.nl()
1601
1602    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1603    cw.p(f'.max_attr = {type_max},')
1604    cw.p(f'.table = {struct.render_name}_policy,')
1605    cw.block_end(line=';')
1606    cw.nl()
1607
1608
1609def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1610    args = [f'int {arg_name}']
1611    if enum:
1612        args = [enum.user_type + ' ' + arg_name]
1613    cw.write_func_prot('const char *', f'{render_name}_str', args)
1614    cw.block_start()
1615    if enum and enum.type == 'flags':
1616        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1617    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
1618    cw.p('return NULL;')
1619    cw.p(f'return {map_name}[{arg_name}];')
1620    cw.block_end()
1621    cw.nl()
1622
1623
1624def put_op_name_fwd(family, cw):
1625    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
1626
1627
1628def put_op_name(family, cw):
1629    map_name = f'{family.c_name}_op_strmap'
1630    cw.block_start(line=f"static const char * const {map_name}[] =")
1631    for op_name, op in family.msgs.items():
1632        if op.rsp_value:
1633            # Make sure we don't add duplicated entries, if multiple commands
1634            # produce the same response in legacy families.
1635            if family.rsp_by_value[op.rsp_value] != op:
1636                cw.p(f'// skip "{op_name}", duplicate reply value')
1637                continue
1638
1639            if op.req_value == op.rsp_value:
1640                cw.p(f'[{op.enum_name}] = "{op_name}",')
1641            else:
1642                cw.p(f'[{op.rsp_value}] = "{op_name}",')
1643    cw.block_end(line=';')
1644    cw.nl()
1645
1646    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
1647
1648
1649def put_enum_to_str_fwd(family, cw, enum):
1650    args = [enum.user_type + ' value']
1651    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1652
1653
1654def put_enum_to_str(family, cw, enum):
1655    map_name = f'{enum.render_name}_strmap'
1656    cw.block_start(line=f"static const char * const {map_name}[] =")
1657    for entry in enum.entries.values():
1658        cw.p(f'[{entry.value}] = "{entry.name}",')
1659    cw.block_end(line=';')
1660    cw.nl()
1661
1662    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
1663
1664
1665def put_req_nested_prototype(ri, struct, suffix=';'):
1666    func_args = ['struct nlmsghdr *nlh',
1667                 'unsigned int attr_type',
1668                 f'{struct.ptr_name}obj']
1669
1670    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
1671                          suffix=suffix)
1672
1673
1674def put_req_nested(ri, struct):
1675    local_vars = []
1676    init_lines = []
1677
1678    local_vars.append('struct nlattr *nest;')
1679    init_lines.append("nest = ynl_attr_nest_start(nlh, attr_type);")
1680
1681    for _, arg in struct.member_list():
1682        if arg.presence_type() == 'count':
1683            local_vars.append('unsigned int i;')
1684            break
1685
1686    put_req_nested_prototype(ri, struct, suffix='')
1687    ri.cw.block_start()
1688    ri.cw.write_func_lvar(local_vars)
1689
1690    for line in init_lines:
1691        ri.cw.p(line)
1692
1693    for _, arg in struct.member_list():
1694        arg.attr_put(ri, "obj")
1695
1696    ri.cw.p("ynl_attr_nest_end(nlh, nest);")
1697
1698    ri.cw.nl()
1699    ri.cw.p('return 0;')
1700    ri.cw.block_end()
1701    ri.cw.nl()
1702
1703
1704def _multi_parse(ri, struct, init_lines, local_vars):
1705    if struct.nested:
1706        iter_line = "ynl_attr_for_each_nested(attr, nested)"
1707    else:
1708        if ri.fixed_hdr:
1709            local_vars += ['void *hdr;']
1710        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
1711
1712    array_nests = set()
1713    multi_attrs = set()
1714    needs_parg = False
1715    for arg, aspec in struct.member_list():
1716        if aspec['type'] == 'indexed-array' and 'sub-type' in aspec:
1717            if aspec["sub-type"] == 'nest':
1718                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1719                array_nests.add(arg)
1720            elif aspec['sub-type'] in scalars:
1721                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1722                array_nests.add(arg)
1723            else:
1724                raise Exception(f'Not supported sub-type {aspec["sub-type"]}')
1725        if 'multi-attr' in aspec:
1726            multi_attrs.add(arg)
1727        needs_parg |= 'nested-attributes' in aspec
1728    if array_nests or multi_attrs:
1729        local_vars.append('int i;')
1730    if needs_parg:
1731        local_vars.append('struct ynl_parse_arg parg;')
1732        init_lines.append('parg.ys = yarg->ys;')
1733
1734    all_multi = array_nests | multi_attrs
1735
1736    for anest in sorted(all_multi):
1737        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
1738
1739    ri.cw.block_start()
1740    ri.cw.write_func_lvar(local_vars)
1741
1742    for line in init_lines:
1743        ri.cw.p(line)
1744    ri.cw.nl()
1745
1746    for arg in struct.inherited:
1747        ri.cw.p(f'dst->{arg} = {arg};')
1748
1749    if ri.fixed_hdr:
1750        ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
1751        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({ri.fixed_hdr}));")
1752    for anest in sorted(all_multi):
1753        aspec = struct[anest]
1754        ri.cw.p(f"if (dst->{aspec.c_name})")
1755        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
1756
1757    ri.cw.nl()
1758    ri.cw.block_start(line=iter_line)
1759    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
1760    ri.cw.nl()
1761
1762    first = True
1763    for _, arg in struct.member_list():
1764        good = arg.attr_get(ri, 'dst', first=first)
1765        # First may be 'unused' or 'pad', ignore those
1766        first &= not good
1767
1768    ri.cw.block_end()
1769    ri.cw.nl()
1770
1771    for anest in sorted(array_nests):
1772        aspec = struct[anest]
1773
1774        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1775        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1776        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1777        ri.cw.p('i = 0;')
1778        if 'nested-attributes' in aspec:
1779            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1780        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1781        if 'nested-attributes' in aspec:
1782            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1783            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
1784            ri.cw.p('return YNL_PARSE_CB_ERROR;')
1785        elif aspec.sub_type in scalars:
1786            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.sub_type}(attr);")
1787        else:
1788            raise Exception(f"Nest parsing type not supported in {aspec['name']}")
1789        ri.cw.p('i++;')
1790        ri.cw.block_end()
1791        ri.cw.block_end()
1792    ri.cw.nl()
1793
1794    for anest in sorted(multi_attrs):
1795        aspec = struct[anest]
1796        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1797        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1798        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1799        ri.cw.p('i = 0;')
1800        if 'nested-attributes' in aspec:
1801            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1802        ri.cw.block_start(line=iter_line)
1803        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
1804        if 'nested-attributes' in aspec:
1805            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1806            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1807            ri.cw.p('return YNL_PARSE_CB_ERROR;')
1808        elif aspec.type in scalars:
1809            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
1810        else:
1811            raise Exception('Nest parsing type not supported yet')
1812        ri.cw.p('i++;')
1813        ri.cw.block_end()
1814        ri.cw.block_end()
1815        ri.cw.block_end()
1816    ri.cw.nl()
1817
1818    if struct.nested:
1819        ri.cw.p('return 0;')
1820    else:
1821        ri.cw.p('return YNL_PARSE_CB_OK;')
1822    ri.cw.block_end()
1823    ri.cw.nl()
1824
1825
1826def parse_rsp_nested_prototype(ri, struct, suffix=';'):
1827    func_args = ['struct ynl_parse_arg *yarg',
1828                 'const struct nlattr *nested']
1829    for arg in struct.inherited:
1830        func_args.append('__u32 ' + arg)
1831
1832    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
1833                          suffix=suffix)
1834
1835
1836def parse_rsp_nested(ri, struct):
1837    parse_rsp_nested_prototype(ri, struct, suffix='')
1838
1839    local_vars = ['const struct nlattr *attr;',
1840                  f'{struct.ptr_name}dst = yarg->data;']
1841    init_lines = []
1842
1843    if struct.member_list():
1844        _multi_parse(ri, struct, init_lines, local_vars)
1845    else:
1846        # Empty nest
1847        ri.cw.block_start()
1848        ri.cw.p('return 0;')
1849        ri.cw.block_end()
1850        ri.cw.nl()
1851
1852
1853def parse_rsp_msg(ri, deref=False):
1854    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1855        return
1856
1857    func_args = ['const struct nlmsghdr *nlh',
1858                 'struct ynl_parse_arg *yarg']
1859
1860    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1861                  'const struct nlattr *attr;']
1862    init_lines = ['dst = yarg->data;']
1863
1864    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1865
1866    if ri.struct["reply"].member_list():
1867        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1868    else:
1869        # Empty reply
1870        ri.cw.block_start()
1871        ri.cw.p('return YNL_PARSE_CB_OK;')
1872        ri.cw.block_end()
1873        ri.cw.nl()
1874
1875
1876def print_req(ri):
1877    ret_ok = '0'
1878    ret_err = '-1'
1879    direction = "request"
1880    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
1881                  'struct nlmsghdr *nlh;',
1882                  'int err;']
1883
1884    if 'reply' in ri.op[ri.op_mode]:
1885        ret_ok = 'rsp'
1886        ret_err = 'NULL'
1887        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
1888
1889    if ri.fixed_hdr:
1890        local_vars += ['size_t hdr_len;',
1891                       'void *hdr;']
1892
1893    for _, attr in ri.struct["request"].member_list():
1894        if attr.presence_type() == 'count':
1895            local_vars += ['unsigned int i;']
1896            break
1897
1898    print_prototype(ri, direction, terminate=False)
1899    ri.cw.block_start()
1900    ri.cw.write_func_lvar(local_vars)
1901
1902    ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1903
1904    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1905    if 'reply' in ri.op[ri.op_mode]:
1906        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1907    ri.cw.nl()
1908
1909    if ri.fixed_hdr:
1910        ri.cw.p("hdr_len = sizeof(req->_hdr);")
1911        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
1912        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
1913        ri.cw.nl()
1914
1915    for _, attr in ri.struct["request"].member_list():
1916        attr.attr_put(ri, "req")
1917    ri.cw.nl()
1918
1919    if 'reply' in ri.op[ri.op_mode]:
1920        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1921        ri.cw.p('yrs.yarg.data = rsp;')
1922        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
1923        if ri.op.value is not None:
1924            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
1925        else:
1926            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
1927        ri.cw.nl()
1928    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
1929    ri.cw.p('if (err < 0)')
1930    if 'reply' in ri.op[ri.op_mode]:
1931        ri.cw.p('goto err_free;')
1932    else:
1933        ri.cw.p('return -1;')
1934    ri.cw.nl()
1935
1936    ri.cw.p(f"return {ret_ok};")
1937    ri.cw.nl()
1938
1939    if 'reply' in ri.op[ri.op_mode]:
1940        ri.cw.p('err_free:')
1941        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1942        ri.cw.p(f"return {ret_err};")
1943
1944    ri.cw.block_end()
1945
1946
1947def print_dump(ri):
1948    direction = "request"
1949    print_prototype(ri, direction, terminate=False)
1950    ri.cw.block_start()
1951    local_vars = ['struct ynl_dump_state yds = {};',
1952                  'struct nlmsghdr *nlh;',
1953                  'int err;']
1954
1955    if ri.fixed_hdr:
1956        local_vars += ['size_t hdr_len;',
1957                       'void *hdr;']
1958
1959    ri.cw.write_func_lvar(local_vars)
1960
1961    ri.cw.p('yds.yarg.ys = ys;')
1962    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1963    ri.cw.p("yds.yarg.data = NULL;")
1964    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1965    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1966    if ri.op.value is not None:
1967        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
1968    else:
1969        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
1970    ri.cw.nl()
1971    ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1972
1973    if ri.fixed_hdr:
1974        ri.cw.p("hdr_len = sizeof(req->_hdr);")
1975        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
1976        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
1977        ri.cw.nl()
1978
1979    if "request" in ri.op[ri.op_mode]:
1980        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1981        ri.cw.nl()
1982        for _, attr in ri.struct["request"].member_list():
1983            attr.attr_put(ri, "req")
1984    ri.cw.nl()
1985
1986    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
1987    ri.cw.p('if (err < 0)')
1988    ri.cw.p('goto free_list;')
1989    ri.cw.nl()
1990
1991    ri.cw.p('return yds.first;')
1992    ri.cw.nl()
1993    ri.cw.p('free_list:')
1994    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
1995    ri.cw.p('return NULL;')
1996    ri.cw.block_end()
1997
1998
1999def call_free(ri, direction, var):
2000    return f"{op_prefix(ri, direction)}_free({var});"
2001
2002
2003def free_arg_name(direction):
2004    if direction:
2005        return direction_to_suffix[direction][1:]
2006    return 'obj'
2007
2008
2009def print_alloc_wrapper(ri, direction):
2010    name = op_prefix(ri, direction)
2011    ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
2012    ri.cw.block_start()
2013    ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
2014    ri.cw.block_end()
2015
2016
2017def print_free_prototype(ri, direction, suffix=';'):
2018    name = op_prefix(ri, direction)
2019    struct_name = name
2020    if ri.type_name_conflict:
2021        struct_name += '_'
2022    arg = free_arg_name(direction)
2023    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
2024
2025
2026def _print_type(ri, direction, struct):
2027    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
2028    if not direction and ri.type_name_conflict:
2029        suffix += '_'
2030
2031    if ri.op_mode == 'dump':
2032        suffix += '_dump'
2033
2034    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
2035
2036    if ri.fixed_hdr:
2037        ri.cw.p(ri.fixed_hdr + ' _hdr;')
2038        ri.cw.nl()
2039
2040    meta_started = False
2041    for _, attr in struct.member_list():
2042        for type_filter in ['len', 'bit']:
2043            line = attr.presence_member(ri.ku_space, type_filter)
2044            if line:
2045                if not meta_started:
2046                    ri.cw.block_start(line=f"struct")
2047                    meta_started = True
2048                ri.cw.p(line)
2049    if meta_started:
2050        ri.cw.block_end(line='_present;')
2051        ri.cw.nl()
2052
2053    for arg in struct.inherited:
2054        ri.cw.p(f"__u32 {arg};")
2055
2056    for _, attr in struct.member_list():
2057        attr.struct_member(ri)
2058
2059    ri.cw.block_end(line=';')
2060    ri.cw.nl()
2061
2062
2063def print_type(ri, direction):
2064    _print_type(ri, direction, ri.struct[direction])
2065
2066
2067def print_type_full(ri, struct):
2068    _print_type(ri, "", struct)
2069
2070
2071def print_type_helpers(ri, direction, deref=False):
2072    print_free_prototype(ri, direction)
2073    ri.cw.nl()
2074
2075    if ri.ku_space == 'user' and direction == 'request':
2076        for _, attr in ri.struct[direction].member_list():
2077            attr.setter(ri, ri.attr_set, direction, deref=deref)
2078    ri.cw.nl()
2079
2080
2081def print_req_type_helpers(ri):
2082    if len(ri.struct["request"].attr_list) == 0:
2083        return
2084    print_alloc_wrapper(ri, "request")
2085    print_type_helpers(ri, "request")
2086
2087
2088def print_rsp_type_helpers(ri):
2089    if 'reply' not in ri.op[ri.op_mode]:
2090        return
2091    print_type_helpers(ri, "reply")
2092
2093
2094def print_parse_prototype(ri, direction, terminate=True):
2095    suffix = "_rsp" if direction == "reply" else "_req"
2096    term = ';' if terminate else ''
2097
2098    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
2099                          ['const struct nlattr **tb',
2100                           f"struct {ri.op.render_name}{suffix} *req"],
2101                          suffix=term)
2102
2103
2104def print_req_type(ri):
2105    if len(ri.struct["request"].attr_list) == 0:
2106        return
2107    print_type(ri, "request")
2108
2109
2110def print_req_free(ri):
2111    if 'request' not in ri.op[ri.op_mode]:
2112        return
2113    _free_type(ri, 'request', ri.struct['request'])
2114
2115
2116def print_rsp_type(ri):
2117    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2118        direction = 'reply'
2119    elif ri.op_mode == 'event':
2120        direction = 'reply'
2121    else:
2122        return
2123    print_type(ri, direction)
2124
2125
2126def print_wrapped_type(ri):
2127    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2128    if ri.op_mode == 'dump':
2129        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2130    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2131        ri.cw.p('__u16 family;')
2132        ri.cw.p('__u8 cmd;')
2133        ri.cw.p('struct ynl_ntf_base_type *next;')
2134        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2135    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2136    ri.cw.block_end(line=';')
2137    ri.cw.nl()
2138    print_free_prototype(ri, 'reply')
2139    ri.cw.nl()
2140
2141
2142def _free_type_members_iter(ri, struct):
2143    for _, attr in struct.member_list():
2144        if attr.free_needs_iter():
2145            ri.cw.p('unsigned int i;')
2146            ri.cw.nl()
2147            break
2148
2149
2150def _free_type_members(ri, var, struct, ref=''):
2151    for _, attr in struct.member_list():
2152        attr.free(ri, var, ref)
2153
2154
2155def _free_type(ri, direction, struct):
2156    var = free_arg_name(direction)
2157
2158    print_free_prototype(ri, direction, suffix='')
2159    ri.cw.block_start()
2160    _free_type_members_iter(ri, struct)
2161    _free_type_members(ri, var, struct)
2162    if direction:
2163        ri.cw.p(f'free({var});')
2164    ri.cw.block_end()
2165    ri.cw.nl()
2166
2167
2168def free_rsp_nested_prototype(ri):
2169        print_free_prototype(ri, "")
2170
2171
2172def free_rsp_nested(ri, struct):
2173    _free_type(ri, "", struct)
2174
2175
2176def print_rsp_free(ri):
2177    if 'reply' not in ri.op[ri.op_mode]:
2178        return
2179    _free_type(ri, 'reply', ri.struct['reply'])
2180
2181
2182def print_dump_type_free(ri):
2183    sub_type = type_name(ri, 'reply')
2184
2185    print_free_prototype(ri, 'reply', suffix='')
2186    ri.cw.block_start()
2187    ri.cw.p(f"{sub_type} *next = rsp;")
2188    ri.cw.nl()
2189    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2190    _free_type_members_iter(ri, ri.struct['reply'])
2191    ri.cw.p('rsp = next;')
2192    ri.cw.p('next = rsp->next;')
2193    ri.cw.nl()
2194
2195    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2196    ri.cw.p(f'free(rsp);')
2197    ri.cw.block_end()
2198    ri.cw.block_end()
2199    ri.cw.nl()
2200
2201
2202def print_ntf_type_free(ri):
2203    print_free_prototype(ri, 'reply', suffix='')
2204    ri.cw.block_start()
2205    _free_type_members_iter(ri, ri.struct['reply'])
2206    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2207    ri.cw.p(f'free(rsp);')
2208    ri.cw.block_end()
2209    ri.cw.nl()
2210
2211
2212def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2213    if terminate and ri and policy_should_be_static(struct.family):
2214        return
2215
2216    if terminate:
2217        prefix = 'extern '
2218    else:
2219        if ri and policy_should_be_static(struct.family):
2220            prefix = 'static '
2221        else:
2222            prefix = ''
2223
2224    suffix = ';' if terminate else ' = {'
2225
2226    max_attr = struct.attr_max_val
2227    if ri:
2228        name = ri.op.render_name
2229        if ri.op.dual_policy:
2230            name += '_' + ri.op_mode
2231    else:
2232        name = struct.render_name
2233    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2234
2235
2236def print_req_policy(cw, struct, ri=None):
2237    if ri and ri.op:
2238        cw.ifdef_block(ri.op.get('config-cond', None))
2239    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2240    for _, arg in struct.member_list():
2241        arg.attr_policy(cw)
2242    cw.p("};")
2243    cw.ifdef_block(None)
2244    cw.nl()
2245
2246
2247def kernel_can_gen_family_struct(family):
2248    return family.proto == 'genetlink'
2249
2250
2251def policy_should_be_static(family):
2252    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2253
2254
2255def print_kernel_policy_ranges(family, cw):
2256    first = True
2257    for _, attr_set in family.attr_sets.items():
2258        if attr_set.subset_of:
2259            continue
2260
2261        for _, attr in attr_set.items():
2262            if not attr.request:
2263                continue
2264            if 'full-range' not in attr.checks:
2265                continue
2266
2267            if first:
2268                cw.p('/* Integer value ranges */')
2269                first = False
2270
2271            sign = '' if attr.type[0] == 'u' else '_signed'
2272            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2273            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2274            members = []
2275            if 'min' in attr.checks:
2276                members.append(('min', attr.get_limit_str('min', suffix=suffix)))
2277            if 'max' in attr.checks:
2278                members.append(('max', attr.get_limit_str('max', suffix=suffix)))
2279            cw.write_struct_init(members)
2280            cw.block_end(line=';')
2281            cw.nl()
2282
2283
2284def print_kernel_op_table_fwd(family, cw, terminate):
2285    exported = not kernel_can_gen_family_struct(family)
2286
2287    if not terminate or exported:
2288        cw.p(f"/* Ops table for {family.ident_name} */")
2289
2290        pol_to_struct = {'global': 'genl_small_ops',
2291                         'per-op': 'genl_ops',
2292                         'split': 'genl_split_ops'}
2293        struct_type = pol_to_struct[family.kernel_policy]
2294
2295        if not exported:
2296            cnt = ""
2297        elif family.kernel_policy == 'split':
2298            cnt = 0
2299            for op in family.ops.values():
2300                if 'do' in op:
2301                    cnt += 1
2302                if 'dump' in op:
2303                    cnt += 1
2304        else:
2305            cnt = len(family.ops)
2306
2307        qual = 'static const' if not exported else 'const'
2308        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2309        if terminate:
2310            cw.p(f"extern {line};")
2311        else:
2312            cw.block_start(line=line + ' =')
2313
2314    if not terminate:
2315        return
2316
2317    cw.nl()
2318    for name in family.hooks['pre']['do']['list']:
2319        cw.write_func_prot('int', c_lower(name),
2320                           ['const struct genl_split_ops *ops',
2321                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2322    for name in family.hooks['post']['do']['list']:
2323        cw.write_func_prot('void', c_lower(name),
2324                           ['const struct genl_split_ops *ops',
2325                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2326    for name in family.hooks['pre']['dump']['list']:
2327        cw.write_func_prot('int', c_lower(name),
2328                           ['struct netlink_callback *cb'], suffix=';')
2329    for name in family.hooks['post']['dump']['list']:
2330        cw.write_func_prot('int', c_lower(name),
2331                           ['struct netlink_callback *cb'], suffix=';')
2332
2333    cw.nl()
2334
2335    for op_name, op in family.ops.items():
2336        if op.is_async:
2337            continue
2338
2339        if 'do' in op:
2340            name = c_lower(f"{family.ident_name}-nl-{op_name}-doit")
2341            cw.write_func_prot('int', name,
2342                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2343
2344        if 'dump' in op:
2345            name = c_lower(f"{family.ident_name}-nl-{op_name}-dumpit")
2346            cw.write_func_prot('int', name,
2347                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2348    cw.nl()
2349
2350
2351def print_kernel_op_table_hdr(family, cw):
2352    print_kernel_op_table_fwd(family, cw, terminate=True)
2353
2354
2355def print_kernel_op_table(family, cw):
2356    print_kernel_op_table_fwd(family, cw, terminate=False)
2357    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2358        for op_name, op in family.ops.items():
2359            if op.is_async:
2360                continue
2361
2362            cw.ifdef_block(op.get('config-cond', None))
2363            cw.block_start()
2364            members = [('cmd', op.enum_name)]
2365            if 'dont-validate' in op:
2366                members.append(('validate',
2367                                ' | '.join([c_upper('genl-dont-validate-' + x)
2368                                            for x in op['dont-validate']])), )
2369            for op_mode in ['do', 'dump']:
2370                if op_mode in op:
2371                    name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2372                    members.append((op_mode + 'it', name))
2373            if family.kernel_policy == 'per-op':
2374                struct = Struct(family, op['attribute-set'],
2375                                type_list=op['do']['request']['attributes'])
2376
2377                name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2378                members.append(('policy', name))
2379                members.append(('maxattr', struct.attr_max_val.enum_name))
2380            if 'flags' in op:
2381                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2382            cw.write_struct_init(members)
2383            cw.block_end(line=',')
2384    elif family.kernel_policy == 'split':
2385        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2386                    'dump': {'pre': 'start', 'post': 'done'}}
2387
2388        for op_name, op in family.ops.items():
2389            for op_mode in ['do', 'dump']:
2390                if op.is_async or op_mode not in op:
2391                    continue
2392
2393                cw.ifdef_block(op.get('config-cond', None))
2394                cw.block_start()
2395                members = [('cmd', op.enum_name)]
2396                if 'dont-validate' in op:
2397                    dont_validate = []
2398                    for x in op['dont-validate']:
2399                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2400                            continue
2401                        if op_mode == "dump" and x == 'strict':
2402                            continue
2403                        dont_validate.append(x)
2404
2405                    if dont_validate:
2406                        members.append(('validate',
2407                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2408                                                    for x in dont_validate])), )
2409                name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2410                if 'pre' in op[op_mode]:
2411                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2412                members.append((op_mode + 'it', name))
2413                if 'post' in op[op_mode]:
2414                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2415                if 'request' in op[op_mode]:
2416                    struct = Struct(family, op['attribute-set'],
2417                                    type_list=op[op_mode]['request']['attributes'])
2418
2419                    if op.dual_policy:
2420                        name = c_lower(f"{family.ident_name}-{op_name}-{op_mode}-nl-policy")
2421                    else:
2422                        name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2423                    members.append(('policy', name))
2424                    members.append(('maxattr', struct.attr_max_val.enum_name))
2425                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2426                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2427                cw.write_struct_init(members)
2428                cw.block_end(line=',')
2429    cw.ifdef_block(None)
2430
2431    cw.block_end(line=';')
2432    cw.nl()
2433
2434
2435def print_kernel_mcgrp_hdr(family, cw):
2436    if not family.mcgrps['list']:
2437        return
2438
2439    cw.block_start('enum')
2440    for grp in family.mcgrps['list']:
2441        grp_id = c_upper(f"{family.ident_name}-nlgrp-{grp['name']},")
2442        cw.p(grp_id)
2443    cw.block_end(';')
2444    cw.nl()
2445
2446
2447def print_kernel_mcgrp_src(family, cw):
2448    if not family.mcgrps['list']:
2449        return
2450
2451    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
2452    for grp in family.mcgrps['list']:
2453        name = grp['name']
2454        grp_id = c_upper(f"{family.ident_name}-nlgrp-{name}")
2455        cw.p('[' + grp_id + '] = { "' + name + '", },')
2456    cw.block_end(';')
2457    cw.nl()
2458
2459
2460def print_kernel_family_struct_hdr(family, cw):
2461    if not kernel_can_gen_family_struct(family):
2462        return
2463
2464    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
2465    cw.nl()
2466    if 'sock-priv' in family.kernel_family:
2467        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
2468        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
2469        cw.nl()
2470
2471
2472def print_kernel_family_struct_src(family, cw):
2473    if not kernel_can_gen_family_struct(family):
2474        return
2475
2476    if 'sock-priv' in family.kernel_family:
2477        # Generate "trampolines" to make CFI happy
2478        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_init",
2479                      [f"{family.c_name}_nl_sock_priv_init(priv);"],
2480                      ["void *priv"])
2481        cw.nl()
2482        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_destroy",
2483                      [f"{family.c_name}_nl_sock_priv_destroy(priv);"],
2484                      ["void *priv"])
2485        cw.nl()
2486
2487    cw.block_start(f"struct genl_family {family.ident_name}_nl_family __ro_after_init =")
2488    cw.p('.name\t\t= ' + family.fam_key + ',')
2489    cw.p('.version\t= ' + family.ver_key + ',')
2490    cw.p('.netnsok\t= true,')
2491    cw.p('.parallel_ops\t= true,')
2492    cw.p('.module\t\t= THIS_MODULE,')
2493    if family.kernel_policy == 'per-op':
2494        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
2495        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2496    elif family.kernel_policy == 'split':
2497        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
2498        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2499    if family.mcgrps['list']:
2500        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
2501        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
2502    if 'sock-priv' in family.kernel_family:
2503        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
2504        cw.p(f'.sock_priv_init\t= __{family.c_name}_nl_sock_priv_init,')
2505        cw.p(f'.sock_priv_destroy = __{family.c_name}_nl_sock_priv_destroy,')
2506    cw.block_end(';')
2507
2508
2509def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2510    start_line = 'enum'
2511    if enum_name in obj:
2512        if obj[enum_name]:
2513            start_line = 'enum ' + c_lower(obj[enum_name])
2514    elif ckey and ckey in obj:
2515        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
2516    cw.block_start(line=start_line)
2517
2518
2519def render_uapi_unified(family, cw, max_by_define, separate_ntf):
2520    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2521    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2522    max_value = f"({cnt_name} - 1)"
2523
2524    uapi_enum_start(family, cw, family['operations'], 'enum-name')
2525    val = 0
2526    for op in family.msgs.values():
2527        if separate_ntf and ('notify' in op or 'event' in op):
2528            continue
2529
2530        suffix = ','
2531        if op.value != val:
2532            suffix = f" = {op.value},"
2533            val = op.value
2534        cw.p(op.enum_name + suffix)
2535        val += 1
2536    cw.nl()
2537    cw.p(cnt_name + ('' if max_by_define else ','))
2538    if not max_by_define:
2539        cw.p(f"{max_name} = {max_value}")
2540    cw.block_end(line=';')
2541    if max_by_define:
2542        cw.p(f"#define {max_name} {max_value}")
2543    cw.nl()
2544
2545
2546def render_uapi_directional(family, cw, max_by_define):
2547    max_name = f"{family.op_prefix}USER_MAX"
2548    cnt_name = f"__{family.op_prefix}USER_CNT"
2549    max_value = f"({cnt_name} - 1)"
2550
2551    cw.block_start(line='enum')
2552    cw.p(c_upper(f'{family.name}_MSG_USER_NONE = 0,'))
2553    val = 0
2554    for op in family.msgs.values():
2555        if 'do' in op and 'event' not in op:
2556            suffix = ','
2557            if op.value and op.value != val:
2558                suffix = f" = {op.value},"
2559                val = op.value
2560            cw.p(op.enum_name + suffix)
2561            val += 1
2562    cw.nl()
2563    cw.p(cnt_name + ('' if max_by_define else ','))
2564    if not max_by_define:
2565        cw.p(f"{max_name} = {max_value}")
2566    cw.block_end(line=';')
2567    if max_by_define:
2568        cw.p(f"#define {max_name} {max_value}")
2569    cw.nl()
2570
2571    max_name = f"{family.op_prefix}KERNEL_MAX"
2572    cnt_name = f"__{family.op_prefix}KERNEL_CNT"
2573    max_value = f"({cnt_name} - 1)"
2574
2575    cw.block_start(line='enum')
2576    cw.p(c_upper(f'{family.name}_MSG_KERNEL_NONE = 0,'))
2577    val = 0
2578    for op in family.msgs.values():
2579        if ('do' in op and 'reply' in op['do']) or 'notify' in op or 'event' in op:
2580            enum_name = op.enum_name
2581            if 'event' not in op and 'notify' not in op:
2582                enum_name = f'{enum_name}_REPLY'
2583
2584            suffix = ','
2585            if op.value and op.value != val:
2586                suffix = f" = {op.value},"
2587                val = op.value
2588            cw.p(enum_name + suffix)
2589            val += 1
2590    cw.nl()
2591    cw.p(cnt_name + ('' if max_by_define else ','))
2592    if not max_by_define:
2593        cw.p(f"{max_name} = {max_value}")
2594    cw.block_end(line=';')
2595    if max_by_define:
2596        cw.p(f"#define {max_name} {max_value}")
2597    cw.nl()
2598
2599
2600def render_uapi(family, cw):
2601    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
2602    hdr_prot = hdr_prot.replace('/', '_')
2603    cw.p('#ifndef ' + hdr_prot)
2604    cw.p('#define ' + hdr_prot)
2605    cw.nl()
2606
2607    defines = [(family.fam_key, family["name"]),
2608               (family.ver_key, family.get('version', 1))]
2609    cw.writes_defines(defines)
2610    cw.nl()
2611
2612    defines = []
2613    for const in family['definitions']:
2614        if const.get('header'):
2615            continue
2616
2617        if const['type'] != 'const':
2618            cw.writes_defines(defines)
2619            defines = []
2620            cw.nl()
2621
2622        # Write kdoc for enum and flags (one day maybe also structs)
2623        if const['type'] == 'enum' or const['type'] == 'flags':
2624            enum = family.consts[const['name']]
2625
2626            if enum.header:
2627                continue
2628
2629            if enum.has_doc():
2630                if enum.has_entry_doc():
2631                    cw.p('/**')
2632                    doc = ''
2633                    if 'doc' in enum:
2634                        doc = ' - ' + enum['doc']
2635                    cw.write_doc_line(enum.enum_name + doc)
2636                else:
2637                    cw.p('/*')
2638                    cw.write_doc_line(enum['doc'], indent=False)
2639                for entry in enum.entries.values():
2640                    if entry.has_doc():
2641                        doc = '@' + entry.c_name + ': ' + entry['doc']
2642                        cw.write_doc_line(doc)
2643                cw.p(' */')
2644
2645            uapi_enum_start(family, cw, const, 'name')
2646            name_pfx = const.get('name-prefix', f"{family.ident_name}-{const['name']}-")
2647            for entry in enum.entries.values():
2648                suffix = ','
2649                if entry.value_change:
2650                    suffix = f" = {entry.user_value()}" + suffix
2651                cw.p(entry.c_name + suffix)
2652
2653            if const.get('render-max', False):
2654                cw.nl()
2655                cw.p('/* private: */')
2656                if const['type'] == 'flags':
2657                    max_name = c_upper(name_pfx + 'mask')
2658                    max_val = f' = {enum.get_mask()},'
2659                    cw.p(max_name + max_val)
2660                else:
2661                    cnt_name = enum.enum_cnt_name
2662                    max_name = c_upper(name_pfx + 'max')
2663                    if not cnt_name:
2664                        cnt_name = '__' + name_pfx + 'max'
2665                    cw.p(c_upper(cnt_name) + ',')
2666                    cw.p(max_name + ' = (' + c_upper(cnt_name) + ' - 1)')
2667            cw.block_end(line=';')
2668            cw.nl()
2669        elif const['type'] == 'const':
2670            defines.append([c_upper(family.get('c-define-name',
2671                                               f"{family.ident_name}-{const['name']}")),
2672                            const['value']])
2673
2674    if defines:
2675        cw.writes_defines(defines)
2676        cw.nl()
2677
2678    max_by_define = family.get('max-by-define', False)
2679
2680    for _, attr_set in family.attr_sets.items():
2681        if attr_set.subset_of:
2682            continue
2683
2684        max_value = f"({attr_set.cnt_name} - 1)"
2685
2686        val = 0
2687        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2688        for _, attr in attr_set.items():
2689            suffix = ','
2690            if attr.value != val:
2691                suffix = f" = {attr.value},"
2692                val = attr.value
2693            val += 1
2694            cw.p(attr.enum_name + suffix)
2695        if attr_set.items():
2696            cw.nl()
2697        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
2698        if not max_by_define:
2699            cw.p(f"{attr_set.max_name} = {max_value}")
2700        cw.block_end(line=';')
2701        if max_by_define:
2702            cw.p(f"#define {attr_set.max_name} {max_value}")
2703        cw.nl()
2704
2705    # Commands
2706    separate_ntf = 'async-prefix' in family['operations']
2707
2708    if family.msg_id_model == 'unified':
2709        render_uapi_unified(family, cw, max_by_define, separate_ntf)
2710    elif family.msg_id_model == 'directional':
2711        render_uapi_directional(family, cw, max_by_define)
2712    else:
2713        raise Exception(f'Unsupported message enum-model {family.msg_id_model}')
2714
2715    if separate_ntf:
2716        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2717        for op in family.msgs.values():
2718            if separate_ntf and not ('notify' in op or 'event' in op):
2719                continue
2720
2721            suffix = ','
2722            if 'value' in op:
2723                suffix = f" = {op['value']},"
2724            cw.p(op.enum_name + suffix)
2725        cw.block_end(line=';')
2726        cw.nl()
2727
2728    # Multicast
2729    defines = []
2730    for grp in family.mcgrps['list']:
2731        name = grp['name']
2732        defines.append([c_upper(grp.get('c-define-name', f"{family.ident_name}-mcgrp-{name}")),
2733                        f'{name}'])
2734    cw.nl()
2735    if defines:
2736        cw.writes_defines(defines)
2737        cw.nl()
2738
2739    cw.p(f'#endif /* {hdr_prot} */')
2740
2741
2742def _render_user_ntf_entry(ri, op):
2743    ri.cw.block_start(line=f"[{op.enum_name}] = ")
2744    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
2745    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
2746    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
2747    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
2748    ri.cw.block_end(line=',')
2749
2750
2751def render_user_family(family, cw, prototype):
2752    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
2753    if prototype:
2754        cw.p(f'extern {symbol};')
2755        return
2756
2757    if family.ntfs:
2758        cw.block_start(line=f"static const struct ynl_ntf_info {family['name']}_ntf_info[] = ")
2759        for ntf_op_name, ntf_op in family.ntfs.items():
2760            if 'notify' in ntf_op:
2761                op = family.ops[ntf_op['notify']]
2762                ri = RenderInfo(cw, family, "user", op, "notify")
2763            elif 'event' in ntf_op:
2764                ri = RenderInfo(cw, family, "user", ntf_op, "event")
2765            else:
2766                raise Exception('Invalid notification ' + ntf_op_name)
2767            _render_user_ntf_entry(ri, ntf_op)
2768        for op_name, op in family.ops.items():
2769            if 'event' not in op:
2770                continue
2771            ri = RenderInfo(cw, family, "user", op, "event")
2772            _render_user_ntf_entry(ri, op)
2773        cw.block_end(line=";")
2774        cw.nl()
2775
2776    cw.block_start(f'{symbol} = ')
2777    cw.p(f'.name\t\t= "{family.c_name}",')
2778    if family.fixed_header:
2779        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
2780    else:
2781        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
2782    if family.ntfs:
2783        cw.p(f".ntf_info\t= {family['name']}_ntf_info,")
2784        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family['name']}_ntf_info),")
2785    cw.block_end(line=';')
2786
2787
2788def family_contains_bitfield32(family):
2789    for _, attr_set in family.attr_sets.items():
2790        if attr_set.subset_of:
2791            continue
2792        for _, attr in attr_set.items():
2793            if attr.type == "bitfield32":
2794                return True
2795    return False
2796
2797
2798def find_kernel_root(full_path):
2799    sub_path = ''
2800    while True:
2801        sub_path = os.path.join(os.path.basename(full_path), sub_path)
2802        full_path = os.path.dirname(full_path)
2803        maintainers = os.path.join(full_path, "MAINTAINERS")
2804        if os.path.exists(maintainers):
2805            return full_path, sub_path[:-1]
2806
2807
2808def main():
2809    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2810    parser.add_argument('--mode', dest='mode', type=str, required=True,
2811                        choices=('user', 'kernel', 'uapi'))
2812    parser.add_argument('--spec', dest='spec', type=str, required=True)
2813    parser.add_argument('--header', dest='header', action='store_true', default=None)
2814    parser.add_argument('--source', dest='header', action='store_false')
2815    parser.add_argument('--user-header', nargs='+', default=[])
2816    parser.add_argument('--cmp-out', action='store_true', default=None,
2817                        help='Do not overwrite the output file if the new output is identical to the old')
2818    parser.add_argument('--exclude-op', action='append', default=[])
2819    parser.add_argument('-o', dest='out_file', type=str, default=None)
2820    args = parser.parse_args()
2821
2822    if args.header is None:
2823        parser.error("--header or --source is required")
2824
2825    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
2826
2827    try:
2828        parsed = Family(args.spec, exclude_ops)
2829        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
2830            print('Spec license:', parsed.license)
2831            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
2832            os.sys.exit(1)
2833    except yaml.YAMLError as exc:
2834        print(exc)
2835        os.sys.exit(1)
2836        return
2837
2838    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
2839
2840    _, spec_kernel = find_kernel_root(args.spec)
2841    if args.mode == 'uapi' or args.header:
2842        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
2843    else:
2844        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
2845    cw.p("/* Do not edit directly, auto-generated from: */")
2846    cw.p(f"/*\t{spec_kernel} */")
2847    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2848    if args.exclude_op or args.user_header:
2849        line = ''
2850        line += ' --user-header '.join([''] + args.user_header)
2851        line += ' --exclude-op '.join([''] + args.exclude_op)
2852        cw.p(f'/* YNL-ARG{line} */')
2853    cw.nl()
2854
2855    if args.mode == 'uapi':
2856        render_uapi(parsed, cw)
2857        return
2858
2859    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
2860    if args.header:
2861        cw.p('#ifndef ' + hdr_prot)
2862        cw.p('#define ' + hdr_prot)
2863        cw.nl()
2864
2865    if args.out_file:
2866        hdr_file = os.path.basename(args.out_file[:-2]) + ".h"
2867    else:
2868        hdr_file = "generated_header_file.h"
2869
2870    if args.mode == 'kernel':
2871        cw.p('#include <net/netlink.h>')
2872        cw.p('#include <net/genetlink.h>')
2873        cw.nl()
2874        if not args.header:
2875            if args.out_file:
2876                cw.p(f'#include "{hdr_file}"')
2877            cw.nl()
2878        headers = ['uapi/' + parsed.uapi_header]
2879        headers += parsed.kernel_family.get('headers', [])
2880    else:
2881        cw.p('#include <stdlib.h>')
2882        cw.p('#include <string.h>')
2883        if args.header:
2884            cw.p('#include <linux/types.h>')
2885            if family_contains_bitfield32(parsed):
2886                cw.p('#include <linux/netlink.h>')
2887        else:
2888            cw.p(f'#include "{hdr_file}"')
2889            cw.p('#include "ynl.h"')
2890        headers = []
2891    for definition in parsed['definitions']:
2892        if 'header' in definition:
2893            headers.append(definition['header'])
2894    if args.mode == 'user':
2895        headers.append(parsed.uapi_header)
2896    seen_header = []
2897    for one in headers:
2898        if one not in seen_header:
2899            cw.p(f"#include <{one}>")
2900            seen_header.append(one)
2901    cw.nl()
2902
2903    if args.mode == "user":
2904        if not args.header:
2905            cw.p("#include <linux/genetlink.h>")
2906            cw.nl()
2907            for one in args.user_header:
2908                cw.p(f'#include "{one}"')
2909        else:
2910            cw.p('struct ynl_sock;')
2911            cw.nl()
2912            render_user_family(parsed, cw, True)
2913        cw.nl()
2914
2915    if args.mode == "kernel":
2916        if args.header:
2917            for _, struct in sorted(parsed.pure_nested_structs.items()):
2918                if struct.request:
2919                    cw.p('/* Common nested types */')
2920                    break
2921            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2922                if struct.request:
2923                    print_req_policy_fwd(cw, struct)
2924            cw.nl()
2925
2926            if parsed.kernel_policy == 'global':
2927                cw.p(f"/* Global operation policy for {parsed.name} */")
2928
2929                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2930                print_req_policy_fwd(cw, struct)
2931                cw.nl()
2932
2933            if parsed.kernel_policy in {'per-op', 'split'}:
2934                for op_name, op in parsed.ops.items():
2935                    if 'do' in op and 'event' not in op:
2936                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
2937                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2938                        cw.nl()
2939
2940            print_kernel_op_table_hdr(parsed, cw)
2941            print_kernel_mcgrp_hdr(parsed, cw)
2942            print_kernel_family_struct_hdr(parsed, cw)
2943        else:
2944            print_kernel_policy_ranges(parsed, cw)
2945
2946            for _, struct in sorted(parsed.pure_nested_structs.items()):
2947                if struct.request:
2948                    cw.p('/* Common nested types */')
2949                    break
2950            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2951                if struct.request:
2952                    print_req_policy(cw, struct)
2953            cw.nl()
2954
2955            if parsed.kernel_policy == 'global':
2956                cw.p(f"/* Global operation policy for {parsed.name} */")
2957
2958                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2959                print_req_policy(cw, struct)
2960                cw.nl()
2961
2962            for op_name, op in parsed.ops.items():
2963                if parsed.kernel_policy in {'per-op', 'split'}:
2964                    for op_mode in ['do', 'dump']:
2965                        if op_mode in op and 'request' in op[op_mode]:
2966                            cw.p(f"/* {op.enum_name} - {op_mode} */")
2967                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
2968                            print_req_policy(cw, ri.struct['request'], ri=ri)
2969                            cw.nl()
2970
2971            print_kernel_op_table(parsed, cw)
2972            print_kernel_mcgrp_src(parsed, cw)
2973            print_kernel_family_struct_src(parsed, cw)
2974
2975    if args.mode == "user":
2976        if args.header:
2977            cw.p('/* Enums */')
2978            put_op_name_fwd(parsed, cw)
2979
2980            for name, const in parsed.consts.items():
2981                if isinstance(const, EnumSet):
2982                    put_enum_to_str_fwd(parsed, cw, const)
2983            cw.nl()
2984
2985            cw.p('/* Common nested types */')
2986            for attr_set, struct in parsed.pure_nested_structs.items():
2987                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2988                print_type_full(ri, struct)
2989                if struct.request and struct.in_multi_val:
2990                    free_rsp_nested_prototype(ri)
2991                    cw.nl()
2992
2993            for op_name, op in parsed.ops.items():
2994                cw.p(f"/* ============== {op.enum_name} ============== */")
2995
2996                if 'do' in op and 'event' not in op:
2997                    cw.p(f"/* {op.enum_name} - do */")
2998                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2999                    print_req_type(ri)
3000                    print_req_type_helpers(ri)
3001                    cw.nl()
3002                    print_rsp_type(ri)
3003                    print_rsp_type_helpers(ri)
3004                    cw.nl()
3005                    print_req_prototype(ri)
3006                    cw.nl()
3007
3008                if 'dump' in op:
3009                    cw.p(f"/* {op.enum_name} - dump */")
3010                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
3011                    print_req_type(ri)
3012                    print_req_type_helpers(ri)
3013                    if not ri.type_consistent:
3014                        print_rsp_type(ri)
3015                    print_wrapped_type(ri)
3016                    print_dump_prototype(ri)
3017                    cw.nl()
3018
3019                if op.has_ntf:
3020                    cw.p(f"/* {op.enum_name} - notify */")
3021                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3022                    if not ri.type_consistent:
3023                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3024                    print_wrapped_type(ri)
3025
3026            for op_name, op in parsed.ntfs.items():
3027                if 'event' in op:
3028                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
3029                    cw.p(f"/* {op.enum_name} - event */")
3030                    print_rsp_type(ri)
3031                    cw.nl()
3032                    print_wrapped_type(ri)
3033            cw.nl()
3034        else:
3035            cw.p('/* Enums */')
3036            put_op_name(parsed, cw)
3037
3038            for name, const in parsed.consts.items():
3039                if isinstance(const, EnumSet):
3040                    put_enum_to_str(parsed, cw, const)
3041            cw.nl()
3042
3043            has_recursive_nests = False
3044            cw.p('/* Policies */')
3045            for struct in parsed.pure_nested_structs.values():
3046                if struct.recursive:
3047                    put_typol_fwd(cw, struct)
3048                    has_recursive_nests = True
3049            if has_recursive_nests:
3050                cw.nl()
3051            for name in parsed.pure_nested_structs:
3052                struct = Struct(parsed, name)
3053                put_typol(cw, struct)
3054            for name in parsed.root_sets:
3055                struct = Struct(parsed, name)
3056                put_typol(cw, struct)
3057
3058            cw.p('/* Common nested types */')
3059            if has_recursive_nests:
3060                for attr_set, struct in parsed.pure_nested_structs.items():
3061                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3062                    free_rsp_nested_prototype(ri)
3063                    if struct.request:
3064                        put_req_nested_prototype(ri, struct)
3065                    if struct.reply:
3066                        parse_rsp_nested_prototype(ri, struct)
3067                cw.nl()
3068            for attr_set, struct in parsed.pure_nested_structs.items():
3069                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3070
3071                free_rsp_nested(ri, struct)
3072                if struct.request:
3073                    put_req_nested(ri, struct)
3074                if struct.reply:
3075                    parse_rsp_nested(ri, struct)
3076
3077            for op_name, op in parsed.ops.items():
3078                cw.p(f"/* ============== {op.enum_name} ============== */")
3079                if 'do' in op and 'event' not in op:
3080                    cw.p(f"/* {op.enum_name} - do */")
3081                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3082                    print_req_free(ri)
3083                    print_rsp_free(ri)
3084                    parse_rsp_msg(ri)
3085                    print_req(ri)
3086                    cw.nl()
3087
3088                if 'dump' in op:
3089                    cw.p(f"/* {op.enum_name} - dump */")
3090                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
3091                    if not ri.type_consistent:
3092                        parse_rsp_msg(ri, deref=True)
3093                    print_req_free(ri)
3094                    print_dump_type_free(ri)
3095                    print_dump(ri)
3096                    cw.nl()
3097
3098                if op.has_ntf:
3099                    cw.p(f"/* {op.enum_name} - notify */")
3100                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3101                    if not ri.type_consistent:
3102                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3103                    print_ntf_type_free(ri)
3104
3105            for op_name, op in parsed.ntfs.items():
3106                if 'event' in op:
3107                    cw.p(f"/* {op.enum_name} - event */")
3108
3109                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3110                    parse_rsp_msg(ri)
3111
3112                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
3113                    print_ntf_type_free(ri)
3114            cw.nl()
3115            render_user_family(parsed, cw, False)
3116
3117    if args.header:
3118        cw.p(f'#endif /* {hdr_prot} */')
3119
3120
3121if __name__ == "__main__":
3122    main()
3123