xref: /linux/tools/net/ynl/pyynl/ynl_gen_c.py (revision 0ad9617c78acbc71373fb341a6f75d4012b01d69)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import filecmp
7import pathlib
8import os
9import re
10import shutil
11import sys
12import tempfile
13import yaml
14
15sys.path.append(pathlib.Path(__file__).resolve().parent.as_posix())
16from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
17
18
19def c_upper(name):
20    return name.upper().replace('-', '_')
21
22
23def c_lower(name):
24    return name.lower().replace('-', '_')
25
26
27def limit_to_number(name):
28    """
29    Turn a string limit like u32-max or s64-min into its numerical value
30    """
31    if name[0] == 'u' and name.endswith('-min'):
32        return 0
33    width = int(name[1:-4])
34    if name[0] == 's':
35        width -= 1
36    value = (1 << width) - 1
37    if name[0] == 's' and name.endswith('-min'):
38        value = -value - 1
39    return value
40
41
42class BaseNlLib:
43    def get_family_id(self):
44        return 'ys->family_id'
45
46
47class Type(SpecAttr):
48    def __init__(self, family, attr_set, attr, value):
49        super().__init__(family, attr_set, attr, value)
50
51        self.attr = attr
52        self.attr_set = attr_set
53        self.type = attr['type']
54        self.checks = attr.get('checks', {})
55
56        self.request = False
57        self.reply = False
58
59        if 'len' in attr:
60            self.len = attr['len']
61
62        if 'nested-attributes' in attr:
63            self.nested_attrs = attr['nested-attributes']
64            if self.nested_attrs == family.name:
65                self.nested_render_name = c_lower(f"{family.ident_name}")
66            else:
67                self.nested_render_name = c_lower(f"{family.ident_name}_{self.nested_attrs}")
68
69            if self.nested_attrs in self.family.consts:
70                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
71            else:
72                self.nested_struct_type = 'struct ' + self.nested_render_name
73
74        self.c_name = c_lower(self.name)
75        if self.c_name in _C_KW:
76            self.c_name += '_'
77
78        # Added by resolve():
79        self.enum_name = None
80        delattr(self, "enum_name")
81
82    def _get_real_attr(self):
83        # if the attr is for a subset return the "real" attr (just one down, does not recurse)
84        return self.family.attr_sets[self.attr_set.subset_of][self.name]
85
86    def set_request(self):
87        self.request = True
88        if self.attr_set.subset_of:
89            self._get_real_attr().set_request()
90
91    def set_reply(self):
92        self.reply = True
93        if self.attr_set.subset_of:
94            self._get_real_attr().set_reply()
95
96    def get_limit(self, limit, default=None):
97        value = self.checks.get(limit, default)
98        if value is None:
99            return value
100        if isinstance(value, int):
101            return value
102        if value in self.family.consts:
103            raise Exception("Resolving family constants not implemented, yet")
104        return limit_to_number(value)
105
106    def get_limit_str(self, limit, default=None, suffix=''):
107        value = self.checks.get(limit, default)
108        if value is None:
109            return ''
110        if isinstance(value, int):
111            return str(value) + suffix
112        if value in self.family.consts:
113            return c_upper(f"{self.family['name']}-{value}")
114        return c_upper(value)
115
116    def resolve(self):
117        if 'name-prefix' in self.attr:
118            enum_name = f"{self.attr['name-prefix']}{self.name}"
119        else:
120            enum_name = f"{self.attr_set.name_prefix}{self.name}"
121        self.enum_name = c_upper(enum_name)
122
123        if self.attr_set.subset_of:
124            if self.checks != self._get_real_attr().checks:
125                raise Exception("Overriding checks not supported by codegen, yet")
126
127    def is_multi_val(self):
128        return None
129
130    def is_scalar(self):
131        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
132
133    def is_recursive(self):
134        return False
135
136    def is_recursive_for_op(self, ri):
137        return self.is_recursive() and not ri.op
138
139    def presence_type(self):
140        return 'bit'
141
142    def presence_member(self, space, type_filter):
143        if self.presence_type() != type_filter:
144            return
145
146        if self.presence_type() == 'bit':
147            pfx = '__' if space == 'user' else ''
148            return f"{pfx}u32 {self.c_name}:1;"
149
150        if self.presence_type() == 'len':
151            pfx = '__' if space == 'user' else ''
152            return f"{pfx}u32 {self.c_name}_len;"
153
154    def _complex_member_type(self, ri):
155        return None
156
157    def free_needs_iter(self):
158        return False
159
160    def free(self, ri, var, ref):
161        if self.is_multi_val() or self.presence_type() == 'len':
162            ri.cw.p(f'free({var}->{ref}{self.c_name});')
163
164    def arg_member(self, ri):
165        member = self._complex_member_type(ri)
166        if member:
167            arg = [member + ' *' + self.c_name]
168            if self.presence_type() == 'count':
169                arg += ['unsigned int n_' + self.c_name]
170            return arg
171        raise Exception(f"Struct member not implemented for class type {self.type}")
172
173    def struct_member(self, ri):
174        if self.is_multi_val():
175            ri.cw.p(f"unsigned int n_{self.c_name};")
176        member = self._complex_member_type(ri)
177        if member:
178            ptr = '*' if self.is_multi_val() else ''
179            if self.is_recursive_for_op(ri):
180                ptr = '*'
181            ri.cw.p(f"{member} {ptr}{self.c_name};")
182            return
183        members = self.arg_member(ri)
184        for one in members:
185            ri.cw.p(one + ';')
186
187    def _attr_policy(self, policy):
188        return '{ .type = ' + policy + ', }'
189
190    def attr_policy(self, cw):
191        policy = f'NLA_{c_upper(self.type)}'
192        if self.attr.get('byte-order') == 'big-endian':
193            if self.type in {'u16', 'u32'}:
194                policy = f'NLA_BE{self.type[1:]}'
195
196        spec = self._attr_policy(policy)
197        cw.p(f"\t[{self.enum_name}] = {spec},")
198
199    def _attr_typol(self):
200        raise Exception(f"Type policy not implemented for class type {self.type}")
201
202    def attr_typol(self, cw):
203        typol = self._attr_typol()
204        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
205
206    def _attr_put_line(self, ri, var, line):
207        if self.presence_type() == 'bit':
208            ri.cw.p(f"if ({var}->_present.{self.c_name})")
209        elif self.presence_type() == 'len':
210            ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
211        ri.cw.p(f"{line};")
212
213    def _attr_put_simple(self, ri, var, put_type):
214        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
215        self._attr_put_line(ri, var, line)
216
217    def attr_put(self, ri, var):
218        raise Exception(f"Put not implemented for class type {self.type}")
219
220    def _attr_get(self, ri, var):
221        raise Exception(f"Attr get not implemented for class type {self.type}")
222
223    def attr_get(self, ri, var, first):
224        lines, init_lines, local_vars = self._attr_get(ri, var)
225        if type(lines) is str:
226            lines = [lines]
227        if type(init_lines) is str:
228            init_lines = [init_lines]
229
230        kw = 'if' if first else 'else if'
231        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
232        if local_vars:
233            for local in local_vars:
234                ri.cw.p(local)
235            ri.cw.nl()
236
237        if not self.is_multi_val():
238            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
239            ri.cw.p("return YNL_PARSE_CB_ERROR;")
240            if self.presence_type() == 'bit':
241                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
242
243        if init_lines:
244            ri.cw.nl()
245            for line in init_lines:
246                ri.cw.p(line)
247
248        for line in lines:
249            ri.cw.p(line)
250        ri.cw.block_end()
251        return True
252
253    def _setter_lines(self, ri, member, presence):
254        raise Exception(f"Setter not implemented for class type {self.type}")
255
256    def setter(self, ri, space, direction, deref=False, ref=None):
257        ref = (ref if ref else []) + [self.c_name]
258        var = "req"
259        member = f"{var}->{'.'.join(ref)}"
260
261        code = []
262        presence = ''
263        for i in range(0, len(ref)):
264            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
265            # Every layer below last is a nest, so we know it uses bit presence
266            # last layer is "self" and may be a complex type
267            if i == len(ref) - 1 and self.presence_type() != 'bit':
268                continue
269            code.append(presence + ' = 1;')
270        code += self._setter_lines(ri, member, presence)
271
272        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
273        free = bool([x for x in code if 'free(' in x])
274        alloc = bool([x for x in code if 'alloc(' in x])
275        if free and not alloc:
276            func_name = '__' + func_name
277        ri.cw.write_func('static inline void', func_name, body=code,
278                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
279
280
281class TypeUnused(Type):
282    def presence_type(self):
283        return ''
284
285    def arg_member(self, ri):
286        return []
287
288    def _attr_get(self, ri, var):
289        return ['return YNL_PARSE_CB_ERROR;'], None, None
290
291    def _attr_typol(self):
292        return '.type = YNL_PT_REJECT, '
293
294    def attr_policy(self, cw):
295        pass
296
297    def attr_put(self, ri, var):
298        pass
299
300    def attr_get(self, ri, var, first):
301        pass
302
303    def setter(self, ri, space, direction, deref=False, ref=None):
304        pass
305
306
307class TypePad(Type):
308    def presence_type(self):
309        return ''
310
311    def arg_member(self, ri):
312        return []
313
314    def _attr_typol(self):
315        return '.type = YNL_PT_IGNORE, '
316
317    def attr_put(self, ri, var):
318        pass
319
320    def attr_get(self, ri, var, first):
321        pass
322
323    def attr_policy(self, cw):
324        pass
325
326    def setter(self, ri, space, direction, deref=False, ref=None):
327        pass
328
329
330class TypeScalar(Type):
331    def __init__(self, family, attr_set, attr, value):
332        super().__init__(family, attr_set, attr, value)
333
334        self.byte_order_comment = ''
335        if 'byte-order' in attr:
336            self.byte_order_comment = f" /* {attr['byte-order']} */"
337
338        if 'enum' in self.attr:
339            enum = self.family.consts[self.attr['enum']]
340            low, high = enum.value_range()
341            if 'min' not in self.checks:
342                if low != 0 or self.type[0] == 's':
343                    self.checks['min'] = low
344            if 'max' not in self.checks:
345                self.checks['max'] = high
346
347        if 'min' in self.checks and 'max' in self.checks:
348            if self.get_limit('min') > self.get_limit('max'):
349                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
350            self.checks['range'] = True
351
352        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
353        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
354        if low < 0 and self.type[0] == 'u':
355            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
356        if low < -32768 or high > 32767:
357            self.checks['full-range'] = True
358
359        # Added by resolve():
360        self.is_bitfield = None
361        delattr(self, "is_bitfield")
362        self.type_name = None
363        delattr(self, "type_name")
364
365    def resolve(self):
366        self.resolve_up(super())
367
368        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
369            self.is_bitfield = True
370        elif 'enum' in self.attr:
371            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
372        else:
373            self.is_bitfield = False
374
375        if not self.is_bitfield and 'enum' in self.attr:
376            self.type_name = self.family.consts[self.attr['enum']].user_type
377        elif self.is_auto_scalar:
378            self.type_name = '__' + self.type[0] + '64'
379        else:
380            self.type_name = '__' + self.type
381
382    def _attr_policy(self, policy):
383        if 'flags-mask' in self.checks or self.is_bitfield:
384            if self.is_bitfield:
385                enum = self.family.consts[self.attr['enum']]
386                mask = enum.get_mask(as_flags=True)
387            else:
388                flags = self.family.consts[self.checks['flags-mask']]
389                flag_cnt = len(flags['entries'])
390                mask = (1 << flag_cnt) - 1
391            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
392        elif 'full-range' in self.checks:
393            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
394        elif 'range' in self.checks:
395            return f"NLA_POLICY_RANGE({policy}, {self.get_limit_str('min')}, {self.get_limit_str('max')})"
396        elif 'min' in self.checks:
397            return f"NLA_POLICY_MIN({policy}, {self.get_limit_str('min')})"
398        elif 'max' in self.checks:
399            return f"NLA_POLICY_MAX({policy}, {self.get_limit_str('max')})"
400        return super()._attr_policy(policy)
401
402    def _attr_typol(self):
403        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
404
405    def arg_member(self, ri):
406        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
407
408    def attr_put(self, ri, var):
409        self._attr_put_simple(ri, var, self.type)
410
411    def _attr_get(self, ri, var):
412        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
413
414    def _setter_lines(self, ri, member, presence):
415        return [f"{member} = {self.c_name};"]
416
417
418class TypeFlag(Type):
419    def arg_member(self, ri):
420        return []
421
422    def _attr_typol(self):
423        return '.type = YNL_PT_FLAG, '
424
425    def attr_put(self, ri, var):
426        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
427
428    def _attr_get(self, ri, var):
429        return [], None, None
430
431    def _setter_lines(self, ri, member, presence):
432        return []
433
434
435class TypeString(Type):
436    def arg_member(self, ri):
437        return [f"const char *{self.c_name}"]
438
439    def presence_type(self):
440        return 'len'
441
442    def struct_member(self, ri):
443        ri.cw.p(f"char *{self.c_name};")
444
445    def _attr_typol(self):
446        return f'.type = YNL_PT_NUL_STR, '
447
448    def _attr_policy(self, policy):
449        if 'exact-len' in self.checks:
450            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
451        else:
452            mem = '{ .type = ' + policy
453            if 'max-len' in self.checks:
454                mem += ', .len = ' + self.get_limit_str('max-len')
455            mem += ', }'
456        return mem
457
458    def attr_policy(self, cw):
459        if self.checks.get('unterminated-ok', False):
460            policy = 'NLA_STRING'
461        else:
462            policy = 'NLA_NUL_STRING'
463
464        spec = self._attr_policy(policy)
465        cw.p(f"\t[{self.enum_name}] = {spec},")
466
467    def attr_put(self, ri, var):
468        self._attr_put_simple(ri, var, 'str')
469
470    def _attr_get(self, ri, var):
471        len_mem = var + '->_present.' + self.c_name + '_len'
472        return [f"{len_mem} = len;",
473                f"{var}->{self.c_name} = malloc(len + 1);",
474                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
475                f"{var}->{self.c_name}[len] = 0;"], \
476               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
477               ['unsigned int len;']
478
479    def _setter_lines(self, ri, member, presence):
480        return [f"free({member});",
481                f"{presence}_len = strlen({self.c_name});",
482                f"{member} = malloc({presence}_len + 1);",
483                f'memcpy({member}, {self.c_name}, {presence}_len);',
484                f'{member}[{presence}_len] = 0;']
485
486
487class TypeBinary(Type):
488    def arg_member(self, ri):
489        return [f"const void *{self.c_name}", 'size_t len']
490
491    def presence_type(self):
492        return 'len'
493
494    def struct_member(self, ri):
495        ri.cw.p(f"void *{self.c_name};")
496
497    def _attr_typol(self):
498        return f'.type = YNL_PT_BINARY,'
499
500    def _attr_policy(self, policy):
501        if len(self.checks) == 0:
502            pass
503        elif len(self.checks) == 1:
504            check_name = list(self.checks)[0]
505            if check_name not in {'exact-len', 'min-len', 'max-len'}:
506                raise Exception('Unsupported check for binary type: ' + check_name)
507        else:
508            raise Exception('More than one check for binary type not implemented, yet')
509
510        if len(self.checks) == 0:
511            mem = '{ .type = NLA_BINARY, }'
512        elif 'exact-len' in self.checks:
513            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
514        elif 'min-len' in self.checks:
515            mem = '{ .len = ' + self.get_limit_str('min-len') + ', }'
516        elif 'max-len' in self.checks:
517            mem = 'NLA_POLICY_MAX_LEN(' + self.get_limit_str('max-len') + ')'
518
519        return mem
520
521    def attr_put(self, ri, var):
522        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
523                            f"{var}->{self.c_name}, {var}->_present.{self.c_name}_len)")
524
525    def _attr_get(self, ri, var):
526        len_mem = var + '->_present.' + self.c_name + '_len'
527        return [f"{len_mem} = len;",
528                f"{var}->{self.c_name} = malloc(len);",
529                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
530               ['len = ynl_attr_data_len(attr);'], \
531               ['unsigned int len;']
532
533    def _setter_lines(self, ri, member, presence):
534        return [f"free({member});",
535                f"{presence}_len = len;",
536                f"{member} = malloc({presence}_len);",
537                f'memcpy({member}, {self.c_name}, {presence}_len);']
538
539
540class TypeBitfield32(Type):
541    def _complex_member_type(self, ri):
542        return "struct nla_bitfield32"
543
544    def _attr_typol(self):
545        return f'.type = YNL_PT_BITFIELD32, '
546
547    def _attr_policy(self, policy):
548        if not 'enum' in self.attr:
549            raise Exception('Enum required for bitfield32 attr')
550        enum = self.family.consts[self.attr['enum']]
551        mask = enum.get_mask(as_flags=True)
552        return f"NLA_POLICY_BITFIELD32({mask})"
553
554    def attr_put(self, ri, var):
555        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
556        self._attr_put_line(ri, var, line)
557
558    def _attr_get(self, ri, var):
559        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
560
561    def _setter_lines(self, ri, member, presence):
562        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
563
564
565class TypeNest(Type):
566    def is_recursive(self):
567        return self.family.pure_nested_structs[self.nested_attrs].recursive
568
569    def _complex_member_type(self, ri):
570        return self.nested_struct_type
571
572    def free(self, ri, var, ref):
573        at = '&'
574        if self.is_recursive_for_op(ri):
575            at = ''
576            ri.cw.p(f'if ({var}->{ref}{self.c_name})')
577        ri.cw.p(f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});')
578
579    def _attr_typol(self):
580        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
581
582    def _attr_policy(self, policy):
583        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
584
585    def attr_put(self, ri, var):
586        at = '' if self.is_recursive_for_op(ri) else '&'
587        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
588                            f"{self.enum_name}, {at}{var}->{self.c_name})")
589
590    def _attr_get(self, ri, var):
591        get_lines = [f"if ({self.nested_render_name}_parse(&parg, attr))",
592                     "return YNL_PARSE_CB_ERROR;"]
593        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
594                      f"parg.data = &{var}->{self.c_name};"]
595        return get_lines, init_lines, None
596
597    def setter(self, ri, space, direction, deref=False, ref=None):
598        ref = (ref if ref else []) + [self.c_name]
599
600        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
601            if attr.is_recursive():
602                continue
603            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
604
605
606class TypeMultiAttr(Type):
607    def __init__(self, family, attr_set, attr, value, base_type):
608        super().__init__(family, attr_set, attr, value)
609
610        self.base_type = base_type
611
612    def is_multi_val(self):
613        return True
614
615    def presence_type(self):
616        return 'count'
617
618    def _complex_member_type(self, ri):
619        if 'type' not in self.attr or self.attr['type'] == 'nest':
620            return self.nested_struct_type
621        elif self.attr['type'] in scalars:
622            scalar_pfx = '__' if ri.ku_space == 'user' else ''
623            return scalar_pfx + self.attr['type']
624        else:
625            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
626
627    def free_needs_iter(self):
628        return 'type' not in self.attr or self.attr['type'] == 'nest'
629
630    def free(self, ri, var, ref):
631        if self.attr['type'] in scalars:
632            ri.cw.p(f"free({var}->{ref}{self.c_name});")
633        elif 'type' not in self.attr or self.attr['type'] == 'nest':
634            ri.cw.p(f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)")
635            ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);')
636            ri.cw.p(f"free({var}->{ref}{self.c_name});")
637        else:
638            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
639
640    def _attr_policy(self, policy):
641        return self.base_type._attr_policy(policy)
642
643    def _attr_typol(self):
644        return self.base_type._attr_typol()
645
646    def _attr_get(self, ri, var):
647        return f'n_{self.c_name}++;', None, None
648
649    def attr_put(self, ri, var):
650        if self.attr['type'] in scalars:
651            put_type = self.type
652            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
653            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
654        elif 'type' not in self.attr or self.attr['type'] == 'nest':
655            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
656            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
657                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
658        else:
659            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
660
661    def _setter_lines(self, ri, member, presence):
662        # For multi-attr we have a count, not presence, hack up the presence
663        presence = presence[:-(len('_present.') + len(self.c_name))] + "n_" + self.c_name
664        return [f"free({member});",
665                f"{member} = {self.c_name};",
666                f"{presence} = n_{self.c_name};"]
667
668
669class TypeArrayNest(Type):
670    def is_multi_val(self):
671        return True
672
673    def presence_type(self):
674        return 'count'
675
676    def _complex_member_type(self, ri):
677        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
678            return self.nested_struct_type
679        elif self.attr['sub-type'] in scalars:
680            scalar_pfx = '__' if ri.ku_space == 'user' else ''
681            return scalar_pfx + self.attr['sub-type']
682        else:
683            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
684
685    def _attr_typol(self):
686        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
687
688    def _attr_get(self, ri, var):
689        local_vars = ['const struct nlattr *attr2;']
690        get_lines = [f'attr_{self.c_name} = attr;',
691                     'ynl_attr_for_each_nested(attr2, attr)',
692                     f'\t{var}->n_{self.c_name}++;']
693        return get_lines, None, local_vars
694
695
696class TypeNestTypeValue(Type):
697    def _complex_member_type(self, ri):
698        return self.nested_struct_type
699
700    def _attr_typol(self):
701        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
702
703    def _attr_get(self, ri, var):
704        prev = 'attr'
705        tv_args = ''
706        get_lines = []
707        local_vars = []
708        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
709                      f"parg.data = &{var}->{self.c_name};"]
710        if 'type-value' in self.attr:
711            tv_names = [c_lower(x) for x in self.attr["type-value"]]
712            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
713            local_vars += [f'__u32 {", ".join(tv_names)};']
714            for level in self.attr["type-value"]:
715                level = c_lower(level)
716                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
717                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
718                prev = 'attr_' + level
719
720            tv_args = f", {', '.join(tv_names)}"
721
722        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
723        return get_lines, init_lines, local_vars
724
725
726class Struct:
727    def __init__(self, family, space_name, type_list=None, inherited=None):
728        self.family = family
729        self.space_name = space_name
730        self.attr_set = family.attr_sets[space_name]
731        # Use list to catch comparisons with empty sets
732        self._inherited = inherited if inherited is not None else []
733        self.inherited = []
734
735        self.nested = type_list is None
736        if family.name == c_lower(space_name):
737            self.render_name = c_lower(family.ident_name)
738        else:
739            self.render_name = c_lower(family.ident_name + '-' + space_name)
740        self.struct_name = 'struct ' + self.render_name
741        if self.nested and space_name in family.consts:
742            self.struct_name += '_'
743        self.ptr_name = self.struct_name + ' *'
744        # All attr sets this one contains, directly or multiple levels down
745        self.child_nests = set()
746
747        self.request = False
748        self.reply = False
749        self.recursive = False
750
751        self.attr_list = []
752        self.attrs = dict()
753        if type_list is not None:
754            for t in type_list:
755                self.attr_list.append((t, self.attr_set[t]),)
756        else:
757            for t in self.attr_set:
758                self.attr_list.append((t, self.attr_set[t]),)
759
760        max_val = 0
761        self.attr_max_val = None
762        for name, attr in self.attr_list:
763            if attr.value >= max_val:
764                max_val = attr.value
765                self.attr_max_val = attr
766            self.attrs[name] = attr
767
768    def __iter__(self):
769        yield from self.attrs
770
771    def __getitem__(self, key):
772        return self.attrs[key]
773
774    def member_list(self):
775        return self.attr_list
776
777    def set_inherited(self, new_inherited):
778        if self._inherited != new_inherited:
779            raise Exception("Inheriting different members not supported")
780        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
781
782
783class EnumEntry(SpecEnumEntry):
784    def __init__(self, enum_set, yaml, prev, value_start):
785        super().__init__(enum_set, yaml, prev, value_start)
786
787        if prev:
788            self.value_change = (self.value != prev.value + 1)
789        else:
790            self.value_change = (self.value != 0)
791        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
792
793        # Added by resolve:
794        self.c_name = None
795        delattr(self, "c_name")
796
797    def resolve(self):
798        self.resolve_up(super())
799
800        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
801
802
803class EnumSet(SpecEnumSet):
804    def __init__(self, family, yaml):
805        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
806
807        if 'enum-name' in yaml:
808            if yaml['enum-name']:
809                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
810                self.user_type = self.enum_name
811            else:
812                self.enum_name = None
813        else:
814            self.enum_name = 'enum ' + self.render_name
815
816        if self.enum_name:
817            self.user_type = self.enum_name
818        else:
819            self.user_type = 'int'
820
821        self.value_pfx = yaml.get('name-prefix', f"{family.ident_name}-{yaml['name']}-")
822        self.header = yaml.get('header', None)
823        self.enum_cnt_name = yaml.get('enum-cnt-name', None)
824
825        super().__init__(family, yaml)
826
827    def new_entry(self, entry, prev_entry, value_start):
828        return EnumEntry(self, entry, prev_entry, value_start)
829
830    def value_range(self):
831        low = min([x.value for x in self.entries.values()])
832        high = max([x.value for x in self.entries.values()])
833
834        if high - low + 1 != len(self.entries):
835            raise Exception("Can't get value range for a noncontiguous enum")
836
837        return low, high
838
839
840class AttrSet(SpecAttrSet):
841    def __init__(self, family, yaml):
842        super().__init__(family, yaml)
843
844        if self.subset_of is None:
845            if 'name-prefix' in yaml:
846                pfx = yaml['name-prefix']
847            elif self.name == family.name:
848                pfx = family.ident_name + '-a-'
849            else:
850                pfx = f"{family.ident_name}-a-{self.name}-"
851            self.name_prefix = c_upper(pfx)
852            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
853            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
854        else:
855            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
856            self.max_name = family.attr_sets[self.subset_of].max_name
857            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
858
859        # Added by resolve:
860        self.c_name = None
861        delattr(self, "c_name")
862
863    def resolve(self):
864        self.c_name = c_lower(self.name)
865        if self.c_name in _C_KW:
866            self.c_name += '_'
867        if self.c_name == self.family.c_name:
868            self.c_name = ''
869
870    def new_attr(self, elem, value):
871        if elem['type'] in scalars:
872            t = TypeScalar(self.family, self, elem, value)
873        elif elem['type'] == 'unused':
874            t = TypeUnused(self.family, self, elem, value)
875        elif elem['type'] == 'pad':
876            t = TypePad(self.family, self, elem, value)
877        elif elem['type'] == 'flag':
878            t = TypeFlag(self.family, self, elem, value)
879        elif elem['type'] == 'string':
880            t = TypeString(self.family, self, elem, value)
881        elif elem['type'] == 'binary':
882            t = TypeBinary(self.family, self, elem, value)
883        elif elem['type'] == 'bitfield32':
884            t = TypeBitfield32(self.family, self, elem, value)
885        elif elem['type'] == 'nest':
886            t = TypeNest(self.family, self, elem, value)
887        elif elem['type'] == 'indexed-array' and 'sub-type' in elem:
888            if elem["sub-type"] == 'nest':
889                t = TypeArrayNest(self.family, self, elem, value)
890            else:
891                raise Exception(f'new_attr: unsupported sub-type {elem["sub-type"]}')
892        elif elem['type'] == 'nest-type-value':
893            t = TypeNestTypeValue(self.family, self, elem, value)
894        else:
895            raise Exception(f"No typed class for type {elem['type']}")
896
897        if 'multi-attr' in elem and elem['multi-attr']:
898            t = TypeMultiAttr(self.family, self, elem, value, t)
899
900        return t
901
902
903class Operation(SpecOperation):
904    def __init__(self, family, yaml, req_value, rsp_value):
905        super().__init__(family, yaml, req_value, rsp_value)
906
907        self.render_name = c_lower(family.ident_name + '_' + self.name)
908
909        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
910                         ('dump' in yaml and 'request' in yaml['dump'])
911
912        self.has_ntf = False
913
914        # Added by resolve:
915        self.enum_name = None
916        delattr(self, "enum_name")
917
918    def resolve(self):
919        self.resolve_up(super())
920
921        if not self.is_async:
922            self.enum_name = self.family.op_prefix + c_upper(self.name)
923        else:
924            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
925
926    def mark_has_ntf(self):
927        self.has_ntf = True
928
929
930class Family(SpecFamily):
931    def __init__(self, file_name, exclude_ops):
932        # Added by resolve:
933        self.c_name = None
934        delattr(self, "c_name")
935        self.op_prefix = None
936        delattr(self, "op_prefix")
937        self.async_op_prefix = None
938        delattr(self, "async_op_prefix")
939        self.mcgrps = None
940        delattr(self, "mcgrps")
941        self.consts = None
942        delattr(self, "consts")
943        self.hooks = None
944        delattr(self, "hooks")
945
946        super().__init__(file_name, exclude_ops=exclude_ops)
947
948        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
949        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
950
951        if 'definitions' not in self.yaml:
952            self.yaml['definitions'] = []
953
954        if 'uapi-header' in self.yaml:
955            self.uapi_header = self.yaml['uapi-header']
956        else:
957            self.uapi_header = f"linux/{self.ident_name}.h"
958        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
959            self.uapi_header_name = self.uapi_header[6:-2]
960        else:
961            self.uapi_header_name = self.ident_name
962
963    def resolve(self):
964        self.resolve_up(super())
965
966        if self.yaml.get('protocol', 'genetlink') not in {'genetlink', 'genetlink-c', 'genetlink-legacy'}:
967            raise Exception("Codegen only supported for genetlink")
968
969        self.c_name = c_lower(self.ident_name)
970        if 'name-prefix' in self.yaml['operations']:
971            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
972        else:
973            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
974        if 'async-prefix' in self.yaml['operations']:
975            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
976        else:
977            self.async_op_prefix = self.op_prefix
978
979        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
980
981        self.hooks = dict()
982        for when in ['pre', 'post']:
983            self.hooks[when] = dict()
984            for op_mode in ['do', 'dump']:
985                self.hooks[when][op_mode] = dict()
986                self.hooks[when][op_mode]['set'] = set()
987                self.hooks[when][op_mode]['list'] = []
988
989        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
990        self.root_sets = dict()
991        # dict space-name -> set('request', 'reply')
992        self.pure_nested_structs = dict()
993
994        self._mark_notify()
995        self._mock_up_events()
996
997        self._load_root_sets()
998        self._load_nested_sets()
999        self._load_attr_use()
1000        self._load_hooks()
1001
1002        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
1003        if self.kernel_policy == 'global':
1004            self._load_global_policy()
1005
1006    def new_enum(self, elem):
1007        return EnumSet(self, elem)
1008
1009    def new_attr_set(self, elem):
1010        return AttrSet(self, elem)
1011
1012    def new_operation(self, elem, req_value, rsp_value):
1013        return Operation(self, elem, req_value, rsp_value)
1014
1015    def _mark_notify(self):
1016        for op in self.msgs.values():
1017            if 'notify' in op:
1018                self.ops[op['notify']].mark_has_ntf()
1019
1020    # Fake a 'do' equivalent of all events, so that we can render their response parsing
1021    def _mock_up_events(self):
1022        for op in self.yaml['operations']['list']:
1023            if 'event' in op:
1024                op['do'] = {
1025                    'reply': {
1026                        'attributes': op['event']['attributes']
1027                    }
1028                }
1029
1030    def _load_root_sets(self):
1031        for op_name, op in self.msgs.items():
1032            if 'attribute-set' not in op:
1033                continue
1034
1035            req_attrs = set()
1036            rsp_attrs = set()
1037            for op_mode in ['do', 'dump']:
1038                if op_mode in op and 'request' in op[op_mode]:
1039                    req_attrs.update(set(op[op_mode]['request']['attributes']))
1040                if op_mode in op and 'reply' in op[op_mode]:
1041                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
1042            if 'event' in op:
1043                rsp_attrs.update(set(op['event']['attributes']))
1044
1045            if op['attribute-set'] not in self.root_sets:
1046                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1047            else:
1048                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1049                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1050
1051    def _sort_pure_types(self):
1052        # Try to reorder according to dependencies
1053        pns_key_list = list(self.pure_nested_structs.keys())
1054        pns_key_seen = set()
1055        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1056        for _ in range(rounds):
1057            if len(pns_key_list) == 0:
1058                break
1059            name = pns_key_list.pop(0)
1060            finished = True
1061            for _, spec in self.attr_sets[name].items():
1062                if 'nested-attributes' in spec:
1063                    nested = spec['nested-attributes']
1064                    # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1065                    if self.pure_nested_structs[nested].recursive:
1066                        continue
1067                    if nested not in pns_key_seen:
1068                        # Dicts are sorted, this will make struct last
1069                        struct = self.pure_nested_structs.pop(name)
1070                        self.pure_nested_structs[name] = struct
1071                        finished = False
1072                        break
1073            if finished:
1074                pns_key_seen.add(name)
1075            else:
1076                pns_key_list.append(name)
1077
1078    def _load_nested_sets(self):
1079        attr_set_queue = list(self.root_sets.keys())
1080        attr_set_seen = set(self.root_sets.keys())
1081
1082        while len(attr_set_queue):
1083            a_set = attr_set_queue.pop(0)
1084            for attr, spec in self.attr_sets[a_set].items():
1085                if 'nested-attributes' not in spec:
1086                    continue
1087
1088                nested = spec['nested-attributes']
1089                if nested not in attr_set_seen:
1090                    attr_set_queue.append(nested)
1091                    attr_set_seen.add(nested)
1092
1093                inherit = set()
1094                if nested not in self.root_sets:
1095                    if nested not in self.pure_nested_structs:
1096                        self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
1097                else:
1098                    raise Exception(f'Using attr set as root and nested not supported - {nested}')
1099
1100                if 'type-value' in spec:
1101                    if nested in self.root_sets:
1102                        raise Exception("Inheriting members to a space used as root not supported")
1103                    inherit.update(set(spec['type-value']))
1104                elif spec['type'] == 'indexed-array':
1105                    inherit.add('idx')
1106                self.pure_nested_structs[nested].set_inherited(inherit)
1107
1108        for root_set, rs_members in self.root_sets.items():
1109            for attr, spec in self.attr_sets[root_set].items():
1110                if 'nested-attributes' in spec:
1111                    nested = spec['nested-attributes']
1112                    if attr in rs_members['request']:
1113                        self.pure_nested_structs[nested].request = True
1114                    if attr in rs_members['reply']:
1115                        self.pure_nested_structs[nested].reply = True
1116
1117        self._sort_pure_types()
1118
1119        # Propagate the request / reply / recursive
1120        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1121            for _, spec in self.attr_sets[attr_set].items():
1122                if 'nested-attributes' in spec:
1123                    child_name = spec['nested-attributes']
1124                    struct.child_nests.add(child_name)
1125                    child = self.pure_nested_structs.get(child_name)
1126                    if child:
1127                        if not child.recursive:
1128                            struct.child_nests.update(child.child_nests)
1129                        child.request |= struct.request
1130                        child.reply |= struct.reply
1131                if attr_set in struct.child_nests:
1132                    struct.recursive = True
1133
1134        self._sort_pure_types()
1135
1136    def _load_attr_use(self):
1137        for _, struct in self.pure_nested_structs.items():
1138            if struct.request:
1139                for _, arg in struct.member_list():
1140                    arg.set_request()
1141            if struct.reply:
1142                for _, arg in struct.member_list():
1143                    arg.set_reply()
1144
1145        for root_set, rs_members in self.root_sets.items():
1146            for attr, spec in self.attr_sets[root_set].items():
1147                if attr in rs_members['request']:
1148                    spec.set_request()
1149                if attr in rs_members['reply']:
1150                    spec.set_reply()
1151
1152    def _load_global_policy(self):
1153        global_set = set()
1154        attr_set_name = None
1155        for op_name, op in self.ops.items():
1156            if not op:
1157                continue
1158            if 'attribute-set' not in op:
1159                continue
1160
1161            if attr_set_name is None:
1162                attr_set_name = op['attribute-set']
1163            if attr_set_name != op['attribute-set']:
1164                raise Exception('For a global policy all ops must use the same set')
1165
1166            for op_mode in ['do', 'dump']:
1167                if op_mode in op:
1168                    req = op[op_mode].get('request')
1169                    if req:
1170                        global_set.update(req.get('attributes', []))
1171
1172        self.global_policy = []
1173        self.global_policy_set = attr_set_name
1174        for attr in self.attr_sets[attr_set_name]:
1175            if attr in global_set:
1176                self.global_policy.append(attr)
1177
1178    def _load_hooks(self):
1179        for op in self.ops.values():
1180            for op_mode in ['do', 'dump']:
1181                if op_mode not in op:
1182                    continue
1183                for when in ['pre', 'post']:
1184                    if when not in op[op_mode]:
1185                        continue
1186                    name = op[op_mode][when]
1187                    if name in self.hooks[when][op_mode]['set']:
1188                        continue
1189                    self.hooks[when][op_mode]['set'].add(name)
1190                    self.hooks[when][op_mode]['list'].append(name)
1191
1192
1193class RenderInfo:
1194    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1195        self.family = family
1196        self.nl = cw.nlib
1197        self.ku_space = ku_space
1198        self.op_mode = op_mode
1199        self.op = op
1200
1201        self.fixed_hdr = None
1202        if op and op.fixed_header:
1203            self.fixed_hdr = 'struct ' + c_lower(op.fixed_header)
1204
1205        # 'do' and 'dump' response parsing is identical
1206        self.type_consistent = True
1207        if op_mode != 'do' and 'dump' in op:
1208            if 'do' in op:
1209                if ('reply' in op['do']) != ('reply' in op["dump"]):
1210                    self.type_consistent = False
1211                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1212                    self.type_consistent = False
1213            else:
1214                self.type_consistent = False
1215
1216        self.attr_set = attr_set
1217        if not self.attr_set:
1218            self.attr_set = op['attribute-set']
1219
1220        self.type_name_conflict = False
1221        if op:
1222            self.type_name = c_lower(op.name)
1223        else:
1224            self.type_name = c_lower(attr_set)
1225            if attr_set in family.consts:
1226                self.type_name_conflict = True
1227
1228        self.cw = cw
1229
1230        self.struct = dict()
1231        if op_mode == 'notify':
1232            op_mode = 'do'
1233        for op_dir in ['request', 'reply']:
1234            if op:
1235                type_list = []
1236                if op_dir in op[op_mode]:
1237                    type_list = op[op_mode][op_dir]['attributes']
1238                self.struct[op_dir] = Struct(family, self.attr_set, type_list=type_list)
1239        if op_mode == 'event':
1240            self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
1241
1242
1243class CodeWriter:
1244    def __init__(self, nlib, out_file=None, overwrite=True):
1245        self.nlib = nlib
1246        self._overwrite = overwrite
1247
1248        self._nl = False
1249        self._block_end = False
1250        self._silent_block = False
1251        self._ind = 0
1252        self._ifdef_block = None
1253        if out_file is None:
1254            self._out = os.sys.stdout
1255        else:
1256            self._out = tempfile.NamedTemporaryFile('w+')
1257            self._out_file = out_file
1258
1259    def __del__(self):
1260        self.close_out_file()
1261
1262    def close_out_file(self):
1263        if self._out == os.sys.stdout:
1264            return
1265        # Avoid modifying the file if contents didn't change
1266        self._out.flush()
1267        if not self._overwrite and os.path.isfile(self._out_file):
1268            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1269                return
1270        with open(self._out_file, 'w+') as out_file:
1271            self._out.seek(0)
1272            shutil.copyfileobj(self._out, out_file)
1273            self._out.close()
1274        self._out = os.sys.stdout
1275
1276    @classmethod
1277    def _is_cond(cls, line):
1278        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1279
1280    def p(self, line, add_ind=0):
1281        if self._block_end:
1282            self._block_end = False
1283            if line.startswith('else'):
1284                line = '} ' + line
1285            else:
1286                self._out.write('\t' * self._ind + '}\n')
1287
1288        if self._nl:
1289            self._out.write('\n')
1290            self._nl = False
1291
1292        ind = self._ind
1293        if line[-1] == ':':
1294            ind -= 1
1295        if self._silent_block:
1296            ind += 1
1297        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1298        if line[0] == '#':
1299            ind = 0
1300        if add_ind:
1301            ind += add_ind
1302        self._out.write('\t' * ind + line + '\n')
1303
1304    def nl(self):
1305        self._nl = True
1306
1307    def block_start(self, line=''):
1308        if line:
1309            line = line + ' '
1310        self.p(line + '{')
1311        self._ind += 1
1312
1313    def block_end(self, line=''):
1314        if line and line[0] not in {';', ','}:
1315            line = ' ' + line
1316        self._ind -= 1
1317        self._nl = False
1318        if not line:
1319            # Delay printing closing bracket in case "else" comes next
1320            if self._block_end:
1321                self._out.write('\t' * (self._ind + 1) + '}\n')
1322            self._block_end = True
1323        else:
1324            self.p('}' + line)
1325
1326    def write_doc_line(self, doc, indent=True):
1327        words = doc.split()
1328        line = ' *'
1329        for word in words:
1330            if len(line) + len(word) >= 79:
1331                self.p(line)
1332                line = ' *'
1333                if indent:
1334                    line += '  '
1335            line += ' ' + word
1336        self.p(line)
1337
1338    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1339        if not args:
1340            args = ['void']
1341
1342        if doc:
1343            self.p('/*')
1344            self.p(' * ' + doc)
1345            self.p(' */')
1346
1347        oneline = qual_ret
1348        if qual_ret[-1] != '*':
1349            oneline += ' '
1350        oneline += f"{name}({', '.join(args)}){suffix}"
1351
1352        if len(oneline) < 80:
1353            self.p(oneline)
1354            return
1355
1356        v = qual_ret
1357        if len(v) > 3:
1358            self.p(v)
1359            v = ''
1360        elif qual_ret[-1] != '*':
1361            v += ' '
1362        v += name + '('
1363        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1364        delta_ind = len(v) - len(ind)
1365        v += args[0]
1366        i = 1
1367        while i < len(args):
1368            next_len = len(v) + len(args[i])
1369            if v[0] == '\t':
1370                next_len += delta_ind
1371            if next_len > 76:
1372                self.p(v + ',')
1373                v = ind
1374            else:
1375                v += ', '
1376            v += args[i]
1377            i += 1
1378        self.p(v + ')' + suffix)
1379
1380    def write_func_lvar(self, local_vars):
1381        if not local_vars:
1382            return
1383
1384        if type(local_vars) is str:
1385            local_vars = [local_vars]
1386
1387        local_vars.sort(key=len, reverse=True)
1388        for var in local_vars:
1389            self.p(var)
1390        self.nl()
1391
1392    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1393        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1394        self.write_func_lvar(local_vars=local_vars)
1395
1396        self.block_start()
1397        for line in body:
1398            self.p(line)
1399        self.block_end()
1400
1401    def writes_defines(self, defines):
1402        longest = 0
1403        for define in defines:
1404            if len(define[0]) > longest:
1405                longest = len(define[0])
1406        longest = ((longest + 8) // 8) * 8
1407        for define in defines:
1408            line = '#define ' + define[0]
1409            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1410            if type(define[1]) is int:
1411                line += str(define[1])
1412            elif type(define[1]) is str:
1413                line += '"' + define[1] + '"'
1414            self.p(line)
1415
1416    def write_struct_init(self, members):
1417        longest = max([len(x[0]) for x in members])
1418        longest += 1  # because we prepend a .
1419        longest = ((longest + 8) // 8) * 8
1420        for one in members:
1421            line = '.' + one[0]
1422            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1423            line += '= ' + str(one[1]) + ','
1424            self.p(line)
1425
1426    def ifdef_block(self, config):
1427        config_option = None
1428        if config:
1429            config_option = 'CONFIG_' + c_upper(config)
1430        if self._ifdef_block == config_option:
1431            return
1432
1433        if self._ifdef_block:
1434            self.p('#endif /* ' + self._ifdef_block + ' */')
1435        if config_option:
1436            self.p('#ifdef ' + config_option)
1437        self._ifdef_block = config_option
1438
1439
1440scalars = {'u8', 'u16', 'u32', 'u64', 's32', 's64', 'uint', 'sint'}
1441
1442direction_to_suffix = {
1443    'reply': '_rsp',
1444    'request': '_req',
1445    '': ''
1446}
1447
1448op_mode_to_wrapper = {
1449    'do': '',
1450    'dump': '_list',
1451    'notify': '_ntf',
1452    'event': '',
1453}
1454
1455_C_KW = {
1456    'auto',
1457    'bool',
1458    'break',
1459    'case',
1460    'char',
1461    'const',
1462    'continue',
1463    'default',
1464    'do',
1465    'double',
1466    'else',
1467    'enum',
1468    'extern',
1469    'float',
1470    'for',
1471    'goto',
1472    'if',
1473    'inline',
1474    'int',
1475    'long',
1476    'register',
1477    'return',
1478    'short',
1479    'signed',
1480    'sizeof',
1481    'static',
1482    'struct',
1483    'switch',
1484    'typedef',
1485    'union',
1486    'unsigned',
1487    'void',
1488    'volatile',
1489    'while'
1490}
1491
1492
1493def rdir(direction):
1494    if direction == 'reply':
1495        return 'request'
1496    if direction == 'request':
1497        return 'reply'
1498    return direction
1499
1500
1501def op_prefix(ri, direction, deref=False):
1502    suffix = f"_{ri.type_name}"
1503
1504    if not ri.op_mode or ri.op_mode == 'do':
1505        suffix += f"{direction_to_suffix[direction]}"
1506    else:
1507        if direction == 'request':
1508            suffix += '_req_dump'
1509        else:
1510            if ri.type_consistent:
1511                if deref:
1512                    suffix += f"{direction_to_suffix[direction]}"
1513                else:
1514                    suffix += op_mode_to_wrapper[ri.op_mode]
1515            else:
1516                suffix += '_rsp'
1517                suffix += '_dump' if deref else '_list'
1518
1519    return f"{ri.family.c_name}{suffix}"
1520
1521
1522def type_name(ri, direction, deref=False):
1523    return f"struct {op_prefix(ri, direction, deref=deref)}"
1524
1525
1526def print_prototype(ri, direction, terminate=True, doc=None):
1527    suffix = ';' if terminate else ''
1528
1529    fname = ri.op.render_name
1530    if ri.op_mode == 'dump':
1531        fname += '_dump'
1532
1533    args = ['struct ynl_sock *ys']
1534    if 'request' in ri.op[ri.op_mode]:
1535        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1536
1537    ret = 'int'
1538    if 'reply' in ri.op[ri.op_mode]:
1539        ret = f"{type_name(ri, rdir(direction))} *"
1540
1541    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1542
1543
1544def print_req_prototype(ri):
1545    print_prototype(ri, "request", doc=ri.op['doc'])
1546
1547
1548def print_dump_prototype(ri):
1549    print_prototype(ri, "request")
1550
1551
1552def put_typol_fwd(cw, struct):
1553    cw.p(f'extern const struct ynl_policy_nest {struct.render_name}_nest;')
1554
1555
1556def put_typol(cw, struct):
1557    type_max = struct.attr_set.max_name
1558    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1559
1560    for _, arg in struct.member_list():
1561        arg.attr_typol(cw)
1562
1563    cw.block_end(line=';')
1564    cw.nl()
1565
1566    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1567    cw.p(f'.max_attr = {type_max},')
1568    cw.p(f'.table = {struct.render_name}_policy,')
1569    cw.block_end(line=';')
1570    cw.nl()
1571
1572
1573def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1574    args = [f'int {arg_name}']
1575    if enum:
1576        args = [enum.user_type + ' ' + arg_name]
1577    cw.write_func_prot('const char *', f'{render_name}_str', args)
1578    cw.block_start()
1579    if enum and enum.type == 'flags':
1580        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1581    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
1582    cw.p('return NULL;')
1583    cw.p(f'return {map_name}[{arg_name}];')
1584    cw.block_end()
1585    cw.nl()
1586
1587
1588def put_op_name_fwd(family, cw):
1589    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
1590
1591
1592def put_op_name(family, cw):
1593    map_name = f'{family.c_name}_op_strmap'
1594    cw.block_start(line=f"static const char * const {map_name}[] =")
1595    for op_name, op in family.msgs.items():
1596        if op.rsp_value:
1597            # Make sure we don't add duplicated entries, if multiple commands
1598            # produce the same response in legacy families.
1599            if family.rsp_by_value[op.rsp_value] != op:
1600                cw.p(f'// skip "{op_name}", duplicate reply value')
1601                continue
1602
1603            if op.req_value == op.rsp_value:
1604                cw.p(f'[{op.enum_name}] = "{op_name}",')
1605            else:
1606                cw.p(f'[{op.rsp_value}] = "{op_name}",')
1607    cw.block_end(line=';')
1608    cw.nl()
1609
1610    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
1611
1612
1613def put_enum_to_str_fwd(family, cw, enum):
1614    args = [enum.user_type + ' value']
1615    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1616
1617
1618def put_enum_to_str(family, cw, enum):
1619    map_name = f'{enum.render_name}_strmap'
1620    cw.block_start(line=f"static const char * const {map_name}[] =")
1621    for entry in enum.entries.values():
1622        cw.p(f'[{entry.value}] = "{entry.name}",')
1623    cw.block_end(line=';')
1624    cw.nl()
1625
1626    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
1627
1628
1629def put_req_nested_prototype(ri, struct, suffix=';'):
1630    func_args = ['struct nlmsghdr *nlh',
1631                 'unsigned int attr_type',
1632                 f'{struct.ptr_name}obj']
1633
1634    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
1635                          suffix=suffix)
1636
1637
1638def put_req_nested(ri, struct):
1639    put_req_nested_prototype(ri, struct, suffix='')
1640    ri.cw.block_start()
1641    ri.cw.write_func_lvar('struct nlattr *nest;')
1642
1643    ri.cw.p("nest = ynl_attr_nest_start(nlh, attr_type);")
1644
1645    for _, arg in struct.member_list():
1646        arg.attr_put(ri, "obj")
1647
1648    ri.cw.p("ynl_attr_nest_end(nlh, nest);")
1649
1650    ri.cw.nl()
1651    ri.cw.p('return 0;')
1652    ri.cw.block_end()
1653    ri.cw.nl()
1654
1655
1656def _multi_parse(ri, struct, init_lines, local_vars):
1657    if struct.nested:
1658        iter_line = "ynl_attr_for_each_nested(attr, nested)"
1659    else:
1660        if ri.fixed_hdr:
1661            local_vars += ['void *hdr;']
1662        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
1663
1664    array_nests = set()
1665    multi_attrs = set()
1666    needs_parg = False
1667    for arg, aspec in struct.member_list():
1668        if aspec['type'] == 'indexed-array' and 'sub-type' in aspec:
1669            if aspec["sub-type"] == 'nest':
1670                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1671                array_nests.add(arg)
1672            else:
1673                raise Exception(f'Not supported sub-type {aspec["sub-type"]}')
1674        if 'multi-attr' in aspec:
1675            multi_attrs.add(arg)
1676        needs_parg |= 'nested-attributes' in aspec
1677    if array_nests or multi_attrs:
1678        local_vars.append('int i;')
1679    if needs_parg:
1680        local_vars.append('struct ynl_parse_arg parg;')
1681        init_lines.append('parg.ys = yarg->ys;')
1682
1683    all_multi = array_nests | multi_attrs
1684
1685    for anest in sorted(all_multi):
1686        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
1687
1688    ri.cw.block_start()
1689    ri.cw.write_func_lvar(local_vars)
1690
1691    for line in init_lines:
1692        ri.cw.p(line)
1693    ri.cw.nl()
1694
1695    for arg in struct.inherited:
1696        ri.cw.p(f'dst->{arg} = {arg};')
1697
1698    if ri.fixed_hdr:
1699        ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
1700        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({ri.fixed_hdr}));")
1701    for anest in sorted(all_multi):
1702        aspec = struct[anest]
1703        ri.cw.p(f"if (dst->{aspec.c_name})")
1704        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
1705
1706    ri.cw.nl()
1707    ri.cw.block_start(line=iter_line)
1708    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
1709    ri.cw.nl()
1710
1711    first = True
1712    for _, arg in struct.member_list():
1713        good = arg.attr_get(ri, 'dst', first=first)
1714        # First may be 'unused' or 'pad', ignore those
1715        first &= not good
1716
1717    ri.cw.block_end()
1718    ri.cw.nl()
1719
1720    for anest in sorted(array_nests):
1721        aspec = struct[anest]
1722
1723        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1724        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1725        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1726        ri.cw.p('i = 0;')
1727        ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1728        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1729        ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1730        ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
1731        ri.cw.p('return YNL_PARSE_CB_ERROR;')
1732        ri.cw.p('i++;')
1733        ri.cw.block_end()
1734        ri.cw.block_end()
1735    ri.cw.nl()
1736
1737    for anest in sorted(multi_attrs):
1738        aspec = struct[anest]
1739        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1740        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1741        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1742        ri.cw.p('i = 0;')
1743        if 'nested-attributes' in aspec:
1744            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1745        ri.cw.block_start(line=iter_line)
1746        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
1747        if 'nested-attributes' in aspec:
1748            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1749            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1750            ri.cw.p('return YNL_PARSE_CB_ERROR;')
1751        elif aspec.type in scalars:
1752            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
1753        else:
1754            raise Exception('Nest parsing type not supported yet')
1755        ri.cw.p('i++;')
1756        ri.cw.block_end()
1757        ri.cw.block_end()
1758        ri.cw.block_end()
1759    ri.cw.nl()
1760
1761    if struct.nested:
1762        ri.cw.p('return 0;')
1763    else:
1764        ri.cw.p('return YNL_PARSE_CB_OK;')
1765    ri.cw.block_end()
1766    ri.cw.nl()
1767
1768
1769def parse_rsp_nested_prototype(ri, struct, suffix=';'):
1770    func_args = ['struct ynl_parse_arg *yarg',
1771                 'const struct nlattr *nested']
1772    for arg in struct.inherited:
1773        func_args.append('__u32 ' + arg)
1774
1775    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
1776                          suffix=suffix)
1777
1778
1779def parse_rsp_nested(ri, struct):
1780    parse_rsp_nested_prototype(ri, struct, suffix='')
1781
1782    local_vars = ['const struct nlattr *attr;',
1783                  f'{struct.ptr_name}dst = yarg->data;']
1784    init_lines = []
1785
1786    if struct.member_list():
1787        _multi_parse(ri, struct, init_lines, local_vars)
1788    else:
1789        # Empty nest
1790        ri.cw.block_start()
1791        ri.cw.p('return 0;')
1792        ri.cw.block_end()
1793        ri.cw.nl()
1794
1795
1796def parse_rsp_msg(ri, deref=False):
1797    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1798        return
1799
1800    func_args = ['const struct nlmsghdr *nlh',
1801                 'struct ynl_parse_arg *yarg']
1802
1803    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1804                  'const struct nlattr *attr;']
1805    init_lines = ['dst = yarg->data;']
1806
1807    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1808
1809    if ri.struct["reply"].member_list():
1810        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1811    else:
1812        # Empty reply
1813        ri.cw.block_start()
1814        ri.cw.p('return YNL_PARSE_CB_OK;')
1815        ri.cw.block_end()
1816        ri.cw.nl()
1817
1818
1819def print_req(ri):
1820    ret_ok = '0'
1821    ret_err = '-1'
1822    direction = "request"
1823    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
1824                  'struct nlmsghdr *nlh;',
1825                  'int err;']
1826
1827    if 'reply' in ri.op[ri.op_mode]:
1828        ret_ok = 'rsp'
1829        ret_err = 'NULL'
1830        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
1831
1832    if ri.fixed_hdr:
1833        local_vars += ['size_t hdr_len;',
1834                       'void *hdr;']
1835
1836    print_prototype(ri, direction, terminate=False)
1837    ri.cw.block_start()
1838    ri.cw.write_func_lvar(local_vars)
1839
1840    ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1841
1842    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1843    if 'reply' in ri.op[ri.op_mode]:
1844        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1845    ri.cw.nl()
1846
1847    if ri.fixed_hdr:
1848        ri.cw.p("hdr_len = sizeof(req->_hdr);")
1849        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
1850        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
1851        ri.cw.nl()
1852
1853    for _, attr in ri.struct["request"].member_list():
1854        attr.attr_put(ri, "req")
1855    ri.cw.nl()
1856
1857    if 'reply' in ri.op[ri.op_mode]:
1858        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1859        ri.cw.p('yrs.yarg.data = rsp;')
1860        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
1861        if ri.op.value is not None:
1862            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
1863        else:
1864            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
1865        ri.cw.nl()
1866    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
1867    ri.cw.p('if (err < 0)')
1868    if 'reply' in ri.op[ri.op_mode]:
1869        ri.cw.p('goto err_free;')
1870    else:
1871        ri.cw.p('return -1;')
1872    ri.cw.nl()
1873
1874    ri.cw.p(f"return {ret_ok};")
1875    ri.cw.nl()
1876
1877    if 'reply' in ri.op[ri.op_mode]:
1878        ri.cw.p('err_free:')
1879        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1880        ri.cw.p(f"return {ret_err};")
1881
1882    ri.cw.block_end()
1883
1884
1885def print_dump(ri):
1886    direction = "request"
1887    print_prototype(ri, direction, terminate=False)
1888    ri.cw.block_start()
1889    local_vars = ['struct ynl_dump_state yds = {};',
1890                  'struct nlmsghdr *nlh;',
1891                  'int err;']
1892
1893    if ri.fixed_hdr:
1894        local_vars += ['size_t hdr_len;',
1895                       'void *hdr;']
1896
1897    ri.cw.write_func_lvar(local_vars)
1898
1899    ri.cw.p('yds.yarg.ys = ys;')
1900    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1901    ri.cw.p("yds.yarg.data = NULL;")
1902    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1903    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1904    if ri.op.value is not None:
1905        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
1906    else:
1907        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
1908    ri.cw.nl()
1909    ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1910
1911    if ri.fixed_hdr:
1912        ri.cw.p("hdr_len = sizeof(req->_hdr);")
1913        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
1914        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
1915        ri.cw.nl()
1916
1917    if "request" in ri.op[ri.op_mode]:
1918        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1919        ri.cw.nl()
1920        for _, attr in ri.struct["request"].member_list():
1921            attr.attr_put(ri, "req")
1922    ri.cw.nl()
1923
1924    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
1925    ri.cw.p('if (err < 0)')
1926    ri.cw.p('goto free_list;')
1927    ri.cw.nl()
1928
1929    ri.cw.p('return yds.first;')
1930    ri.cw.nl()
1931    ri.cw.p('free_list:')
1932    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
1933    ri.cw.p('return NULL;')
1934    ri.cw.block_end()
1935
1936
1937def call_free(ri, direction, var):
1938    return f"{op_prefix(ri, direction)}_free({var});"
1939
1940
1941def free_arg_name(direction):
1942    if direction:
1943        return direction_to_suffix[direction][1:]
1944    return 'obj'
1945
1946
1947def print_alloc_wrapper(ri, direction):
1948    name = op_prefix(ri, direction)
1949    ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
1950    ri.cw.block_start()
1951    ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
1952    ri.cw.block_end()
1953
1954
1955def print_free_prototype(ri, direction, suffix=';'):
1956    name = op_prefix(ri, direction)
1957    struct_name = name
1958    if ri.type_name_conflict:
1959        struct_name += '_'
1960    arg = free_arg_name(direction)
1961    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
1962
1963
1964def _print_type(ri, direction, struct):
1965    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
1966    if not direction and ri.type_name_conflict:
1967        suffix += '_'
1968
1969    if ri.op_mode == 'dump':
1970        suffix += '_dump'
1971
1972    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
1973
1974    if ri.fixed_hdr:
1975        ri.cw.p(ri.fixed_hdr + ' _hdr;')
1976        ri.cw.nl()
1977
1978    meta_started = False
1979    for _, attr in struct.member_list():
1980        for type_filter in ['len', 'bit']:
1981            line = attr.presence_member(ri.ku_space, type_filter)
1982            if line:
1983                if not meta_started:
1984                    ri.cw.block_start(line=f"struct")
1985                    meta_started = True
1986                ri.cw.p(line)
1987    if meta_started:
1988        ri.cw.block_end(line='_present;')
1989        ri.cw.nl()
1990
1991    for arg in struct.inherited:
1992        ri.cw.p(f"__u32 {arg};")
1993
1994    for _, attr in struct.member_list():
1995        attr.struct_member(ri)
1996
1997    ri.cw.block_end(line=';')
1998    ri.cw.nl()
1999
2000
2001def print_type(ri, direction):
2002    _print_type(ri, direction, ri.struct[direction])
2003
2004
2005def print_type_full(ri, struct):
2006    _print_type(ri, "", struct)
2007
2008
2009def print_type_helpers(ri, direction, deref=False):
2010    print_free_prototype(ri, direction)
2011    ri.cw.nl()
2012
2013    if ri.ku_space == 'user' and direction == 'request':
2014        for _, attr in ri.struct[direction].member_list():
2015            attr.setter(ri, ri.attr_set, direction, deref=deref)
2016    ri.cw.nl()
2017
2018
2019def print_req_type_helpers(ri):
2020    if len(ri.struct["request"].attr_list) == 0:
2021        return
2022    print_alloc_wrapper(ri, "request")
2023    print_type_helpers(ri, "request")
2024
2025
2026def print_rsp_type_helpers(ri):
2027    if 'reply' not in ri.op[ri.op_mode]:
2028        return
2029    print_type_helpers(ri, "reply")
2030
2031
2032def print_parse_prototype(ri, direction, terminate=True):
2033    suffix = "_rsp" if direction == "reply" else "_req"
2034    term = ';' if terminate else ''
2035
2036    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
2037                          ['const struct nlattr **tb',
2038                           f"struct {ri.op.render_name}{suffix} *req"],
2039                          suffix=term)
2040
2041
2042def print_req_type(ri):
2043    if len(ri.struct["request"].attr_list) == 0:
2044        return
2045    print_type(ri, "request")
2046
2047
2048def print_req_free(ri):
2049    if 'request' not in ri.op[ri.op_mode]:
2050        return
2051    _free_type(ri, 'request', ri.struct['request'])
2052
2053
2054def print_rsp_type(ri):
2055    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2056        direction = 'reply'
2057    elif ri.op_mode == 'event':
2058        direction = 'reply'
2059    else:
2060        return
2061    print_type(ri, direction)
2062
2063
2064def print_wrapped_type(ri):
2065    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2066    if ri.op_mode == 'dump':
2067        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2068    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2069        ri.cw.p('__u16 family;')
2070        ri.cw.p('__u8 cmd;')
2071        ri.cw.p('struct ynl_ntf_base_type *next;')
2072        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2073    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2074    ri.cw.block_end(line=';')
2075    ri.cw.nl()
2076    print_free_prototype(ri, 'reply')
2077    ri.cw.nl()
2078
2079
2080def _free_type_members_iter(ri, struct):
2081    for _, attr in struct.member_list():
2082        if attr.free_needs_iter():
2083            ri.cw.p('unsigned int i;')
2084            ri.cw.nl()
2085            break
2086
2087
2088def _free_type_members(ri, var, struct, ref=''):
2089    for _, attr in struct.member_list():
2090        attr.free(ri, var, ref)
2091
2092
2093def _free_type(ri, direction, struct):
2094    var = free_arg_name(direction)
2095
2096    print_free_prototype(ri, direction, suffix='')
2097    ri.cw.block_start()
2098    _free_type_members_iter(ri, struct)
2099    _free_type_members(ri, var, struct)
2100    if direction:
2101        ri.cw.p(f'free({var});')
2102    ri.cw.block_end()
2103    ri.cw.nl()
2104
2105
2106def free_rsp_nested_prototype(ri):
2107        print_free_prototype(ri, "")
2108
2109
2110def free_rsp_nested(ri, struct):
2111    _free_type(ri, "", struct)
2112
2113
2114def print_rsp_free(ri):
2115    if 'reply' not in ri.op[ri.op_mode]:
2116        return
2117    _free_type(ri, 'reply', ri.struct['reply'])
2118
2119
2120def print_dump_type_free(ri):
2121    sub_type = type_name(ri, 'reply')
2122
2123    print_free_prototype(ri, 'reply', suffix='')
2124    ri.cw.block_start()
2125    ri.cw.p(f"{sub_type} *next = rsp;")
2126    ri.cw.nl()
2127    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2128    _free_type_members_iter(ri, ri.struct['reply'])
2129    ri.cw.p('rsp = next;')
2130    ri.cw.p('next = rsp->next;')
2131    ri.cw.nl()
2132
2133    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2134    ri.cw.p(f'free(rsp);')
2135    ri.cw.block_end()
2136    ri.cw.block_end()
2137    ri.cw.nl()
2138
2139
2140def print_ntf_type_free(ri):
2141    print_free_prototype(ri, 'reply', suffix='')
2142    ri.cw.block_start()
2143    _free_type_members_iter(ri, ri.struct['reply'])
2144    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2145    ri.cw.p(f'free(rsp);')
2146    ri.cw.block_end()
2147    ri.cw.nl()
2148
2149
2150def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2151    if terminate and ri and policy_should_be_static(struct.family):
2152        return
2153
2154    if terminate:
2155        prefix = 'extern '
2156    else:
2157        if ri and policy_should_be_static(struct.family):
2158            prefix = 'static '
2159        else:
2160            prefix = ''
2161
2162    suffix = ';' if terminate else ' = {'
2163
2164    max_attr = struct.attr_max_val
2165    if ri:
2166        name = ri.op.render_name
2167        if ri.op.dual_policy:
2168            name += '_' + ri.op_mode
2169    else:
2170        name = struct.render_name
2171    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2172
2173
2174def print_req_policy(cw, struct, ri=None):
2175    if ri and ri.op:
2176        cw.ifdef_block(ri.op.get('config-cond', None))
2177    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2178    for _, arg in struct.member_list():
2179        arg.attr_policy(cw)
2180    cw.p("};")
2181    cw.ifdef_block(None)
2182    cw.nl()
2183
2184
2185def kernel_can_gen_family_struct(family):
2186    return family.proto == 'genetlink'
2187
2188
2189def policy_should_be_static(family):
2190    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2191
2192
2193def print_kernel_policy_ranges(family, cw):
2194    first = True
2195    for _, attr_set in family.attr_sets.items():
2196        if attr_set.subset_of:
2197            continue
2198
2199        for _, attr in attr_set.items():
2200            if not attr.request:
2201                continue
2202            if 'full-range' not in attr.checks:
2203                continue
2204
2205            if first:
2206                cw.p('/* Integer value ranges */')
2207                first = False
2208
2209            sign = '' if attr.type[0] == 'u' else '_signed'
2210            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2211            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2212            members = []
2213            if 'min' in attr.checks:
2214                members.append(('min', attr.get_limit_str('min', suffix=suffix)))
2215            if 'max' in attr.checks:
2216                members.append(('max', attr.get_limit_str('max', suffix=suffix)))
2217            cw.write_struct_init(members)
2218            cw.block_end(line=';')
2219            cw.nl()
2220
2221
2222def print_kernel_op_table_fwd(family, cw, terminate):
2223    exported = not kernel_can_gen_family_struct(family)
2224
2225    if not terminate or exported:
2226        cw.p(f"/* Ops table for {family.ident_name} */")
2227
2228        pol_to_struct = {'global': 'genl_small_ops',
2229                         'per-op': 'genl_ops',
2230                         'split': 'genl_split_ops'}
2231        struct_type = pol_to_struct[family.kernel_policy]
2232
2233        if not exported:
2234            cnt = ""
2235        elif family.kernel_policy == 'split':
2236            cnt = 0
2237            for op in family.ops.values():
2238                if 'do' in op:
2239                    cnt += 1
2240                if 'dump' in op:
2241                    cnt += 1
2242        else:
2243            cnt = len(family.ops)
2244
2245        qual = 'static const' if not exported else 'const'
2246        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2247        if terminate:
2248            cw.p(f"extern {line};")
2249        else:
2250            cw.block_start(line=line + ' =')
2251
2252    if not terminate:
2253        return
2254
2255    cw.nl()
2256    for name in family.hooks['pre']['do']['list']:
2257        cw.write_func_prot('int', c_lower(name),
2258                           ['const struct genl_split_ops *ops',
2259                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2260    for name in family.hooks['post']['do']['list']:
2261        cw.write_func_prot('void', c_lower(name),
2262                           ['const struct genl_split_ops *ops',
2263                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2264    for name in family.hooks['pre']['dump']['list']:
2265        cw.write_func_prot('int', c_lower(name),
2266                           ['struct netlink_callback *cb'], suffix=';')
2267    for name in family.hooks['post']['dump']['list']:
2268        cw.write_func_prot('int', c_lower(name),
2269                           ['struct netlink_callback *cb'], suffix=';')
2270
2271    cw.nl()
2272
2273    for op_name, op in family.ops.items():
2274        if op.is_async:
2275            continue
2276
2277        if 'do' in op:
2278            name = c_lower(f"{family.ident_name}-nl-{op_name}-doit")
2279            cw.write_func_prot('int', name,
2280                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2281
2282        if 'dump' in op:
2283            name = c_lower(f"{family.ident_name}-nl-{op_name}-dumpit")
2284            cw.write_func_prot('int', name,
2285                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2286    cw.nl()
2287
2288
2289def print_kernel_op_table_hdr(family, cw):
2290    print_kernel_op_table_fwd(family, cw, terminate=True)
2291
2292
2293def print_kernel_op_table(family, cw):
2294    print_kernel_op_table_fwd(family, cw, terminate=False)
2295    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2296        for op_name, op in family.ops.items():
2297            if op.is_async:
2298                continue
2299
2300            cw.ifdef_block(op.get('config-cond', None))
2301            cw.block_start()
2302            members = [('cmd', op.enum_name)]
2303            if 'dont-validate' in op:
2304                members.append(('validate',
2305                                ' | '.join([c_upper('genl-dont-validate-' + x)
2306                                            for x in op['dont-validate']])), )
2307            for op_mode in ['do', 'dump']:
2308                if op_mode in op:
2309                    name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2310                    members.append((op_mode + 'it', name))
2311            if family.kernel_policy == 'per-op':
2312                struct = Struct(family, op['attribute-set'],
2313                                type_list=op['do']['request']['attributes'])
2314
2315                name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2316                members.append(('policy', name))
2317                members.append(('maxattr', struct.attr_max_val.enum_name))
2318            if 'flags' in op:
2319                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2320            cw.write_struct_init(members)
2321            cw.block_end(line=',')
2322    elif family.kernel_policy == 'split':
2323        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2324                    'dump': {'pre': 'start', 'post': 'done'}}
2325
2326        for op_name, op in family.ops.items():
2327            for op_mode in ['do', 'dump']:
2328                if op.is_async or op_mode not in op:
2329                    continue
2330
2331                cw.ifdef_block(op.get('config-cond', None))
2332                cw.block_start()
2333                members = [('cmd', op.enum_name)]
2334                if 'dont-validate' in op:
2335                    dont_validate = []
2336                    for x in op['dont-validate']:
2337                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2338                            continue
2339                        if op_mode == "dump" and x == 'strict':
2340                            continue
2341                        dont_validate.append(x)
2342
2343                    if dont_validate:
2344                        members.append(('validate',
2345                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2346                                                    for x in dont_validate])), )
2347                name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2348                if 'pre' in op[op_mode]:
2349                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2350                members.append((op_mode + 'it', name))
2351                if 'post' in op[op_mode]:
2352                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2353                if 'request' in op[op_mode]:
2354                    struct = Struct(family, op['attribute-set'],
2355                                    type_list=op[op_mode]['request']['attributes'])
2356
2357                    if op.dual_policy:
2358                        name = c_lower(f"{family.ident_name}-{op_name}-{op_mode}-nl-policy")
2359                    else:
2360                        name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2361                    members.append(('policy', name))
2362                    members.append(('maxattr', struct.attr_max_val.enum_name))
2363                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2364                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2365                cw.write_struct_init(members)
2366                cw.block_end(line=',')
2367    cw.ifdef_block(None)
2368
2369    cw.block_end(line=';')
2370    cw.nl()
2371
2372
2373def print_kernel_mcgrp_hdr(family, cw):
2374    if not family.mcgrps['list']:
2375        return
2376
2377    cw.block_start('enum')
2378    for grp in family.mcgrps['list']:
2379        grp_id = c_upper(f"{family.ident_name}-nlgrp-{grp['name']},")
2380        cw.p(grp_id)
2381    cw.block_end(';')
2382    cw.nl()
2383
2384
2385def print_kernel_mcgrp_src(family, cw):
2386    if not family.mcgrps['list']:
2387        return
2388
2389    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
2390    for grp in family.mcgrps['list']:
2391        name = grp['name']
2392        grp_id = c_upper(f"{family.ident_name}-nlgrp-{name}")
2393        cw.p('[' + grp_id + '] = { "' + name + '", },')
2394    cw.block_end(';')
2395    cw.nl()
2396
2397
2398def print_kernel_family_struct_hdr(family, cw):
2399    if not kernel_can_gen_family_struct(family):
2400        return
2401
2402    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
2403    cw.nl()
2404    if 'sock-priv' in family.kernel_family:
2405        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
2406        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
2407        cw.nl()
2408
2409
2410def print_kernel_family_struct_src(family, cw):
2411    if not kernel_can_gen_family_struct(family):
2412        return
2413
2414    if 'sock-priv' in family.kernel_family:
2415        # Generate "trampolines" to make CFI happy
2416        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_init",
2417                      [f"{family.c_name}_nl_sock_priv_init(priv);"],
2418                      ["void *priv"])
2419        cw.nl()
2420        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_destroy",
2421                      [f"{family.c_name}_nl_sock_priv_destroy(priv);"],
2422                      ["void *priv"])
2423        cw.nl()
2424
2425    cw.block_start(f"struct genl_family {family.ident_name}_nl_family __ro_after_init =")
2426    cw.p('.name\t\t= ' + family.fam_key + ',')
2427    cw.p('.version\t= ' + family.ver_key + ',')
2428    cw.p('.netnsok\t= true,')
2429    cw.p('.parallel_ops\t= true,')
2430    cw.p('.module\t\t= THIS_MODULE,')
2431    if family.kernel_policy == 'per-op':
2432        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
2433        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2434    elif family.kernel_policy == 'split':
2435        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
2436        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2437    if family.mcgrps['list']:
2438        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
2439        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
2440    if 'sock-priv' in family.kernel_family:
2441        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
2442        cw.p(f'.sock_priv_init\t= __{family.c_name}_nl_sock_priv_init,')
2443        cw.p(f'.sock_priv_destroy = __{family.c_name}_nl_sock_priv_destroy,')
2444    cw.block_end(';')
2445
2446
2447def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2448    start_line = 'enum'
2449    if enum_name in obj:
2450        if obj[enum_name]:
2451            start_line = 'enum ' + c_lower(obj[enum_name])
2452    elif ckey and ckey in obj:
2453        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
2454    cw.block_start(line=start_line)
2455
2456
2457def render_uapi_unified(family, cw, max_by_define, separate_ntf):
2458    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2459    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2460    max_value = f"({cnt_name} - 1)"
2461
2462    uapi_enum_start(family, cw, family['operations'], 'enum-name')
2463    val = 0
2464    for op in family.msgs.values():
2465        if separate_ntf and ('notify' in op or 'event' in op):
2466            continue
2467
2468        suffix = ','
2469        if op.value != val:
2470            suffix = f" = {op.value},"
2471            val = op.value
2472        cw.p(op.enum_name + suffix)
2473        val += 1
2474    cw.nl()
2475    cw.p(cnt_name + ('' if max_by_define else ','))
2476    if not max_by_define:
2477        cw.p(f"{max_name} = {max_value}")
2478    cw.block_end(line=';')
2479    if max_by_define:
2480        cw.p(f"#define {max_name} {max_value}")
2481    cw.nl()
2482
2483
2484def render_uapi_directional(family, cw, max_by_define):
2485    max_name = f"{family.op_prefix}USER_MAX"
2486    cnt_name = f"__{family.op_prefix}USER_CNT"
2487    max_value = f"({cnt_name} - 1)"
2488
2489    cw.block_start(line='enum')
2490    cw.p(c_upper(f'{family.name}_MSG_USER_NONE = 0,'))
2491    val = 0
2492    for op in family.msgs.values():
2493        if 'do' in op and 'event' not in op:
2494            suffix = ','
2495            if op.value and op.value != val:
2496                suffix = f" = {op.value},"
2497                val = op.value
2498            cw.p(op.enum_name + suffix)
2499            val += 1
2500    cw.nl()
2501    cw.p(cnt_name + ('' if max_by_define else ','))
2502    if not max_by_define:
2503        cw.p(f"{max_name} = {max_value}")
2504    cw.block_end(line=';')
2505    if max_by_define:
2506        cw.p(f"#define {max_name} {max_value}")
2507    cw.nl()
2508
2509    max_name = f"{family.op_prefix}KERNEL_MAX"
2510    cnt_name = f"__{family.op_prefix}KERNEL_CNT"
2511    max_value = f"({cnt_name} - 1)"
2512
2513    cw.block_start(line='enum')
2514    cw.p(c_upper(f'{family.name}_MSG_KERNEL_NONE = 0,'))
2515    val = 0
2516    for op in family.msgs.values():
2517        if ('do' in op and 'reply' in op['do']) or 'notify' in op or 'event' in op:
2518            enum_name = op.enum_name
2519            if 'event' not in op and 'notify' not in op:
2520                enum_name = f'{enum_name}_REPLY'
2521
2522            suffix = ','
2523            if op.value and op.value != val:
2524                suffix = f" = {op.value},"
2525                val = op.value
2526            cw.p(enum_name + suffix)
2527            val += 1
2528    cw.nl()
2529    cw.p(cnt_name + ('' if max_by_define else ','))
2530    if not max_by_define:
2531        cw.p(f"{max_name} = {max_value}")
2532    cw.block_end(line=';')
2533    if max_by_define:
2534        cw.p(f"#define {max_name} {max_value}")
2535    cw.nl()
2536
2537
2538def render_uapi(family, cw):
2539    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
2540    hdr_prot = hdr_prot.replace('/', '_')
2541    cw.p('#ifndef ' + hdr_prot)
2542    cw.p('#define ' + hdr_prot)
2543    cw.nl()
2544
2545    defines = [(family.fam_key, family["name"]),
2546               (family.ver_key, family.get('version', 1))]
2547    cw.writes_defines(defines)
2548    cw.nl()
2549
2550    defines = []
2551    for const in family['definitions']:
2552        if const['type'] != 'const':
2553            cw.writes_defines(defines)
2554            defines = []
2555            cw.nl()
2556
2557        # Write kdoc for enum and flags (one day maybe also structs)
2558        if const['type'] == 'enum' or const['type'] == 'flags':
2559            enum = family.consts[const['name']]
2560
2561            if enum.header:
2562                continue
2563
2564            if enum.has_doc():
2565                if enum.has_entry_doc():
2566                    cw.p('/**')
2567                    doc = ''
2568                    if 'doc' in enum:
2569                        doc = ' - ' + enum['doc']
2570                    cw.write_doc_line(enum.enum_name + doc)
2571                else:
2572                    cw.p('/*')
2573                    cw.write_doc_line(enum['doc'], indent=False)
2574                for entry in enum.entries.values():
2575                    if entry.has_doc():
2576                        doc = '@' + entry.c_name + ': ' + entry['doc']
2577                        cw.write_doc_line(doc)
2578                cw.p(' */')
2579
2580            uapi_enum_start(family, cw, const, 'name')
2581            name_pfx = const.get('name-prefix', f"{family.ident_name}-{const['name']}-")
2582            for entry in enum.entries.values():
2583                suffix = ','
2584                if entry.value_change:
2585                    suffix = f" = {entry.user_value()}" + suffix
2586                cw.p(entry.c_name + suffix)
2587
2588            if const.get('render-max', False):
2589                cw.nl()
2590                cw.p('/* private: */')
2591                if const['type'] == 'flags':
2592                    max_name = c_upper(name_pfx + 'mask')
2593                    max_val = f' = {enum.get_mask()},'
2594                    cw.p(max_name + max_val)
2595                else:
2596                    cnt_name = enum.enum_cnt_name
2597                    max_name = c_upper(name_pfx + 'max')
2598                    if not cnt_name:
2599                        cnt_name = '__' + name_pfx + 'max'
2600                    cw.p(c_upper(cnt_name) + ',')
2601                    cw.p(max_name + ' = (' + c_upper(cnt_name) + ' - 1)')
2602            cw.block_end(line=';')
2603            cw.nl()
2604        elif const['type'] == 'const':
2605            defines.append([c_upper(family.get('c-define-name',
2606                                               f"{family.ident_name}-{const['name']}")),
2607                            const['value']])
2608
2609    if defines:
2610        cw.writes_defines(defines)
2611        cw.nl()
2612
2613    max_by_define = family.get('max-by-define', False)
2614
2615    for _, attr_set in family.attr_sets.items():
2616        if attr_set.subset_of:
2617            continue
2618
2619        max_value = f"({attr_set.cnt_name} - 1)"
2620
2621        val = 0
2622        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2623        for _, attr in attr_set.items():
2624            suffix = ','
2625            if attr.value != val:
2626                suffix = f" = {attr.value},"
2627                val = attr.value
2628            val += 1
2629            cw.p(attr.enum_name + suffix)
2630        if attr_set.items():
2631            cw.nl()
2632        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
2633        if not max_by_define:
2634            cw.p(f"{attr_set.max_name} = {max_value}")
2635        cw.block_end(line=';')
2636        if max_by_define:
2637            cw.p(f"#define {attr_set.max_name} {max_value}")
2638        cw.nl()
2639
2640    # Commands
2641    separate_ntf = 'async-prefix' in family['operations']
2642
2643    if family.msg_id_model == 'unified':
2644        render_uapi_unified(family, cw, max_by_define, separate_ntf)
2645    elif family.msg_id_model == 'directional':
2646        render_uapi_directional(family, cw, max_by_define)
2647    else:
2648        raise Exception(f'Unsupported message enum-model {family.msg_id_model}')
2649
2650    if separate_ntf:
2651        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2652        for op in family.msgs.values():
2653            if separate_ntf and not ('notify' in op or 'event' in op):
2654                continue
2655
2656            suffix = ','
2657            if 'value' in op:
2658                suffix = f" = {op['value']},"
2659            cw.p(op.enum_name + suffix)
2660        cw.block_end(line=';')
2661        cw.nl()
2662
2663    # Multicast
2664    defines = []
2665    for grp in family.mcgrps['list']:
2666        name = grp['name']
2667        defines.append([c_upper(grp.get('c-define-name', f"{family.ident_name}-mcgrp-{name}")),
2668                        f'{name}'])
2669    cw.nl()
2670    if defines:
2671        cw.writes_defines(defines)
2672        cw.nl()
2673
2674    cw.p(f'#endif /* {hdr_prot} */')
2675
2676
2677def _render_user_ntf_entry(ri, op):
2678    ri.cw.block_start(line=f"[{op.enum_name}] = ")
2679    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
2680    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
2681    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
2682    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
2683    ri.cw.block_end(line=',')
2684
2685
2686def render_user_family(family, cw, prototype):
2687    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
2688    if prototype:
2689        cw.p(f'extern {symbol};')
2690        return
2691
2692    if family.ntfs:
2693        cw.block_start(line=f"static const struct ynl_ntf_info {family['name']}_ntf_info[] = ")
2694        for ntf_op_name, ntf_op in family.ntfs.items():
2695            if 'notify' in ntf_op:
2696                op = family.ops[ntf_op['notify']]
2697                ri = RenderInfo(cw, family, "user", op, "notify")
2698            elif 'event' in ntf_op:
2699                ri = RenderInfo(cw, family, "user", ntf_op, "event")
2700            else:
2701                raise Exception('Invalid notification ' + ntf_op_name)
2702            _render_user_ntf_entry(ri, ntf_op)
2703        for op_name, op in family.ops.items():
2704            if 'event' not in op:
2705                continue
2706            ri = RenderInfo(cw, family, "user", op, "event")
2707            _render_user_ntf_entry(ri, op)
2708        cw.block_end(line=";")
2709        cw.nl()
2710
2711    cw.block_start(f'{symbol} = ')
2712    cw.p(f'.name\t\t= "{family.c_name}",')
2713    if family.fixed_header:
2714        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
2715    else:
2716        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
2717    if family.ntfs:
2718        cw.p(f".ntf_info\t= {family['name']}_ntf_info,")
2719        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family['name']}_ntf_info),")
2720    cw.block_end(line=';')
2721
2722
2723def family_contains_bitfield32(family):
2724    for _, attr_set in family.attr_sets.items():
2725        if attr_set.subset_of:
2726            continue
2727        for _, attr in attr_set.items():
2728            if attr.type == "bitfield32":
2729                return True
2730    return False
2731
2732
2733def find_kernel_root(full_path):
2734    sub_path = ''
2735    while True:
2736        sub_path = os.path.join(os.path.basename(full_path), sub_path)
2737        full_path = os.path.dirname(full_path)
2738        maintainers = os.path.join(full_path, "MAINTAINERS")
2739        if os.path.exists(maintainers):
2740            return full_path, sub_path[:-1]
2741
2742
2743def main():
2744    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2745    parser.add_argument('--mode', dest='mode', type=str, required=True,
2746                        choices=('user', 'kernel', 'uapi'))
2747    parser.add_argument('--spec', dest='spec', type=str, required=True)
2748    parser.add_argument('--header', dest='header', action='store_true', default=None)
2749    parser.add_argument('--source', dest='header', action='store_false')
2750    parser.add_argument('--user-header', nargs='+', default=[])
2751    parser.add_argument('--cmp-out', action='store_true', default=None,
2752                        help='Do not overwrite the output file if the new output is identical to the old')
2753    parser.add_argument('--exclude-op', action='append', default=[])
2754    parser.add_argument('-o', dest='out_file', type=str, default=None)
2755    args = parser.parse_args()
2756
2757    if args.header is None:
2758        parser.error("--header or --source is required")
2759
2760    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
2761
2762    try:
2763        parsed = Family(args.spec, exclude_ops)
2764        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
2765            print('Spec license:', parsed.license)
2766            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
2767            os.sys.exit(1)
2768    except yaml.YAMLError as exc:
2769        print(exc)
2770        os.sys.exit(1)
2771        return
2772
2773    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
2774
2775    _, spec_kernel = find_kernel_root(args.spec)
2776    if args.mode == 'uapi' or args.header:
2777        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
2778    else:
2779        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
2780    cw.p("/* Do not edit directly, auto-generated from: */")
2781    cw.p(f"/*\t{spec_kernel} */")
2782    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2783    if args.exclude_op or args.user_header:
2784        line = ''
2785        line += ' --user-header '.join([''] + args.user_header)
2786        line += ' --exclude-op '.join([''] + args.exclude_op)
2787        cw.p(f'/* YNL-ARG{line} */')
2788    cw.nl()
2789
2790    if args.mode == 'uapi':
2791        render_uapi(parsed, cw)
2792        return
2793
2794    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
2795    if args.header:
2796        cw.p('#ifndef ' + hdr_prot)
2797        cw.p('#define ' + hdr_prot)
2798        cw.nl()
2799
2800    if args.out_file:
2801        hdr_file = os.path.basename(args.out_file[:-2]) + ".h"
2802    else:
2803        hdr_file = "generated_header_file.h"
2804
2805    if args.mode == 'kernel':
2806        cw.p('#include <net/netlink.h>')
2807        cw.p('#include <net/genetlink.h>')
2808        cw.nl()
2809        if not args.header:
2810            if args.out_file:
2811                cw.p(f'#include "{hdr_file}"')
2812            cw.nl()
2813        headers = ['uapi/' + parsed.uapi_header]
2814        headers += parsed.kernel_family.get('headers', [])
2815    else:
2816        cw.p('#include <stdlib.h>')
2817        cw.p('#include <string.h>')
2818        if args.header:
2819            cw.p('#include <linux/types.h>')
2820            if family_contains_bitfield32(parsed):
2821                cw.p('#include <linux/netlink.h>')
2822        else:
2823            cw.p(f'#include "{hdr_file}"')
2824            cw.p('#include "ynl.h"')
2825        headers = []
2826    for definition in parsed['definitions']:
2827        if 'header' in definition:
2828            headers.append(definition['header'])
2829    if args.mode == 'user':
2830        headers.append(parsed.uapi_header)
2831    seen_header = []
2832    for one in headers:
2833        if one not in seen_header:
2834            cw.p(f"#include <{one}>")
2835            seen_header.append(one)
2836    cw.nl()
2837
2838    if args.mode == "user":
2839        if not args.header:
2840            cw.p("#include <linux/genetlink.h>")
2841            cw.nl()
2842            for one in args.user_header:
2843                cw.p(f'#include "{one}"')
2844        else:
2845            cw.p('struct ynl_sock;')
2846            cw.nl()
2847            render_user_family(parsed, cw, True)
2848        cw.nl()
2849
2850    if args.mode == "kernel":
2851        if args.header:
2852            for _, struct in sorted(parsed.pure_nested_structs.items()):
2853                if struct.request:
2854                    cw.p('/* Common nested types */')
2855                    break
2856            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2857                if struct.request:
2858                    print_req_policy_fwd(cw, struct)
2859            cw.nl()
2860
2861            if parsed.kernel_policy == 'global':
2862                cw.p(f"/* Global operation policy for {parsed.name} */")
2863
2864                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2865                print_req_policy_fwd(cw, struct)
2866                cw.nl()
2867
2868            if parsed.kernel_policy in {'per-op', 'split'}:
2869                for op_name, op in parsed.ops.items():
2870                    if 'do' in op and 'event' not in op:
2871                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
2872                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2873                        cw.nl()
2874
2875            print_kernel_op_table_hdr(parsed, cw)
2876            print_kernel_mcgrp_hdr(parsed, cw)
2877            print_kernel_family_struct_hdr(parsed, cw)
2878        else:
2879            print_kernel_policy_ranges(parsed, cw)
2880
2881            for _, struct in sorted(parsed.pure_nested_structs.items()):
2882                if struct.request:
2883                    cw.p('/* Common nested types */')
2884                    break
2885            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2886                if struct.request:
2887                    print_req_policy(cw, struct)
2888            cw.nl()
2889
2890            if parsed.kernel_policy == 'global':
2891                cw.p(f"/* Global operation policy for {parsed.name} */")
2892
2893                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2894                print_req_policy(cw, struct)
2895                cw.nl()
2896
2897            for op_name, op in parsed.ops.items():
2898                if parsed.kernel_policy in {'per-op', 'split'}:
2899                    for op_mode in ['do', 'dump']:
2900                        if op_mode in op and 'request' in op[op_mode]:
2901                            cw.p(f"/* {op.enum_name} - {op_mode} */")
2902                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
2903                            print_req_policy(cw, ri.struct['request'], ri=ri)
2904                            cw.nl()
2905
2906            print_kernel_op_table(parsed, cw)
2907            print_kernel_mcgrp_src(parsed, cw)
2908            print_kernel_family_struct_src(parsed, cw)
2909
2910    if args.mode == "user":
2911        if args.header:
2912            cw.p('/* Enums */')
2913            put_op_name_fwd(parsed, cw)
2914
2915            for name, const in parsed.consts.items():
2916                if isinstance(const, EnumSet):
2917                    put_enum_to_str_fwd(parsed, cw, const)
2918            cw.nl()
2919
2920            cw.p('/* Common nested types */')
2921            for attr_set, struct in parsed.pure_nested_structs.items():
2922                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2923                print_type_full(ri, struct)
2924
2925            for op_name, op in parsed.ops.items():
2926                cw.p(f"/* ============== {op.enum_name} ============== */")
2927
2928                if 'do' in op and 'event' not in op:
2929                    cw.p(f"/* {op.enum_name} - do */")
2930                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2931                    print_req_type(ri)
2932                    print_req_type_helpers(ri)
2933                    cw.nl()
2934                    print_rsp_type(ri)
2935                    print_rsp_type_helpers(ri)
2936                    cw.nl()
2937                    print_req_prototype(ri)
2938                    cw.nl()
2939
2940                if 'dump' in op:
2941                    cw.p(f"/* {op.enum_name} - dump */")
2942                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
2943                    print_req_type(ri)
2944                    print_req_type_helpers(ri)
2945                    if not ri.type_consistent:
2946                        print_rsp_type(ri)
2947                    print_wrapped_type(ri)
2948                    print_dump_prototype(ri)
2949                    cw.nl()
2950
2951                if op.has_ntf:
2952                    cw.p(f"/* {op.enum_name} - notify */")
2953                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2954                    if not ri.type_consistent:
2955                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2956                    print_wrapped_type(ri)
2957
2958            for op_name, op in parsed.ntfs.items():
2959                if 'event' in op:
2960                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
2961                    cw.p(f"/* {op.enum_name} - event */")
2962                    print_rsp_type(ri)
2963                    cw.nl()
2964                    print_wrapped_type(ri)
2965            cw.nl()
2966        else:
2967            cw.p('/* Enums */')
2968            put_op_name(parsed, cw)
2969
2970            for name, const in parsed.consts.items():
2971                if isinstance(const, EnumSet):
2972                    put_enum_to_str(parsed, cw, const)
2973            cw.nl()
2974
2975            has_recursive_nests = False
2976            cw.p('/* Policies */')
2977            for struct in parsed.pure_nested_structs.values():
2978                if struct.recursive:
2979                    put_typol_fwd(cw, struct)
2980                    has_recursive_nests = True
2981            if has_recursive_nests:
2982                cw.nl()
2983            for name in parsed.pure_nested_structs:
2984                struct = Struct(parsed, name)
2985                put_typol(cw, struct)
2986            for name in parsed.root_sets:
2987                struct = Struct(parsed, name)
2988                put_typol(cw, struct)
2989
2990            cw.p('/* Common nested types */')
2991            if has_recursive_nests:
2992                for attr_set, struct in parsed.pure_nested_structs.items():
2993                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2994                    free_rsp_nested_prototype(ri)
2995                    if struct.request:
2996                        put_req_nested_prototype(ri, struct)
2997                    if struct.reply:
2998                        parse_rsp_nested_prototype(ri, struct)
2999                cw.nl()
3000            for attr_set, struct in parsed.pure_nested_structs.items():
3001                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3002
3003                free_rsp_nested(ri, struct)
3004                if struct.request:
3005                    put_req_nested(ri, struct)
3006                if struct.reply:
3007                    parse_rsp_nested(ri, struct)
3008
3009            for op_name, op in parsed.ops.items():
3010                cw.p(f"/* ============== {op.enum_name} ============== */")
3011                if 'do' in op and 'event' not in op:
3012                    cw.p(f"/* {op.enum_name} - do */")
3013                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3014                    print_req_free(ri)
3015                    print_rsp_free(ri)
3016                    parse_rsp_msg(ri)
3017                    print_req(ri)
3018                    cw.nl()
3019
3020                if 'dump' in op:
3021                    cw.p(f"/* {op.enum_name} - dump */")
3022                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
3023                    if not ri.type_consistent:
3024                        parse_rsp_msg(ri, deref=True)
3025                    print_req_free(ri)
3026                    print_dump_type_free(ri)
3027                    print_dump(ri)
3028                    cw.nl()
3029
3030                if op.has_ntf:
3031                    cw.p(f"/* {op.enum_name} - notify */")
3032                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3033                    if not ri.type_consistent:
3034                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3035                    print_ntf_type_free(ri)
3036
3037            for op_name, op in parsed.ntfs.items():
3038                if 'event' in op:
3039                    cw.p(f"/* {op.enum_name} - event */")
3040
3041                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3042                    parse_rsp_msg(ri)
3043
3044                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
3045                    print_ntf_type_free(ri)
3046            cw.nl()
3047            render_user_family(parsed, cw, False)
3048
3049    if args.header:
3050        cw.p(f'#endif /* {hdr_prot} */')
3051
3052
3053if __name__ == "__main__":
3054    main()
3055