xref: /linux/tools/net/ynl/ynl-gen-c.py (revision 06ce23ad57c8e378b86ef3f439b2e08bcb5d05eb)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import os
7import yaml
8
9from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
10
11
12def c_upper(name):
13    return name.upper().replace('-', '_')
14
15
16def c_lower(name):
17    return name.lower().replace('-', '_')
18
19
20class BaseNlLib:
21    def get_family_id(self):
22        return 'ys->family_id'
23
24    def parse_cb_run(self, cb, data, is_dump=False, indent=1):
25        ind = '\n\t\t' + '\t' * indent + ' '
26        if is_dump:
27            return f"mnl_cb_run2(ys->rx_buf, len, 0, 0, {cb}, {data},{ind}ynl_cb_array, NLMSG_MIN_TYPE)"
28        else:
29            return f"mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid,{ind}{cb}, {data},{ind}" + \
30                   "ynl_cb_array, NLMSG_MIN_TYPE)"
31
32
33class Type(SpecAttr):
34    def __init__(self, family, attr_set, attr, value):
35        super().__init__(family, attr_set, attr, value)
36
37        self.attr = attr
38        self.attr_set = attr_set
39        self.type = attr['type']
40        self.checks = attr.get('checks', {})
41
42        if 'len' in attr:
43            self.len = attr['len']
44        if 'nested-attributes' in attr:
45            self.nested_attrs = attr['nested-attributes']
46            if self.nested_attrs == family.name:
47                self.nested_render_name = f"{family.name}"
48            else:
49                self.nested_render_name = f"{family.name}_{c_lower(self.nested_attrs)}"
50
51        self.c_name = c_lower(self.name)
52        if self.c_name in _C_KW:
53            self.c_name += '_'
54
55        # Added by resolve():
56        self.enum_name = None
57        delattr(self, "enum_name")
58
59    def resolve(self):
60        self.enum_name = f"{self.attr_set.name_prefix}{self.name}"
61        self.enum_name = c_upper(self.enum_name)
62
63    def is_multi_val(self):
64        return None
65
66    def is_scalar(self):
67        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
68
69    def presence_type(self):
70        return 'bit'
71
72    def presence_member(self, space, type_filter):
73        if self.presence_type() != type_filter:
74            return
75
76        if self.presence_type() == 'bit':
77            pfx = '__' if space == 'user' else ''
78            return f"{pfx}u32 {self.c_name}:1;"
79
80        if self.presence_type() == 'len':
81            pfx = '__' if space == 'user' else ''
82            return f"{pfx}u32 {self.c_name}_len;"
83
84    def _complex_member_type(self, ri):
85        return None
86
87    def free_needs_iter(self):
88        return False
89
90    def free(self, ri, var, ref):
91        if self.is_multi_val() or self.presence_type() == 'len':
92            ri.cw.p(f'free({var}->{ref}{self.c_name});')
93
94    def arg_member(self, ri):
95        member = self._complex_member_type(ri)
96        if member:
97            return [member + ' *' + self.c_name]
98        raise Exception(f"Struct member not implemented for class type {self.type}")
99
100    def struct_member(self, ri):
101        if self.is_multi_val():
102            ri.cw.p(f"unsigned int n_{self.c_name};")
103        member = self._complex_member_type(ri)
104        if member:
105            ptr = '*' if self.is_multi_val() else ''
106            ri.cw.p(f"{member} {ptr}{self.c_name};")
107            return
108        members = self.arg_member(ri)
109        for one in members:
110            ri.cw.p(one + ';')
111
112    def _attr_policy(self, policy):
113        return '{ .type = ' + policy + ', }'
114
115    def attr_policy(self, cw):
116        policy = c_upper('nla-' + self.attr['type'])
117
118        spec = self._attr_policy(policy)
119        cw.p(f"\t[{self.enum_name}] = {spec},")
120
121    def _attr_typol(self):
122        raise Exception(f"Type policy not implemented for class type {self.type}")
123
124    def attr_typol(self, cw):
125        typol = self._attr_typol()
126        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
127
128    def _attr_put_line(self, ri, var, line):
129        if self.presence_type() == 'bit':
130            ri.cw.p(f"if ({var}->_present.{self.c_name})")
131        elif self.presence_type() == 'len':
132            ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
133        ri.cw.p(f"{line};")
134
135    def _attr_put_simple(self, ri, var, put_type):
136        line = f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
137        self._attr_put_line(ri, var, line)
138
139    def attr_put(self, ri, var):
140        raise Exception(f"Put not implemented for class type {self.type}")
141
142    def _attr_get(self, ri, var):
143        raise Exception(f"Attr get not implemented for class type {self.type}")
144
145    def attr_get(self, ri, var, first):
146        lines, init_lines, local_vars = self._attr_get(ri, var)
147        if type(lines) is str:
148            lines = [lines]
149        if type(init_lines) is str:
150            init_lines = [init_lines]
151
152        kw = 'if' if first else 'else if'
153        ri.cw.block_start(line=f"{kw} (mnl_attr_get_type(attr) == {self.enum_name})")
154        if local_vars:
155            for local in local_vars:
156                ri.cw.p(local)
157            ri.cw.nl()
158
159        if not self.is_multi_val():
160            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
161            ri.cw.p("return MNL_CB_ERROR;")
162            if self.presence_type() == 'bit':
163                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
164
165        if init_lines:
166            ri.cw.nl()
167            for line in init_lines:
168                ri.cw.p(line)
169
170        for line in lines:
171            ri.cw.p(line)
172        ri.cw.block_end()
173
174    def _setter_lines(self, ri, member, presence):
175        raise Exception(f"Setter not implemented for class type {self.type}")
176
177    def setter(self, ri, space, direction, deref=False, ref=None):
178        ref = (ref if ref else []) + [self.c_name]
179        var = "req"
180        member = f"{var}->{'.'.join(ref)}"
181
182        code = []
183        presence = ''
184        for i in range(0, len(ref)):
185            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
186            if self.presence_type() == 'bit':
187                code.append(presence + ' = 1;')
188        code += self._setter_lines(ri, member, presence)
189
190        ri.cw.write_func('static inline void',
191                         f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}",
192                         body=code,
193                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
194
195
196class TypeUnused(Type):
197    def presence_type(self):
198        return ''
199
200    def _attr_typol(self):
201        return '.type = YNL_PT_REJECT, '
202
203    def attr_policy(self, cw):
204        pass
205
206
207class TypePad(Type):
208    def presence_type(self):
209        return ''
210
211    def _attr_typol(self):
212        return '.type = YNL_PT_REJECT, '
213
214    def attr_policy(self, cw):
215        pass
216
217
218class TypeScalar(Type):
219    def __init__(self, family, attr_set, attr, value):
220        super().__init__(family, attr_set, attr, value)
221
222        self.byte_order_comment = ''
223        if 'byte-order' in attr:
224            self.byte_order_comment = f" /* {attr['byte-order']} */"
225
226        # Added by resolve():
227        self.is_bitfield = None
228        delattr(self, "is_bitfield")
229        self.type_name = None
230        delattr(self, "type_name")
231
232    def resolve(self):
233        self.resolve_up(super())
234
235        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
236            self.is_bitfield = True
237        elif 'enum' in self.attr:
238            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
239        else:
240            self.is_bitfield = False
241
242        if 'enum' in self.attr and not self.is_bitfield:
243            self.type_name = f"enum {self.family.name}_{c_lower(self.attr['enum'])}"
244        else:
245            self.type_name = '__' + self.type
246
247    def _mnl_type(self):
248        t = self.type
249        # mnl does not have a helper for signed types
250        if t[0] == 's':
251            t = 'u' + t[1:]
252        return t
253
254    def _attr_policy(self, policy):
255        if 'flags-mask' in self.checks or self.is_bitfield:
256            if self.is_bitfield:
257                enum = self.family.consts[self.attr['enum']]
258                mask = enum.get_mask(as_flags=True)
259            else:
260                flags = self.family.consts[self.checks['flags-mask']]
261                flag_cnt = len(flags['entries'])
262                mask = (1 << flag_cnt) - 1
263            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
264        elif 'min' in self.checks:
265            return f"NLA_POLICY_MIN({policy}, {self.checks['min']})"
266        elif 'enum' in self.attr:
267            enum = self.family.consts[self.attr['enum']]
268            cnt = len(enum['entries'])
269            return f"NLA_POLICY_MAX({policy}, {cnt - 1})"
270        return super()._attr_policy(policy)
271
272    def _attr_typol(self):
273        return f'.type = YNL_PT_U{self.type[1:]}, '
274
275    def arg_member(self, ri):
276        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
277
278    def attr_put(self, ri, var):
279        self._attr_put_simple(ri, var, self._mnl_type())
280
281    def _attr_get(self, ri, var):
282        return f"{var}->{self.c_name} = mnl_attr_get_{self._mnl_type()}(attr);", None, None
283
284    def _setter_lines(self, ri, member, presence):
285        return [f"{member} = {self.c_name};"]
286
287
288class TypeFlag(Type):
289    def arg_member(self, ri):
290        return []
291
292    def _attr_typol(self):
293        return '.type = YNL_PT_FLAG, '
294
295    def attr_put(self, ri, var):
296        self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, 0, NULL)")
297
298    def _attr_get(self, ri, var):
299        return [], None, None
300
301    def _setter_lines(self, ri, member, presence):
302        return []
303
304
305class TypeString(Type):
306    def arg_member(self, ri):
307        return [f"const char *{self.c_name}"]
308
309    def presence_type(self):
310        return 'len'
311
312    def struct_member(self, ri):
313        ri.cw.p(f"char *{self.c_name};")
314
315    def _attr_typol(self):
316        return f'.type = YNL_PT_NUL_STR, '
317
318    def _attr_policy(self, policy):
319        mem = '{ .type = ' + policy
320        if 'max-len' in self.checks:
321            mem += ', .len = ' + str(self.checks['max-len'])
322        mem += ', }'
323        return mem
324
325    def attr_policy(self, cw):
326        if self.checks.get('unterminated-ok', False):
327            policy = 'NLA_STRING'
328        else:
329            policy = 'NLA_NUL_STRING'
330
331        spec = self._attr_policy(policy)
332        cw.p(f"\t[{self.enum_name}] = {spec},")
333
334    def attr_put(self, ri, var):
335        self._attr_put_simple(ri, var, 'strz')
336
337    def _attr_get(self, ri, var):
338        len_mem = var + '->_present.' + self.c_name + '_len'
339        return [f"{len_mem} = len;",
340                f"{var}->{self.c_name} = malloc(len + 1);",
341                f"memcpy({var}->{self.c_name}, mnl_attr_get_str(attr), len);",
342                f"{var}->{self.c_name}[len] = 0;"], \
343               ['len = strnlen(mnl_attr_get_str(attr), mnl_attr_get_payload_len(attr));'], \
344               ['unsigned int len;']
345
346    def _setter_lines(self, ri, member, presence):
347        return [f"free({member});",
348                f"{presence}_len = strlen({self.c_name});",
349                f"{member} = malloc({presence}_len + 1);",
350                f'memcpy({member}, {self.c_name}, {presence}_len);',
351                f'{member}[{presence}_len] = 0;']
352
353
354class TypeBinary(Type):
355    def arg_member(self, ri):
356        return [f"const void *{self.c_name}", 'size_t len']
357
358    def presence_type(self):
359        return 'len'
360
361    def struct_member(self, ri):
362        ri.cw.p(f"void *{self.c_name};")
363
364    def _attr_typol(self):
365        return f'.type = YNL_PT_BINARY,'
366
367    def _attr_policy(self, policy):
368        mem = '{ '
369        if len(self.checks) == 1 and 'min-len' in self.checks:
370            mem += '.len = ' + str(self.checks['min-len'])
371        elif len(self.checks) == 0:
372            mem += '.type = NLA_BINARY'
373        else:
374            raise Exception('One or more of binary type checks not implemented, yet')
375        mem += ', }'
376        return mem
377
378    def attr_put(self, ri, var):
379        self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, " +
380                            f"{var}->_present.{self.c_name}_len, {var}->{self.c_name})")
381
382    def _attr_get(self, ri, var):
383        len_mem = var + '->_present.' + self.c_name + '_len'
384        return [f"{len_mem} = len;",
385                f"{var}->{self.c_name} = malloc(len);",
386                f"memcpy({var}->{self.c_name}, mnl_attr_get_payload(attr), len);"], \
387               ['len = mnl_attr_get_payload_len(attr);'], \
388               ['unsigned int len;']
389
390    def _setter_lines(self, ri, member, presence):
391        return [f"free({member});",
392                f"{member} = malloc({presence}_len);",
393                f'memcpy({member}, {self.c_name}, {presence}_len);']
394
395
396class TypeNest(Type):
397    def _complex_member_type(self, ri):
398        return f"struct {self.nested_render_name}"
399
400    def free(self, ri, var, ref):
401        ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name});')
402
403    def _attr_typol(self):
404        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
405
406    def _attr_policy(self, policy):
407        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
408
409    def attr_put(self, ri, var):
410        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
411                            f"{self.enum_name}, &{var}->{self.c_name})")
412
413    def _attr_get(self, ri, var):
414        get_lines = [f"{self.nested_render_name}_parse(&parg, attr);"]
415        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
416                      f"parg.data = &{var}->{self.c_name};"]
417        return get_lines, init_lines, None
418
419    def setter(self, ri, space, direction, deref=False, ref=None):
420        ref = (ref if ref else []) + [self.c_name]
421
422        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
423            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
424
425
426class TypeMultiAttr(Type):
427    def is_multi_val(self):
428        return True
429
430    def presence_type(self):
431        return 'count'
432
433    def _complex_member_type(self, ri):
434        if 'type' not in self.attr or self.attr['type'] == 'nest':
435            return f"struct {self.nested_render_name}"
436        elif self.attr['type'] in scalars:
437            scalar_pfx = '__' if ri.ku_space == 'user' else ''
438            return scalar_pfx + self.attr['type']
439        else:
440            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
441
442    def free_needs_iter(self):
443        return 'type' not in self.attr or self.attr['type'] == 'nest'
444
445    def free(self, ri, var, ref):
446        if 'type' not in self.attr or self.attr['type'] == 'nest':
447            ri.cw.p(f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)")
448            ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);')
449
450    def _attr_typol(self):
451        if 'type' not in self.attr or self.attr['type'] == 'nest':
452            return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
453        elif self.attr['type'] in scalars:
454            return f".type = YNL_PT_U{self.attr['type'][1:]}, "
455        else:
456            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
457
458    def _attr_get(self, ri, var):
459        return f'{var}->n_{self.c_name}++;', None, None
460
461
462class TypeArrayNest(Type):
463    def is_multi_val(self):
464        return True
465
466    def presence_type(self):
467        return 'count'
468
469    def _complex_member_type(self, ri):
470        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
471            return f"struct {self.nested_render_name}"
472        elif self.attr['sub-type'] in scalars:
473            scalar_pfx = '__' if ri.ku_space == 'user' else ''
474            return scalar_pfx + self.attr['sub-type']
475        else:
476            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
477
478    def _attr_typol(self):
479        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
480
481    def _attr_get(self, ri, var):
482        local_vars = ['const struct nlattr *attr2;']
483        get_lines = [f'attr_{self.c_name} = attr;',
484                     'mnl_attr_for_each_nested(attr2, attr)',
485                     f'\t{var}->n_{self.c_name}++;']
486        return get_lines, None, local_vars
487
488
489class TypeNestTypeValue(Type):
490    def _complex_member_type(self, ri):
491        return f"struct {self.nested_render_name}"
492
493    def _attr_typol(self):
494        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
495
496    def _attr_get(self, ri, var):
497        prev = 'attr'
498        tv_args = ''
499        get_lines = []
500        local_vars = []
501        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
502                      f"parg.data = &{var}->{self.c_name};"]
503        if 'type-value' in self.attr:
504            tv_names = [c_lower(x) for x in self.attr["type-value"]]
505            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
506            local_vars += [f'__u32 {", ".join(tv_names)};']
507            for level in self.attr["type-value"]:
508                level = c_lower(level)
509                get_lines += [f'attr_{level} = mnl_attr_get_payload({prev});']
510                get_lines += [f'{level} = mnl_attr_get_type(attr_{level});']
511                prev = 'attr_' + level
512
513            tv_args = f", {', '.join(tv_names)}"
514
515        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
516        return get_lines, init_lines, local_vars
517
518
519class Struct:
520    def __init__(self, family, space_name, type_list=None, inherited=None):
521        self.family = family
522        self.space_name = space_name
523        self.attr_set = family.attr_sets[space_name]
524        # Use list to catch comparisons with empty sets
525        self._inherited = inherited if inherited is not None else []
526        self.inherited = []
527
528        self.nested = type_list is None
529        if family.name == c_lower(space_name):
530            self.render_name = f"{family.name}"
531        else:
532            self.render_name = f"{family.name}_{c_lower(space_name)}"
533        self.struct_name = 'struct ' + self.render_name
534        self.ptr_name = self.struct_name + ' *'
535
536        self.request = False
537        self.reply = False
538
539        self.attr_list = []
540        self.attrs = dict()
541        if type_list:
542            for t in type_list:
543                self.attr_list.append((t, self.attr_set[t]),)
544        else:
545            for t in self.attr_set:
546                self.attr_list.append((t, self.attr_set[t]),)
547
548        max_val = 0
549        self.attr_max_val = None
550        for name, attr in self.attr_list:
551            if attr.value >= max_val:
552                max_val = attr.value
553                self.attr_max_val = attr
554            self.attrs[name] = attr
555
556    def __iter__(self):
557        yield from self.attrs
558
559    def __getitem__(self, key):
560        return self.attrs[key]
561
562    def member_list(self):
563        return self.attr_list
564
565    def set_inherited(self, new_inherited):
566        if self._inherited != new_inherited:
567            raise Exception("Inheriting different members not supported")
568        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
569
570
571class EnumEntry(SpecEnumEntry):
572    def __init__(self, enum_set, yaml, prev, value_start):
573        super().__init__(enum_set, yaml, prev, value_start)
574
575        if prev:
576            self.value_change = (self.value != prev.value + 1)
577        else:
578            self.value_change = (self.value != 0)
579        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
580
581        # Added by resolve:
582        self.c_name = None
583        delattr(self, "c_name")
584
585    def resolve(self):
586        self.resolve_up(super())
587
588        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
589
590
591class EnumSet(SpecEnumSet):
592    def __init__(self, family, yaml):
593        self.render_name = c_lower(family.name + '-' + yaml['name'])
594        self.enum_name = 'enum ' + self.render_name
595
596        self.value_pfx = yaml.get('name-prefix', f"{family.name}-{yaml['name']}-")
597
598        super().__init__(family, yaml)
599
600    def new_entry(self, entry, prev_entry, value_start):
601        return EnumEntry(self, entry, prev_entry, value_start)
602
603
604class AttrSet(SpecAttrSet):
605    def __init__(self, family, yaml):
606        super().__init__(family, yaml)
607
608        if self.subset_of is None:
609            if 'name-prefix' in yaml:
610                pfx = yaml['name-prefix']
611            elif self.name == family.name:
612                pfx = family.name + '-a-'
613            else:
614                pfx = f"{family.name}-a-{self.name}-"
615            self.name_prefix = c_upper(pfx)
616            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
617        else:
618            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
619            self.max_name = family.attr_sets[self.subset_of].max_name
620
621        # Added by resolve:
622        self.c_name = None
623        delattr(self, "c_name")
624
625    def resolve(self):
626        self.c_name = c_lower(self.name)
627        if self.c_name in _C_KW:
628            self.c_name += '_'
629        if self.c_name == self.family.c_name:
630            self.c_name = ''
631
632    def new_attr(self, elem, value):
633        if 'multi-attr' in elem and elem['multi-attr']:
634            return TypeMultiAttr(self.family, self, elem, value)
635        elif elem['type'] in scalars:
636            return TypeScalar(self.family, self, elem, value)
637        elif elem['type'] == 'unused':
638            return TypeUnused(self.family, self, elem, value)
639        elif elem['type'] == 'pad':
640            return TypePad(self.family, self, elem, value)
641        elif elem['type'] == 'flag':
642            return TypeFlag(self.family, self, elem, value)
643        elif elem['type'] == 'string':
644            return TypeString(self.family, self, elem, value)
645        elif elem['type'] == 'binary':
646            return TypeBinary(self.family, self, elem, value)
647        elif elem['type'] == 'nest':
648            return TypeNest(self.family, self, elem, value)
649        elif elem['type'] == 'array-nest':
650            return TypeArrayNest(self.family, self, elem, value)
651        elif elem['type'] == 'nest-type-value':
652            return TypeNestTypeValue(self.family, self, elem, value)
653        else:
654            raise Exception(f"No typed class for type {elem['type']}")
655
656
657class Operation(SpecOperation):
658    def __init__(self, family, yaml, req_value, rsp_value):
659        super().__init__(family, yaml, req_value, rsp_value)
660
661        if req_value != rsp_value:
662            raise Exception("Directional messages not supported by codegen")
663
664        self.render_name = family.name + '_' + c_lower(self.name)
665
666        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
667                         ('dump' in yaml and 'request' in yaml['dump'])
668
669        # Added by resolve:
670        self.enum_name = None
671        delattr(self, "enum_name")
672
673    def resolve(self):
674        self.resolve_up(super())
675
676        if not self.is_async:
677            self.enum_name = self.family.op_prefix + c_upper(self.name)
678        else:
679            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
680
681    def add_notification(self, op):
682        if 'notify' not in self.yaml:
683            self.yaml['notify'] = dict()
684            self.yaml['notify']['reply'] = self.yaml['do']['reply']
685            self.yaml['notify']['cmds'] = []
686        self.yaml['notify']['cmds'].append(op)
687
688
689class Family(SpecFamily):
690    def __init__(self, file_name):
691        # Added by resolve:
692        self.c_name = None
693        delattr(self, "c_name")
694        self.op_prefix = None
695        delattr(self, "op_prefix")
696        self.async_op_prefix = None
697        delattr(self, "async_op_prefix")
698        self.mcgrps = None
699        delattr(self, "mcgrps")
700        self.consts = None
701        delattr(self, "consts")
702        self.hooks = None
703        delattr(self, "hooks")
704
705        super().__init__(file_name)
706
707        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
708        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
709
710        if 'definitions' not in self.yaml:
711            self.yaml['definitions'] = []
712
713        if 'uapi-header' in self.yaml:
714            self.uapi_header = self.yaml['uapi-header']
715        else:
716            self.uapi_header = f"linux/{self.name}.h"
717
718    def resolve(self):
719        self.resolve_up(super())
720
721        if self.yaml.get('protocol', 'genetlink') not in {'genetlink', 'genetlink-c', 'genetlink-legacy'}:
722            raise Exception("Codegen only supported for genetlink")
723
724        self.c_name = c_lower(self.name)
725        if 'name-prefix' in self.yaml['operations']:
726            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
727        else:
728            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
729        if 'async-prefix' in self.yaml['operations']:
730            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
731        else:
732            self.async_op_prefix = self.op_prefix
733
734        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
735
736        self.hooks = dict()
737        for when in ['pre', 'post']:
738            self.hooks[when] = dict()
739            for op_mode in ['do', 'dump']:
740                self.hooks[when][op_mode] = dict()
741                self.hooks[when][op_mode]['set'] = set()
742                self.hooks[when][op_mode]['list'] = []
743
744        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
745        self.root_sets = dict()
746        # dict space-name -> set('request', 'reply')
747        self.pure_nested_structs = dict()
748        self.all_notify = dict()
749
750        self._mock_up_events()
751
752        self._dictify()
753        self._load_root_sets()
754        self._load_nested_sets()
755        self._load_all_notify()
756        self._load_hooks()
757
758        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
759        if self.kernel_policy == 'global':
760            self._load_global_policy()
761
762    def new_enum(self, elem):
763        return EnumSet(self, elem)
764
765    def new_attr_set(self, elem):
766        return AttrSet(self, elem)
767
768    def new_operation(self, elem, req_value, rsp_value):
769        return Operation(self, elem, req_value, rsp_value)
770
771    # Fake a 'do' equivalent of all events, so that we can render their response parsing
772    def _mock_up_events(self):
773        for op in self.yaml['operations']['list']:
774            if 'event' in op:
775                op['do'] = {
776                    'reply': {
777                        'attributes': op['event']['attributes']
778                    }
779                }
780
781    def _dictify(self):
782        ntf = []
783        for msg in self.msgs.values():
784            if 'notify' in msg:
785                ntf.append(msg)
786        for n in ntf:
787            self.ops[n['notify']].add_notification(n)
788
789    def _load_root_sets(self):
790        for op_name, op in self.ops.items():
791            if 'attribute-set' not in op:
792                continue
793
794            req_attrs = set()
795            rsp_attrs = set()
796            for op_mode in ['do', 'dump']:
797                if op_mode in op and 'request' in op[op_mode]:
798                    req_attrs.update(set(op[op_mode]['request']['attributes']))
799                if op_mode in op and 'reply' in op[op_mode]:
800                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
801
802            if op['attribute-set'] not in self.root_sets:
803                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
804            else:
805                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
806                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
807
808    def _load_nested_sets(self):
809        for root_set, rs_members in self.root_sets.items():
810            for attr, spec in self.attr_sets[root_set].items():
811                if 'nested-attributes' in spec:
812                    inherit = set()
813                    nested = spec['nested-attributes']
814                    if nested not in self.root_sets:
815                        self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
816                    if attr in rs_members['request']:
817                        self.pure_nested_structs[nested].request = True
818                    if attr in rs_members['reply']:
819                        self.pure_nested_structs[nested].reply = True
820
821                    if 'type-value' in spec:
822                        if nested in self.root_sets:
823                            raise Exception("Inheriting members to a space used as root not supported")
824                        inherit.update(set(spec['type-value']))
825                    elif spec['type'] == 'array-nest':
826                        inherit.add('idx')
827                    self.pure_nested_structs[nested].set_inherited(inherit)
828
829    def _load_all_notify(self):
830        for op_name, op in self.ops.items():
831            if not op:
832                continue
833
834            if 'notify' in op:
835                self.all_notify[op_name] = op['notify']['cmds']
836
837    def _load_global_policy(self):
838        global_set = set()
839        attr_set_name = None
840        for op_name, op in self.ops.items():
841            if not op:
842                continue
843            if 'attribute-set' not in op:
844                continue
845
846            if attr_set_name is None:
847                attr_set_name = op['attribute-set']
848            if attr_set_name != op['attribute-set']:
849                raise Exception('For a global policy all ops must use the same set')
850
851            for op_mode in ['do', 'dump']:
852                if op_mode in op:
853                    global_set.update(op[op_mode].get('request', []))
854
855        self.global_policy = []
856        self.global_policy_set = attr_set_name
857        for attr in self.attr_sets[attr_set_name]:
858            if attr in global_set:
859                self.global_policy.append(attr)
860
861    def _load_hooks(self):
862        for op in self.ops.values():
863            for op_mode in ['do', 'dump']:
864                if op_mode not in op:
865                    continue
866                for when in ['pre', 'post']:
867                    if when not in op[op_mode]:
868                        continue
869                    name = op[op_mode][when]
870                    if name in self.hooks[when][op_mode]['set']:
871                        continue
872                    self.hooks[when][op_mode]['set'].add(name)
873                    self.hooks[when][op_mode]['list'].append(name)
874
875
876class RenderInfo:
877    def __init__(self, cw, family, ku_space, op, op_name, op_mode, attr_set=None):
878        self.family = family
879        self.nl = cw.nlib
880        self.ku_space = ku_space
881        self.op = op
882        self.op_name = op_name
883        self.op_mode = op_mode
884
885        # 'do' and 'dump' response parsing is identical
886        if op_mode != 'do' and 'dump' in op and 'do' in op and 'reply' in op['do'] and \
887           op["do"]["reply"] == op["dump"]["reply"]:
888            self.type_consistent = True
889        else:
890            self.type_consistent = op_mode == 'event'
891
892        self.attr_set = attr_set
893        if not self.attr_set:
894            self.attr_set = op['attribute-set']
895
896        if op:
897            self.type_name = c_lower(op_name)
898        else:
899            self.type_name = c_lower(attr_set)
900
901        self.cw = cw
902
903        self.struct = dict()
904        for op_dir in ['request', 'reply']:
905            if op and op_dir in op[op_mode]:
906                self.struct[op_dir] = Struct(family, self.attr_set,
907                                             type_list=op[op_mode][op_dir]['attributes'])
908        if op_mode == 'event':
909            self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
910
911
912class CodeWriter:
913    def __init__(self, nlib, out_file):
914        self.nlib = nlib
915
916        self._nl = False
917        self._silent_block = False
918        self._ind = 0
919        self._out = out_file
920
921    @classmethod
922    def _is_cond(cls, line):
923        return line.startswith('if') or line.startswith('while') or line.startswith('for')
924
925    def p(self, line, add_ind=0):
926        if self._nl:
927            self._out.write('\n')
928            self._nl = False
929        ind = self._ind
930        if line[-1] == ':':
931            ind -= 1
932        if self._silent_block:
933            ind += 1
934        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
935        if add_ind:
936            ind += add_ind
937        self._out.write('\t' * ind + line + '\n')
938
939    def nl(self):
940        self._nl = True
941
942    def block_start(self, line=''):
943        if line:
944            line = line + ' '
945        self.p(line + '{')
946        self._ind += 1
947
948    def block_end(self, line=''):
949        if line and line[0] not in {';', ','}:
950            line = ' ' + line
951        self._ind -= 1
952        self.p('}' + line)
953
954    def write_doc_line(self, doc, indent=True):
955        words = doc.split()
956        line = ' *'
957        for word in words:
958            if len(line) + len(word) >= 79:
959                self.p(line)
960                line = ' *'
961                if indent:
962                    line += '  '
963            line += ' ' + word
964        self.p(line)
965
966    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
967        if not args:
968            args = ['void']
969
970        if doc:
971            self.p('/*')
972            self.p(' * ' + doc)
973            self.p(' */')
974
975        oneline = qual_ret
976        if qual_ret[-1] != '*':
977            oneline += ' '
978        oneline += f"{name}({', '.join(args)}){suffix}"
979
980        if len(oneline) < 80:
981            self.p(oneline)
982            return
983
984        v = qual_ret
985        if len(v) > 3:
986            self.p(v)
987            v = ''
988        elif qual_ret[-1] != '*':
989            v += ' '
990        v += name + '('
991        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
992        delta_ind = len(v) - len(ind)
993        v += args[0]
994        i = 1
995        while i < len(args):
996            next_len = len(v) + len(args[i])
997            if v[0] == '\t':
998                next_len += delta_ind
999            if next_len > 76:
1000                self.p(v + ',')
1001                v = ind
1002            else:
1003                v += ', '
1004            v += args[i]
1005            i += 1
1006        self.p(v + ')' + suffix)
1007
1008    def write_func_lvar(self, local_vars):
1009        if not local_vars:
1010            return
1011
1012        if type(local_vars) is str:
1013            local_vars = [local_vars]
1014
1015        local_vars.sort(key=len, reverse=True)
1016        for var in local_vars:
1017            self.p(var)
1018        self.nl()
1019
1020    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1021        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1022        self.write_func_lvar(local_vars=local_vars)
1023
1024        self.block_start()
1025        for line in body:
1026            self.p(line)
1027        self.block_end()
1028
1029    def writes_defines(self, defines):
1030        longest = 0
1031        for define in defines:
1032            if len(define[0]) > longest:
1033                longest = len(define[0])
1034        longest = ((longest + 8) // 8) * 8
1035        for define in defines:
1036            line = '#define ' + define[0]
1037            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1038            if type(define[1]) is int:
1039                line += str(define[1])
1040            elif type(define[1]) is str:
1041                line += '"' + define[1] + '"'
1042            self.p(line)
1043
1044    def write_struct_init(self, members):
1045        longest = max([len(x[0]) for x in members])
1046        longest += 1  # because we prepend a .
1047        longest = ((longest + 8) // 8) * 8
1048        for one in members:
1049            line = '.' + one[0]
1050            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1051            line += '= ' + one[1] + ','
1052            self.p(line)
1053
1054
1055scalars = {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
1056
1057direction_to_suffix = {
1058    'reply': '_rsp',
1059    'request': '_req',
1060    '': ''
1061}
1062
1063op_mode_to_wrapper = {
1064    'do': '',
1065    'dump': '_list',
1066    'notify': '_ntf',
1067    'event': '',
1068}
1069
1070_C_KW = {
1071    'do'
1072}
1073
1074
1075def rdir(direction):
1076    if direction == 'reply':
1077        return 'request'
1078    if direction == 'request':
1079        return 'reply'
1080    return direction
1081
1082
1083def op_prefix(ri, direction, deref=False):
1084    suffix = f"_{ri.type_name}"
1085
1086    if not ri.op_mode or ri.op_mode == 'do':
1087        suffix += f"{direction_to_suffix[direction]}"
1088    else:
1089        if direction == 'request':
1090            suffix += '_req_dump'
1091        else:
1092            if ri.type_consistent:
1093                if deref:
1094                    suffix += f"{direction_to_suffix[direction]}"
1095                else:
1096                    suffix += op_mode_to_wrapper[ri.op_mode]
1097            else:
1098                suffix += '_rsp'
1099                suffix += '_dump' if deref else '_list'
1100
1101    return f"{ri.family['name']}{suffix}"
1102
1103
1104def type_name(ri, direction, deref=False):
1105    return f"struct {op_prefix(ri, direction, deref=deref)}"
1106
1107
1108def print_prototype(ri, direction, terminate=True, doc=None):
1109    suffix = ';' if terminate else ''
1110
1111    fname = ri.op.render_name
1112    if ri.op_mode == 'dump':
1113        fname += '_dump'
1114
1115    args = ['struct ynl_sock *ys']
1116    if 'request' in ri.op[ri.op_mode]:
1117        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1118
1119    ret = 'int'
1120    if 'reply' in ri.op[ri.op_mode]:
1121        ret = f"{type_name(ri, rdir(direction))} *"
1122
1123    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1124
1125
1126def print_req_prototype(ri):
1127    print_prototype(ri, "request", doc=ri.op['doc'])
1128
1129
1130def print_dump_prototype(ri):
1131    print_prototype(ri, "request")
1132
1133
1134def put_typol_fwd(cw, struct):
1135    cw.p(f'extern struct ynl_policy_nest {struct.render_name}_nest;')
1136
1137
1138def put_typol(cw, struct):
1139    type_max = struct.attr_set.max_name
1140    cw.block_start(line=f'struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1141
1142    for _, arg in struct.member_list():
1143        arg.attr_typol(cw)
1144
1145    cw.block_end(line=';')
1146    cw.nl()
1147
1148    cw.block_start(line=f'struct ynl_policy_nest {struct.render_name}_nest =')
1149    cw.p(f'.max_attr = {type_max},')
1150    cw.p(f'.table = {struct.render_name}_policy,')
1151    cw.block_end(line=';')
1152    cw.nl()
1153
1154
1155def put_req_nested(ri, struct):
1156    func_args = ['struct nlmsghdr *nlh',
1157                 'unsigned int attr_type',
1158                 f'{struct.ptr_name}obj']
1159
1160    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args)
1161    ri.cw.block_start()
1162    ri.cw.write_func_lvar('struct nlattr *nest;')
1163
1164    ri.cw.p("nest = mnl_attr_nest_start(nlh, attr_type);")
1165
1166    for _, arg in struct.member_list():
1167        arg.attr_put(ri, "obj")
1168
1169    ri.cw.p("mnl_attr_nest_end(nlh, nest);")
1170
1171    ri.cw.nl()
1172    ri.cw.p('return 0;')
1173    ri.cw.block_end()
1174    ri.cw.nl()
1175
1176
1177def _multi_parse(ri, struct, init_lines, local_vars):
1178    if struct.nested:
1179        iter_line = "mnl_attr_for_each_nested(attr, nested)"
1180    else:
1181        iter_line = "mnl_attr_for_each(attr, nlh, sizeof(struct genlmsghdr))"
1182
1183    array_nests = set()
1184    multi_attrs = set()
1185    needs_parg = False
1186    for arg, aspec in struct.member_list():
1187        if aspec['type'] == 'array-nest':
1188            local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1189            array_nests.add(arg)
1190        if 'multi-attr' in aspec:
1191            multi_attrs.add(arg)
1192        needs_parg |= 'nested-attributes' in aspec
1193    if array_nests or multi_attrs:
1194        local_vars.append('int i;')
1195    if needs_parg:
1196        local_vars.append('struct ynl_parse_arg parg;')
1197        init_lines.append('parg.ys = yarg->ys;')
1198
1199    ri.cw.block_start()
1200    ri.cw.write_func_lvar(local_vars)
1201
1202    for line in init_lines:
1203        ri.cw.p(line)
1204    ri.cw.nl()
1205
1206    for arg in struct.inherited:
1207        ri.cw.p(f'dst->{arg} = {arg};')
1208
1209    ri.cw.nl()
1210    ri.cw.block_start(line=iter_line)
1211
1212    first = True
1213    for _, arg in struct.member_list():
1214        arg.attr_get(ri, 'dst', first=first)
1215        first = False
1216
1217    ri.cw.block_end()
1218    ri.cw.nl()
1219
1220    for anest in sorted(array_nests):
1221        aspec = struct[anest]
1222
1223        ri.cw.block_start(line=f"if (dst->n_{aspec.c_name})")
1224        ri.cw.p(f"dst->{aspec.c_name} = calloc(dst->n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1225        ri.cw.p('i = 0;')
1226        ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1227        ri.cw.block_start(line=f"mnl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1228        ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1229        ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, mnl_attr_get_type(attr)))")
1230        ri.cw.p('return MNL_CB_ERROR;')
1231        ri.cw.p('i++;')
1232        ri.cw.block_end()
1233        ri.cw.block_end()
1234    ri.cw.nl()
1235
1236    for anest in sorted(multi_attrs):
1237        aspec = struct[anest]
1238        ri.cw.block_start(line=f"if (dst->n_{aspec.c_name})")
1239        ri.cw.p(f"dst->{aspec.c_name} = calloc(dst->n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1240        ri.cw.p('i = 0;')
1241        if 'nested-attributes' in aspec:
1242            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1243        ri.cw.block_start(line=iter_line)
1244        ri.cw.block_start(line=f"if (mnl_attr_get_type(attr) == {aspec.enum_name})")
1245        if 'nested-attributes' in aspec:
1246            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1247            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1248            ri.cw.p('return MNL_CB_ERROR;')
1249        elif aspec['type'] in scalars:
1250            t = aspec['type']
1251            if t[0] == 's':
1252                t = 'u' + t[1:]
1253            ri.cw.p(f"dst->{aspec.c_name}[i] = mnl_attr_get_{t}(attr);")
1254        else:
1255            raise Exception('Nest parsing type not supported yet')
1256        ri.cw.p('i++;')
1257        ri.cw.block_end()
1258        ri.cw.block_end()
1259        ri.cw.block_end()
1260    ri.cw.nl()
1261
1262    if struct.nested:
1263        ri.cw.p('return 0;')
1264    else:
1265        ri.cw.p('return MNL_CB_OK;')
1266    ri.cw.block_end()
1267    ri.cw.nl()
1268
1269
1270def parse_rsp_nested(ri, struct):
1271    func_args = ['struct ynl_parse_arg *yarg',
1272                 'const struct nlattr *nested']
1273    for arg in struct.inherited:
1274        func_args.append('__u32 ' + arg)
1275
1276    local_vars = ['const struct nlattr *attr;',
1277                  f'{struct.ptr_name}dst = yarg->data;']
1278    init_lines = []
1279
1280    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args)
1281
1282    _multi_parse(ri, struct, init_lines, local_vars)
1283
1284
1285def parse_rsp_msg(ri, deref=False):
1286    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1287        return
1288
1289    func_args = ['const struct nlmsghdr *nlh',
1290                 'void *data']
1291
1292    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1293                  'struct ynl_parse_arg *yarg = data;',
1294                  'const struct nlattr *attr;']
1295    init_lines = ['dst = yarg->data;']
1296
1297    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1298
1299    _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1300
1301
1302def print_req(ri):
1303    ret_ok = '0'
1304    ret_err = '-1'
1305    direction = "request"
1306    local_vars = ['struct nlmsghdr *nlh;',
1307                  'int len, err;']
1308
1309    if 'reply' in ri.op[ri.op_mode]:
1310        ret_ok = 'rsp'
1311        ret_err = 'NULL'
1312        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;',
1313                       'struct ynl_parse_arg yarg = { .ys = ys, };']
1314
1315    print_prototype(ri, direction, terminate=False)
1316    ri.cw.block_start()
1317    ri.cw.write_func_lvar(local_vars)
1318
1319    ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1320
1321    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1322    if 'reply' in ri.op[ri.op_mode]:
1323        ri.cw.p(f"yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1324    ri.cw.nl()
1325    for _, attr in ri.struct["request"].member_list():
1326        attr.attr_put(ri, "req")
1327    ri.cw.nl()
1328
1329    ri.cw.p('err = mnl_socket_sendto(ys->sock, nlh, nlh->nlmsg_len);')
1330    ri.cw.p('if (err < 0)')
1331    ri.cw.p(f"return {ret_err};")
1332    ri.cw.nl()
1333    ri.cw.p('len = mnl_socket_recvfrom(ys->sock, ys->rx_buf, MNL_SOCKET_BUFFER_SIZE);')
1334    ri.cw.p('if (len < 0)')
1335    ri.cw.p(f"return {ret_err};")
1336    ri.cw.nl()
1337
1338    if 'reply' in ri.op[ri.op_mode]:
1339        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1340        ri.cw.p('yarg.data = rsp;')
1341        ri.cw.nl()
1342        ri.cw.p(f"err = {ri.nl.parse_cb_run(op_prefix(ri, 'reply') + '_parse', '&yarg', False)};")
1343        ri.cw.p('if (err < 0)')
1344        ri.cw.p('goto err_free;')
1345        ri.cw.nl()
1346
1347    ri.cw.p('err = ynl_recv_ack(ys, err);')
1348    ri.cw.p('if (err)')
1349    ri.cw.p('goto err_free;')
1350    ri.cw.nl()
1351    ri.cw.p(f"return {ret_ok};")
1352    ri.cw.nl()
1353    ri.cw.p('err_free:')
1354
1355    if 'reply' in ri.op[ri.op_mode]:
1356        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1357    ri.cw.p(f"return {ret_err};")
1358    ri.cw.block_end()
1359
1360
1361def print_dump(ri):
1362    direction = "request"
1363    print_prototype(ri, direction, terminate=False)
1364    ri.cw.block_start()
1365    local_vars = ['struct ynl_dump_state yds = {};',
1366                  'struct nlmsghdr *nlh;',
1367                  'int len, err;']
1368
1369    for var in local_vars:
1370        ri.cw.p(f'{var}')
1371    ri.cw.nl()
1372
1373    ri.cw.p('yds.ys = ys;')
1374    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1375    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1376    ri.cw.p(f"yds.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1377    ri.cw.nl()
1378    ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1379
1380    if "request" in ri.op[ri.op_mode]:
1381        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1382        ri.cw.nl()
1383        for _, attr in ri.struct["request"].member_list():
1384            attr.attr_put(ri, "req")
1385    ri.cw.nl()
1386
1387    ri.cw.p('err = mnl_socket_sendto(ys->sock, nlh, nlh->nlmsg_len);')
1388    ri.cw.p('if (err < 0)')
1389    ri.cw.p('return NULL;')
1390    ri.cw.nl()
1391
1392    ri.cw.block_start(line='do')
1393    ri.cw.p('len = mnl_socket_recvfrom(ys->sock, ys->rx_buf, MNL_SOCKET_BUFFER_SIZE);')
1394    ri.cw.p('if (len < 0)')
1395    ri.cw.p('goto free_list;')
1396    ri.cw.nl()
1397    ri.cw.p(f"err = {ri.nl.parse_cb_run('ynl_dump_trampoline', '&yds', False, indent=2)};")
1398    ri.cw.p('if (err < 0)')
1399    ri.cw.p('goto free_list;')
1400    ri.cw.block_end(line='while (err > 0);')
1401    ri.cw.nl()
1402
1403    ri.cw.p('return yds.first;')
1404    ri.cw.nl()
1405    ri.cw.p('free_list:')
1406    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
1407    ri.cw.p('return NULL;')
1408    ri.cw.block_end()
1409
1410
1411def call_free(ri, direction, var):
1412    return f"{op_prefix(ri, direction)}_free({var});"
1413
1414
1415def free_arg_name(direction):
1416    if direction:
1417        return direction_to_suffix[direction][1:]
1418    return 'obj'
1419
1420
1421def print_free_prototype(ri, direction, suffix=';'):
1422    name = op_prefix(ri, direction)
1423    arg = free_arg_name(direction)
1424    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {name} *{arg}"], suffix=suffix)
1425
1426
1427def _print_type(ri, direction, struct):
1428    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
1429
1430    if ri.op_mode == 'dump':
1431        suffix += '_dump'
1432
1433    ri.cw.block_start(line=f"struct {ri.family['name']}{suffix}")
1434
1435    meta_started = False
1436    for _, attr in struct.member_list():
1437        for type_filter in ['len', 'bit']:
1438            line = attr.presence_member(ri.ku_space, type_filter)
1439            if line:
1440                if not meta_started:
1441                    ri.cw.block_start(line=f"struct")
1442                    meta_started = True
1443                ri.cw.p(line)
1444    if meta_started:
1445        ri.cw.block_end(line='_present;')
1446        ri.cw.nl()
1447
1448    for arg in struct.inherited:
1449        ri.cw.p(f"__u32 {arg};")
1450
1451    for _, attr in struct.member_list():
1452        attr.struct_member(ri)
1453
1454    ri.cw.block_end(line=';')
1455    ri.cw.nl()
1456
1457
1458def print_type(ri, direction):
1459    _print_type(ri, direction, ri.struct[direction])
1460
1461
1462def print_type_full(ri, struct):
1463    _print_type(ri, "", struct)
1464
1465
1466def print_type_helpers(ri, direction, deref=False):
1467    print_free_prototype(ri, direction)
1468
1469    if ri.ku_space == 'user' and direction == 'request':
1470        for _, attr in ri.struct[direction].member_list():
1471            attr.setter(ri, ri.attr_set, direction, deref=deref)
1472    ri.cw.nl()
1473
1474
1475def print_req_type_helpers(ri):
1476    print_type_helpers(ri, "request")
1477
1478
1479def print_rsp_type_helpers(ri):
1480    if 'reply' not in ri.op[ri.op_mode]:
1481        return
1482    print_type_helpers(ri, "reply")
1483
1484
1485def print_parse_prototype(ri, direction, terminate=True):
1486    suffix = "_rsp" if direction == "reply" else "_req"
1487    term = ';' if terminate else ''
1488
1489    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
1490                          ['const struct nlattr **tb',
1491                           f"struct {ri.op.render_name}{suffix} *req"],
1492                          suffix=term)
1493
1494
1495def print_req_type(ri):
1496    print_type(ri, "request")
1497
1498
1499def print_rsp_type(ri):
1500    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
1501        direction = 'reply'
1502    elif ri.op_mode == 'event':
1503        direction = 'reply'
1504    else:
1505        return
1506    print_type(ri, direction)
1507
1508
1509def print_wrapped_type(ri):
1510    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
1511    if ri.op_mode == 'dump':
1512        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
1513    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
1514        ri.cw.p('__u16 family;')
1515        ri.cw.p('__u8 cmd;')
1516        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
1517    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__ ((aligned (8)));")
1518    ri.cw.block_end(line=';')
1519    ri.cw.nl()
1520    print_free_prototype(ri, 'reply')
1521    ri.cw.nl()
1522
1523
1524def _free_type_members_iter(ri, struct):
1525    for _, attr in struct.member_list():
1526        if attr.free_needs_iter():
1527            ri.cw.p('unsigned int i;')
1528            ri.cw.nl()
1529            break
1530
1531
1532def _free_type_members(ri, var, struct, ref=''):
1533    for _, attr in struct.member_list():
1534        attr.free(ri, var, ref)
1535
1536
1537def _free_type(ri, direction, struct):
1538    var = free_arg_name(direction)
1539
1540    print_free_prototype(ri, direction, suffix='')
1541    ri.cw.block_start()
1542    _free_type_members_iter(ri, struct)
1543    _free_type_members(ri, var, struct)
1544    if direction:
1545        ri.cw.p(f'free({var});')
1546    ri.cw.block_end()
1547    ri.cw.nl()
1548
1549
1550def free_rsp_nested(ri, struct):
1551    _free_type(ri, "", struct)
1552
1553
1554def print_rsp_free(ri):
1555    if 'reply' not in ri.op[ri.op_mode]:
1556        return
1557    _free_type(ri, 'reply', ri.struct['reply'])
1558
1559
1560def print_dump_type_free(ri):
1561    sub_type = type_name(ri, 'reply')
1562
1563    print_free_prototype(ri, 'reply', suffix='')
1564    ri.cw.block_start()
1565    ri.cw.p(f"{sub_type} *next = rsp;")
1566    ri.cw.nl()
1567    ri.cw.block_start(line='while (next)')
1568    _free_type_members_iter(ri, ri.struct['reply'])
1569    ri.cw.p('rsp = next;')
1570    ri.cw.p('next = rsp->next;')
1571    ri.cw.nl()
1572
1573    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
1574    ri.cw.p(f'free(rsp);')
1575    ri.cw.block_end()
1576    ri.cw.block_end()
1577    ri.cw.nl()
1578
1579
1580def print_ntf_type_free(ri):
1581    print_free_prototype(ri, 'reply', suffix='')
1582    ri.cw.block_start()
1583    _free_type_members_iter(ri, ri.struct['reply'])
1584    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
1585    ri.cw.p(f'free(rsp);')
1586    ri.cw.block_end()
1587    ri.cw.nl()
1588
1589
1590def print_ntf_parse_prototype(family, cw, suffix=';'):
1591    cw.write_func_prot('struct ynl_ntf_base_type *', f"{family['name']}_ntf_parse",
1592                       ['struct ynl_sock *ys'], suffix=suffix)
1593
1594
1595def print_ntf_type_parse(family, cw, ku_mode):
1596    print_ntf_parse_prototype(family, cw, suffix='')
1597    cw.block_start()
1598    cw.write_func_lvar(['struct genlmsghdr *genlh;',
1599                        'struct nlmsghdr *nlh;',
1600                        'struct ynl_parse_arg yarg = { .ys = ys, };',
1601                        'struct ynl_ntf_base_type *rsp;',
1602                        'int len, err;',
1603                        'mnl_cb_t parse;'])
1604    cw.p('len = mnl_socket_recvfrom(ys->sock, ys->rx_buf, MNL_SOCKET_BUFFER_SIZE);')
1605    cw.p('if (len < (ssize_t)(sizeof(*nlh) + sizeof(*genlh)))')
1606    cw.p('return NULL;')
1607    cw.nl()
1608    cw.p('nlh = (struct nlmsghdr *)ys->rx_buf;')
1609    cw.p('genlh = mnl_nlmsg_get_payload(nlh);')
1610    cw.nl()
1611    cw.block_start(line='switch (genlh->cmd)')
1612    for ntf_op in sorted(family.all_notify.keys()):
1613        op = family.ops[ntf_op]
1614        ri = RenderInfo(cw, family, ku_mode, op, ntf_op, "notify")
1615        for ntf in op['notify']['cmds']:
1616            cw.p(f"case {ntf.enum_name}:")
1617        cw.p(f"rsp = calloc(1, sizeof({type_name(ri, 'notify')}));")
1618        cw.p(f"parse = {op_prefix(ri, 'reply', deref=True)}_parse;")
1619        cw.p(f"yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1620        cw.p(f"rsp->free = (void *){op_prefix(ri, 'notify')}_free;")
1621        cw.p('break;')
1622    for op_name, op in family.ops.items():
1623        if 'event' not in op:
1624            continue
1625        ri = RenderInfo(cw, family, ku_mode, op, op_name, "event")
1626        cw.p(f"case {op.enum_name}:")
1627        cw.p(f"rsp = calloc(1, sizeof({type_name(ri, 'event')}));")
1628        cw.p(f"parse = {op_prefix(ri, 'reply', deref=True)}_parse;")
1629        cw.p(f"yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1630        cw.p(f"rsp->free = (void *){op_prefix(ri, 'notify')}_free;")
1631        cw.p('break;')
1632    cw.p('default:')
1633    cw.p('ynl_error_unknown_notification(ys, genlh->cmd);')
1634    cw.p('return NULL;')
1635    cw.block_end()
1636    cw.nl()
1637    cw.p('yarg.data = rsp->data;')
1638    cw.nl()
1639    cw.p(f"err = {cw.nlib.parse_cb_run('parse', '&yarg', True)};")
1640    cw.p('if (err < 0)')
1641    cw.p('goto err_free;')
1642    cw.nl()
1643    cw.p('rsp->family = nlh->nlmsg_type;')
1644    cw.p('rsp->cmd = genlh->cmd;')
1645    cw.p('return rsp;')
1646    cw.nl()
1647    cw.p('err_free:')
1648    cw.p('free(rsp);')
1649    cw.p('return NULL;')
1650    cw.block_end()
1651    cw.nl()
1652
1653
1654def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
1655    if terminate and ri and kernel_can_gen_family_struct(struct.family):
1656        return
1657
1658    if terminate:
1659        prefix = 'extern '
1660    else:
1661        if kernel_can_gen_family_struct(struct.family) and ri:
1662            prefix = 'static '
1663        else:
1664            prefix = ''
1665
1666    suffix = ';' if terminate else ' = {'
1667
1668    max_attr = struct.attr_max_val
1669    if ri:
1670        name = ri.op.render_name
1671        if ri.op.dual_policy:
1672            name += '_' + ri.op_mode
1673    else:
1674        name = struct.render_name
1675    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
1676
1677
1678def print_req_policy(cw, struct, ri=None):
1679    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
1680    for _, arg in struct.member_list():
1681        arg.attr_policy(cw)
1682    cw.p("};")
1683
1684
1685def kernel_can_gen_family_struct(family):
1686    return family.proto == 'genetlink'
1687
1688
1689def print_kernel_op_table_fwd(family, cw, terminate):
1690    exported = not kernel_can_gen_family_struct(family)
1691
1692    if not terminate or exported:
1693        cw.p(f"/* Ops table for {family.name} */")
1694
1695        pol_to_struct = {'global': 'genl_small_ops',
1696                         'per-op': 'genl_ops',
1697                         'split': 'genl_split_ops'}
1698        struct_type = pol_to_struct[family.kernel_policy]
1699
1700        if not exported:
1701            cnt = ""
1702        elif family.kernel_policy == 'split':
1703            cnt = 0
1704            for op in family.ops.values():
1705                if 'do' in op:
1706                    cnt += 1
1707                if 'dump' in op:
1708                    cnt += 1
1709        else:
1710            cnt = len(family.ops)
1711
1712        qual = 'static const' if not exported else 'const'
1713        line = f"{qual} struct {struct_type} {family.name}_nl_ops[{cnt}]"
1714        if terminate:
1715            cw.p(f"extern {line};")
1716        else:
1717            cw.block_start(line=line + ' =')
1718
1719    if not terminate:
1720        return
1721
1722    cw.nl()
1723    for name in family.hooks['pre']['do']['list']:
1724        cw.write_func_prot('int', c_lower(name),
1725                           ['const struct genl_split_ops *ops',
1726                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1727    for name in family.hooks['post']['do']['list']:
1728        cw.write_func_prot('void', c_lower(name),
1729                           ['const struct genl_split_ops *ops',
1730                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1731    for name in family.hooks['pre']['dump']['list']:
1732        cw.write_func_prot('int', c_lower(name),
1733                           ['struct netlink_callback *cb'], suffix=';')
1734    for name in family.hooks['post']['dump']['list']:
1735        cw.write_func_prot('int', c_lower(name),
1736                           ['struct netlink_callback *cb'], suffix=';')
1737
1738    cw.nl()
1739
1740    for op_name, op in family.ops.items():
1741        if op.is_async:
1742            continue
1743
1744        if 'do' in op:
1745            name = c_lower(f"{family.name}-nl-{op_name}-doit")
1746            cw.write_func_prot('int', name,
1747                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1748
1749        if 'dump' in op:
1750            name = c_lower(f"{family.name}-nl-{op_name}-dumpit")
1751            cw.write_func_prot('int', name,
1752                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
1753    cw.nl()
1754
1755
1756def print_kernel_op_table_hdr(family, cw):
1757    print_kernel_op_table_fwd(family, cw, terminate=True)
1758
1759
1760def print_kernel_op_table(family, cw):
1761    print_kernel_op_table_fwd(family, cw, terminate=False)
1762    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
1763        for op_name, op in family.ops.items():
1764            if op.is_async:
1765                continue
1766
1767            cw.block_start()
1768            members = [('cmd', op.enum_name)]
1769            if 'dont-validate' in op:
1770                members.append(('validate',
1771                                ' | '.join([c_upper('genl-dont-validate-' + x)
1772                                            for x in op['dont-validate']])), )
1773            for op_mode in ['do', 'dump']:
1774                if op_mode in op:
1775                    name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
1776                    members.append((op_mode + 'it', name))
1777            if family.kernel_policy == 'per-op':
1778                struct = Struct(family, op['attribute-set'],
1779                                type_list=op['do']['request']['attributes'])
1780
1781                name = c_lower(f"{family.name}-{op_name}-nl-policy")
1782                members.append(('policy', name))
1783                members.append(('maxattr', struct.attr_max_val.enum_name))
1784            if 'flags' in op:
1785                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
1786            cw.write_struct_init(members)
1787            cw.block_end(line=',')
1788    elif family.kernel_policy == 'split':
1789        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
1790                    'dump': {'pre': 'start', 'post': 'done'}}
1791
1792        for op_name, op in family.ops.items():
1793            for op_mode in ['do', 'dump']:
1794                if op.is_async or op_mode not in op:
1795                    continue
1796
1797                cw.block_start()
1798                members = [('cmd', op.enum_name)]
1799                if 'dont-validate' in op:
1800                    members.append(('validate',
1801                                    ' | '.join([c_upper('genl-dont-validate-' + x)
1802                                                for x in op['dont-validate']])), )
1803                name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
1804                if 'pre' in op[op_mode]:
1805                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
1806                members.append((op_mode + 'it', name))
1807                if 'post' in op[op_mode]:
1808                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
1809                if 'request' in op[op_mode]:
1810                    struct = Struct(family, op['attribute-set'],
1811                                    type_list=op[op_mode]['request']['attributes'])
1812
1813                    if op.dual_policy:
1814                        name = c_lower(f"{family.name}-{op_name}-{op_mode}-nl-policy")
1815                    else:
1816                        name = c_lower(f"{family.name}-{op_name}-nl-policy")
1817                    members.append(('policy', name))
1818                    members.append(('maxattr', struct.attr_max_val.enum_name))
1819                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
1820                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
1821                cw.write_struct_init(members)
1822                cw.block_end(line=',')
1823
1824    cw.block_end(line=';')
1825    cw.nl()
1826
1827
1828def print_kernel_mcgrp_hdr(family, cw):
1829    if not family.mcgrps['list']:
1830        return
1831
1832    cw.block_start('enum')
1833    for grp in family.mcgrps['list']:
1834        grp_id = c_upper(f"{family.name}-nlgrp-{grp['name']},")
1835        cw.p(grp_id)
1836    cw.block_end(';')
1837    cw.nl()
1838
1839
1840def print_kernel_mcgrp_src(family, cw):
1841    if not family.mcgrps['list']:
1842        return
1843
1844    cw.block_start('static const struct genl_multicast_group ' + family.name + '_nl_mcgrps[] =')
1845    for grp in family.mcgrps['list']:
1846        name = grp['name']
1847        grp_id = c_upper(f"{family.name}-nlgrp-{name}")
1848        cw.p('[' + grp_id + '] = { "' + name + '", },')
1849    cw.block_end(';')
1850    cw.nl()
1851
1852
1853def print_kernel_family_struct_hdr(family, cw):
1854    if not kernel_can_gen_family_struct(family):
1855        return
1856
1857    cw.p(f"extern struct genl_family {family.name}_nl_family;")
1858    cw.nl()
1859
1860
1861def print_kernel_family_struct_src(family, cw):
1862    if not kernel_can_gen_family_struct(family):
1863        return
1864
1865    cw.block_start(f"struct genl_family {family.name}_nl_family __ro_after_init =")
1866    cw.p('.name\t\t= ' + family.fam_key + ',')
1867    cw.p('.version\t= ' + family.ver_key + ',')
1868    cw.p('.netnsok\t= true,')
1869    cw.p('.parallel_ops\t= true,')
1870    cw.p('.module\t\t= THIS_MODULE,')
1871    if family.kernel_policy == 'per-op':
1872        cw.p(f'.ops\t\t= {family.name}_nl_ops,')
1873        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.name}_nl_ops),')
1874    elif family.kernel_policy == 'split':
1875        cw.p(f'.split_ops\t= {family.name}_nl_ops,')
1876        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.name}_nl_ops),')
1877    if family.mcgrps['list']:
1878        cw.p(f'.mcgrps\t\t= {family.name}_nl_mcgrps,')
1879        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.name}_nl_mcgrps),')
1880    cw.block_end(';')
1881
1882
1883def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
1884    start_line = 'enum'
1885    if enum_name in obj:
1886        if obj[enum_name]:
1887            start_line = 'enum ' + c_lower(obj[enum_name])
1888    elif ckey and ckey in obj:
1889        start_line = 'enum ' + family.name + '_' + c_lower(obj[ckey])
1890    cw.block_start(line=start_line)
1891
1892
1893def render_uapi(family, cw):
1894    hdr_prot = f"_UAPI_LINUX_{family.name.upper()}_H"
1895    cw.p('#ifndef ' + hdr_prot)
1896    cw.p('#define ' + hdr_prot)
1897    cw.nl()
1898
1899    defines = [(family.fam_key, family["name"]),
1900               (family.ver_key, family.get('version', 1))]
1901    cw.writes_defines(defines)
1902    cw.nl()
1903
1904    defines = []
1905    for const in family['definitions']:
1906        if const['type'] != 'const':
1907            cw.writes_defines(defines)
1908            defines = []
1909            cw.nl()
1910
1911        # Write kdoc for enum and flags (one day maybe also structs)
1912        if const['type'] == 'enum' or const['type'] == 'flags':
1913            enum = family.consts[const['name']]
1914
1915            if enum.has_doc():
1916                cw.p('/**')
1917                doc = ''
1918                if 'doc' in enum:
1919                    doc = ' - ' + enum['doc']
1920                cw.write_doc_line(enum.enum_name + doc)
1921                for entry in enum.entries.values():
1922                    if entry.has_doc():
1923                        doc = '@' + entry.c_name + ': ' + entry['doc']
1924                        cw.write_doc_line(doc)
1925                cw.p(' */')
1926
1927            uapi_enum_start(family, cw, const, 'name')
1928            name_pfx = const.get('name-prefix', f"{family.name}-{const['name']}-")
1929            for entry in enum.entries.values():
1930                suffix = ','
1931                if entry.value_change:
1932                    suffix = f" = {entry.user_value()}" + suffix
1933                cw.p(entry.c_name + suffix)
1934
1935            if const.get('render-max', False):
1936                cw.nl()
1937                if const['type'] == 'flags':
1938                    max_name = c_upper(name_pfx + 'mask')
1939                    max_val = f' = {enum.get_mask()},'
1940                    cw.p(max_name + max_val)
1941                else:
1942                    max_name = c_upper(name_pfx + 'max')
1943                    cw.p('__' + max_name + ',')
1944                    cw.p(max_name + ' = (__' + max_name + ' - 1)')
1945            cw.block_end(line=';')
1946            cw.nl()
1947        elif const['type'] == 'const':
1948            defines.append([c_upper(family.get('c-define-name',
1949                                               f"{family.name}-{const['name']}")),
1950                            const['value']])
1951
1952    if defines:
1953        cw.writes_defines(defines)
1954        cw.nl()
1955
1956    max_by_define = family.get('max-by-define', False)
1957
1958    for _, attr_set in family.attr_sets.items():
1959        if attr_set.subset_of:
1960            continue
1961
1962        cnt_name = c_upper(family.get('attr-cnt-name', f"__{attr_set.name_prefix}MAX"))
1963        max_value = f"({cnt_name} - 1)"
1964
1965        val = 0
1966        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
1967        for _, attr in attr_set.items():
1968            suffix = ','
1969            if attr.value != val:
1970                suffix = f" = {attr.value},"
1971                val = attr.value
1972            val += 1
1973            cw.p(attr.enum_name + suffix)
1974        cw.nl()
1975        cw.p(cnt_name + ('' if max_by_define else ','))
1976        if not max_by_define:
1977            cw.p(f"{attr_set.max_name} = {max_value}")
1978        cw.block_end(line=';')
1979        if max_by_define:
1980            cw.p(f"#define {attr_set.max_name} {max_value}")
1981        cw.nl()
1982
1983    # Commands
1984    separate_ntf = 'async-prefix' in family['operations']
1985
1986    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
1987    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
1988    max_value = f"({cnt_name} - 1)"
1989
1990    uapi_enum_start(family, cw, family['operations'], 'enum-name')
1991    val = 0
1992    for op in family.msgs.values():
1993        if separate_ntf and ('notify' in op or 'event' in op):
1994            continue
1995
1996        suffix = ','
1997        if op.value != val:
1998            suffix = f" = {op.value},"
1999            val = op.value
2000        cw.p(op.enum_name + suffix)
2001        val += 1
2002    cw.nl()
2003    cw.p(cnt_name + ('' if max_by_define else ','))
2004    if not max_by_define:
2005        cw.p(f"{max_name} = {max_value}")
2006    cw.block_end(line=';')
2007    if max_by_define:
2008        cw.p(f"#define {max_name} {max_value}")
2009    cw.nl()
2010
2011    if separate_ntf:
2012        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2013        for op in family.msgs.values():
2014            if separate_ntf and not ('notify' in op or 'event' in op):
2015                continue
2016
2017            suffix = ','
2018            if 'value' in op:
2019                suffix = f" = {op['value']},"
2020            cw.p(op.enum_name + suffix)
2021        cw.block_end(line=';')
2022        cw.nl()
2023
2024    # Multicast
2025    defines = []
2026    for grp in family.mcgrps['list']:
2027        name = grp['name']
2028        defines.append([c_upper(grp.get('c-define-name', f"{family.name}-mcgrp-{name}")),
2029                        f'{name}'])
2030    cw.nl()
2031    if defines:
2032        cw.writes_defines(defines)
2033        cw.nl()
2034
2035    cw.p(f'#endif /* {hdr_prot} */')
2036
2037
2038def find_kernel_root(full_path):
2039    sub_path = ''
2040    while True:
2041        sub_path = os.path.join(os.path.basename(full_path), sub_path)
2042        full_path = os.path.dirname(full_path)
2043        maintainers = os.path.join(full_path, "MAINTAINERS")
2044        if os.path.exists(maintainers):
2045            return full_path, sub_path[:-1]
2046
2047
2048def main():
2049    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2050    parser.add_argument('--mode', dest='mode', type=str, required=True)
2051    parser.add_argument('--spec', dest='spec', type=str, required=True)
2052    parser.add_argument('--header', dest='header', action='store_true', default=None)
2053    parser.add_argument('--source', dest='header', action='store_false')
2054    parser.add_argument('--user-header', nargs='+', default=[])
2055    parser.add_argument('-o', dest='out_file', type=str)
2056    args = parser.parse_args()
2057
2058    out_file = open(args.out_file, 'w+') if args.out_file else os.sys.stdout
2059
2060    if args.header is None:
2061        parser.error("--header or --source is required")
2062
2063    try:
2064        parsed = Family(args.spec)
2065        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
2066            print('Spec license:', parsed.license)
2067            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
2068            os.sys.exit(1)
2069    except yaml.YAMLError as exc:
2070        print(exc)
2071        os.sys.exit(1)
2072        return
2073
2074    cw = CodeWriter(BaseNlLib(), out_file)
2075
2076    _, spec_kernel = find_kernel_root(args.spec)
2077    if args.mode == 'uapi' or args.header:
2078        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
2079    else:
2080        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
2081    cw.p("/* Do not edit directly, auto-generated from: */")
2082    cw.p(f"/*\t{spec_kernel} */")
2083    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2084    cw.nl()
2085
2086    if args.mode == 'uapi':
2087        render_uapi(parsed, cw)
2088        return
2089
2090    hdr_prot = f"_LINUX_{parsed.name.upper()}_GEN_H"
2091    if args.header:
2092        cw.p('#ifndef ' + hdr_prot)
2093        cw.p('#define ' + hdr_prot)
2094        cw.nl()
2095
2096    if args.mode == 'kernel':
2097        cw.p('#include <net/netlink.h>')
2098        cw.p('#include <net/genetlink.h>')
2099        cw.nl()
2100        if not args.header:
2101            if args.out_file:
2102                cw.p(f'#include "{os.path.basename(args.out_file[:-2])}.h"')
2103            cw.nl()
2104    headers = [parsed.uapi_header]
2105    for definition in parsed['definitions']:
2106        if 'header' in definition:
2107            headers.append(definition['header'])
2108    for one in headers:
2109        cw.p(f"#include <{one}>")
2110    cw.nl()
2111
2112    if args.mode == "user":
2113        if not args.header:
2114            cw.p("#include <stdlib.h>")
2115            cw.p("#include <stdio.h>")
2116            cw.p("#include <string.h>")
2117            cw.p("#include <libmnl/libmnl.h>")
2118            cw.p("#include <linux/genetlink.h>")
2119            cw.nl()
2120            for one in args.user_header:
2121                cw.p(f'#include "{one}"')
2122        else:
2123            cw.p('struct ynl_sock;')
2124        cw.nl()
2125
2126    if args.mode == "kernel":
2127        if args.header:
2128            for _, struct in sorted(parsed.pure_nested_structs.items()):
2129                if struct.request:
2130                    cw.p('/* Common nested types */')
2131                    break
2132            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2133                if struct.request:
2134                    print_req_policy_fwd(cw, struct)
2135            cw.nl()
2136
2137            if parsed.kernel_policy == 'global':
2138                cw.p(f"/* Global operation policy for {parsed.name} */")
2139
2140                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2141                print_req_policy_fwd(cw, struct)
2142                cw.nl()
2143
2144            if parsed.kernel_policy in {'per-op', 'split'}:
2145                for op_name, op in parsed.ops.items():
2146                    if 'do' in op and 'event' not in op:
2147                        ri = RenderInfo(cw, parsed, args.mode, op, op_name, "do")
2148                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2149                        cw.nl()
2150
2151            print_kernel_op_table_hdr(parsed, cw)
2152            print_kernel_mcgrp_hdr(parsed, cw)
2153            print_kernel_family_struct_hdr(parsed, cw)
2154        else:
2155            for _, struct in sorted(parsed.pure_nested_structs.items()):
2156                if struct.request:
2157                    cw.p('/* Common nested types */')
2158                    break
2159            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2160                if struct.request:
2161                    print_req_policy(cw, struct)
2162            cw.nl()
2163
2164            if parsed.kernel_policy == 'global':
2165                cw.p(f"/* Global operation policy for {parsed.name} */")
2166
2167                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2168                print_req_policy(cw, struct)
2169                cw.nl()
2170
2171            for op_name, op in parsed.ops.items():
2172                if parsed.kernel_policy in {'per-op', 'split'}:
2173                    for op_mode in ['do', 'dump']:
2174                        if op_mode in op and 'request' in op[op_mode]:
2175                            cw.p(f"/* {op.enum_name} - {op_mode} */")
2176                            ri = RenderInfo(cw, parsed, args.mode, op, op_name, op_mode)
2177                            print_req_policy(cw, ri.struct['request'], ri=ri)
2178                            cw.nl()
2179
2180            print_kernel_op_table(parsed, cw)
2181            print_kernel_mcgrp_src(parsed, cw)
2182            print_kernel_family_struct_src(parsed, cw)
2183
2184    if args.mode == "user":
2185        has_ntf = False
2186        if args.header:
2187            cw.p('/* Common nested types */')
2188            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2189                ri = RenderInfo(cw, parsed, args.mode, "", "", "", attr_set)
2190                print_type_full(ri, struct)
2191
2192            for op_name, op in parsed.ops.items():
2193                cw.p(f"/* ============== {op.enum_name} ============== */")
2194
2195                if 'do' in op and 'event' not in op:
2196                    cw.p(f"/* {op.enum_name} - do */")
2197                    ri = RenderInfo(cw, parsed, args.mode, op, op_name, "do")
2198                    print_req_type(ri)
2199                    print_req_type_helpers(ri)
2200                    cw.nl()
2201                    print_rsp_type(ri)
2202                    print_rsp_type_helpers(ri)
2203                    cw.nl()
2204                    print_req_prototype(ri)
2205                    cw.nl()
2206
2207                if 'dump' in op:
2208                    cw.p(f"/* {op.enum_name} - dump */")
2209                    ri = RenderInfo(cw, parsed, args.mode, op, op_name, 'dump')
2210                    if 'request' in op['dump']:
2211                        print_req_type(ri)
2212                        print_req_type_helpers(ri)
2213                    if not ri.type_consistent:
2214                        print_rsp_type(ri)
2215                    print_wrapped_type(ri)
2216                    print_dump_prototype(ri)
2217                    cw.nl()
2218
2219                if 'notify' in op:
2220                    cw.p(f"/* {op.enum_name} - notify */")
2221                    ri = RenderInfo(cw, parsed, args.mode, op, op_name, 'notify')
2222                    has_ntf = True
2223                    if not ri.type_consistent:
2224                        raise Exception('Only notifications with consistent types supported')
2225                    print_wrapped_type(ri)
2226
2227                if 'event' in op:
2228                    ri = RenderInfo(cw, parsed, args.mode, op, op_name, 'event')
2229                    cw.p(f"/* {op.enum_name} - event */")
2230                    print_rsp_type(ri)
2231                    cw.nl()
2232                    print_wrapped_type(ri)
2233
2234            if has_ntf:
2235                cw.p('/* --------------- Common notification parsing --------------- */')
2236                print_ntf_parse_prototype(parsed, cw)
2237            cw.nl()
2238        else:
2239            cw.p('/* Policies */')
2240            for name, _ in parsed.attr_sets.items():
2241                struct = Struct(parsed, name)
2242                put_typol_fwd(cw, struct)
2243            cw.nl()
2244
2245            for name, _ in parsed.attr_sets.items():
2246                struct = Struct(parsed, name)
2247                put_typol(cw, struct)
2248
2249            cw.p('/* Common nested types */')
2250            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2251                ri = RenderInfo(cw, parsed, args.mode, "", "", "", attr_set)
2252
2253                free_rsp_nested(ri, struct)
2254                if struct.request:
2255                    put_req_nested(ri, struct)
2256                if struct.reply:
2257                    parse_rsp_nested(ri, struct)
2258
2259            for op_name, op in parsed.ops.items():
2260                cw.p(f"/* ============== {op.enum_name} ============== */")
2261                if 'do' in op and 'event' not in op:
2262                    cw.p(f"/* {op.enum_name} - do */")
2263                    ri = RenderInfo(cw, parsed, args.mode, op, op_name, "do")
2264                    print_rsp_free(ri)
2265                    parse_rsp_msg(ri)
2266                    print_req(ri)
2267                    cw.nl()
2268
2269                if 'dump' in op:
2270                    cw.p(f"/* {op.enum_name} - dump */")
2271                    ri = RenderInfo(cw, parsed, args.mode, op, op_name, "dump")
2272                    if not ri.type_consistent:
2273                        parse_rsp_msg(ri, deref=True)
2274                    print_dump_type_free(ri)
2275                    print_dump(ri)
2276                    cw.nl()
2277
2278                if 'notify' in op:
2279                    cw.p(f"/* {op.enum_name} - notify */")
2280                    ri = RenderInfo(cw, parsed, args.mode, op, op_name, 'notify')
2281                    has_ntf = True
2282                    if not ri.type_consistent:
2283                        raise Exception('Only notifications with consistent types supported')
2284                    print_ntf_type_free(ri)
2285
2286                if 'event' in op:
2287                    cw.p(f"/* {op.enum_name} - event */")
2288                    has_ntf = True
2289
2290                    ri = RenderInfo(cw, parsed, args.mode, op, op_name, "do")
2291                    parse_rsp_msg(ri)
2292
2293                    ri = RenderInfo(cw, parsed, args.mode, op, op_name, "event")
2294                    print_ntf_type_free(ri)
2295
2296            if has_ntf:
2297                cw.p('/* --------------- Common notification parsing --------------- */')
2298                print_ntf_type_parse(parsed, cw, args.mode)
2299
2300    if args.header:
2301        cw.p(f'#endif /* {hdr_prot} */')
2302
2303
2304if __name__ == "__main__":
2305    main()
2306