xref: /linux/tools/net/ynl/pyynl/ynl_gen_c.py (revision 29d34a4d785bbf389d57bfdafe2a19dad6ced3a4)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import filecmp
7import pathlib
8import os
9import re
10import shutil
11import sys
12import tempfile
13import yaml
14
15sys.path.append(pathlib.Path(__file__).resolve().parent.as_posix())
16from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
17
18
19def c_upper(name):
20    return name.upper().replace('-', '_')
21
22
23def c_lower(name):
24    return name.lower().replace('-', '_')
25
26
27def limit_to_number(name):
28    """
29    Turn a string limit like u32-max or s64-min into its numerical value
30    """
31    if name[0] == 'u' and name.endswith('-min'):
32        return 0
33    width = int(name[1:-4])
34    if name[0] == 's':
35        width -= 1
36    value = (1 << width) - 1
37    if name[0] == 's' and name.endswith('-min'):
38        value = -value - 1
39    return value
40
41
42class BaseNlLib:
43    def get_family_id(self):
44        return 'ys->family_id'
45
46
47class Type(SpecAttr):
48    def __init__(self, family, attr_set, attr, value):
49        super().__init__(family, attr_set, attr, value)
50
51        self.attr = attr
52        self.attr_set = attr_set
53        self.type = attr['type']
54        self.checks = attr.get('checks', {})
55
56        self.request = False
57        self.reply = False
58
59        if 'len' in attr:
60            self.len = attr['len']
61
62        if 'nested-attributes' in attr:
63            self.nested_attrs = attr['nested-attributes']
64            if self.nested_attrs == family.name:
65                self.nested_render_name = c_lower(f"{family.ident_name}")
66            else:
67                self.nested_render_name = c_lower(f"{family.ident_name}_{self.nested_attrs}")
68
69            if self.nested_attrs in self.family.consts:
70                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
71            else:
72                self.nested_struct_type = 'struct ' + self.nested_render_name
73
74        self.c_name = c_lower(self.name)
75        if self.c_name in _C_KW:
76            self.c_name += '_'
77        if self.c_name[0].isdigit():
78            self.c_name = '_' + self.c_name
79
80        # Added by resolve():
81        self.enum_name = None
82        delattr(self, "enum_name")
83
84    def _get_real_attr(self):
85        # if the attr is for a subset return the "real" attr (just one down, does not recurse)
86        return self.family.attr_sets[self.attr_set.subset_of][self.name]
87
88    def set_request(self):
89        self.request = True
90        if self.attr_set.subset_of:
91            self._get_real_attr().set_request()
92
93    def set_reply(self):
94        self.reply = True
95        if self.attr_set.subset_of:
96            self._get_real_attr().set_reply()
97
98    def get_limit(self, limit, default=None):
99        value = self.checks.get(limit, default)
100        if value is None:
101            return value
102        if isinstance(value, int):
103            return value
104        if value in self.family.consts:
105            return self.family.consts[value]["value"]
106        return limit_to_number(value)
107
108    def get_limit_str(self, limit, default=None, suffix=''):
109        value = self.checks.get(limit, default)
110        if value is None:
111            return ''
112        if isinstance(value, int):
113            return str(value) + suffix
114        if value in self.family.consts:
115            const = self.family.consts[value]
116            if const.get('header'):
117                return c_upper(value)
118            return c_upper(f"{self.family['name']}-{value}")
119        return c_upper(value)
120
121    def resolve(self):
122        if 'name-prefix' in self.attr:
123            enum_name = f"{self.attr['name-prefix']}{self.name}"
124        else:
125            enum_name = f"{self.attr_set.name_prefix}{self.name}"
126        self.enum_name = c_upper(enum_name)
127
128        if self.attr_set.subset_of:
129            if self.checks != self._get_real_attr().checks:
130                raise Exception("Overriding checks not supported by codegen, yet")
131
132    def is_multi_val(self):
133        return None
134
135    def is_scalar(self):
136        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
137
138    def is_recursive(self):
139        return False
140
141    def is_recursive_for_op(self, ri):
142        return self.is_recursive() and not ri.op
143
144    def presence_type(self):
145        return 'bit'
146
147    def presence_member(self, space, type_filter):
148        if self.presence_type() != type_filter:
149            return
150
151        if self.presence_type() == 'bit':
152            pfx = '__' if space == 'user' else ''
153            return f"{pfx}u32 {self.c_name}:1;"
154
155        if self.presence_type() == 'len':
156            pfx = '__' if space == 'user' else ''
157            return f"{pfx}u32 {self.c_name}_len;"
158
159    def _complex_member_type(self, ri):
160        return None
161
162    def free_needs_iter(self):
163        return False
164
165    def free(self, ri, var, ref):
166        if self.is_multi_val() or self.presence_type() == 'len':
167            ri.cw.p(f'free({var}->{ref}{self.c_name});')
168
169    def arg_member(self, ri):
170        member = self._complex_member_type(ri)
171        if member:
172            arg = [member + ' *' + self.c_name]
173            if self.presence_type() == 'count':
174                arg += ['unsigned int n_' + self.c_name]
175            return arg
176        raise Exception(f"Struct member not implemented for class type {self.type}")
177
178    def struct_member(self, ri):
179        if self.is_multi_val():
180            ri.cw.p(f"unsigned int n_{self.c_name};")
181        member = self._complex_member_type(ri)
182        if member:
183            ptr = '*' if self.is_multi_val() else ''
184            if self.is_recursive_for_op(ri):
185                ptr = '*'
186            ri.cw.p(f"{member} {ptr}{self.c_name};")
187            return
188        members = self.arg_member(ri)
189        for one in members:
190            ri.cw.p(one + ';')
191
192    def _attr_policy(self, policy):
193        return '{ .type = ' + policy + ', }'
194
195    def attr_policy(self, cw):
196        policy = f'NLA_{c_upper(self.type)}'
197        if self.attr.get('byte-order') == 'big-endian':
198            if self.type in {'u16', 'u32'}:
199                policy = f'NLA_BE{self.type[1:]}'
200
201        spec = self._attr_policy(policy)
202        cw.p(f"\t[{self.enum_name}] = {spec},")
203
204    def _attr_typol(self):
205        raise Exception(f"Type policy not implemented for class type {self.type}")
206
207    def attr_typol(self, cw):
208        typol = self._attr_typol()
209        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
210
211    def _attr_put_line(self, ri, var, line):
212        if self.presence_type() == 'bit':
213            ri.cw.p(f"if ({var}->_present.{self.c_name})")
214        elif self.presence_type() == 'len':
215            ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
216        ri.cw.p(f"{line};")
217
218    def _attr_put_simple(self, ri, var, put_type):
219        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
220        self._attr_put_line(ri, var, line)
221
222    def attr_put(self, ri, var):
223        raise Exception(f"Put not implemented for class type {self.type}")
224
225    def _attr_get(self, ri, var):
226        raise Exception(f"Attr get not implemented for class type {self.type}")
227
228    def attr_get(self, ri, var, first):
229        lines, init_lines, local_vars = self._attr_get(ri, var)
230        if type(lines) is str:
231            lines = [lines]
232        if type(init_lines) is str:
233            init_lines = [init_lines]
234
235        kw = 'if' if first else 'else if'
236        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
237        if local_vars:
238            for local in local_vars:
239                ri.cw.p(local)
240            ri.cw.nl()
241
242        if not self.is_multi_val():
243            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
244            ri.cw.p("return YNL_PARSE_CB_ERROR;")
245            if self.presence_type() == 'bit':
246                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
247
248        if init_lines:
249            ri.cw.nl()
250            for line in init_lines:
251                ri.cw.p(line)
252
253        for line in lines:
254            ri.cw.p(line)
255        ri.cw.block_end()
256        return True
257
258    def _setter_lines(self, ri, member, presence):
259        raise Exception(f"Setter not implemented for class type {self.type}")
260
261    def setter(self, ri, space, direction, deref=False, ref=None):
262        ref = (ref if ref else []) + [self.c_name]
263        var = "req"
264        member = f"{var}->{'.'.join(ref)}"
265
266        code = []
267        presence = ''
268        for i in range(0, len(ref)):
269            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
270            # Every layer below last is a nest, so we know it uses bit presence
271            # last layer is "self" and may be a complex type
272            if i == len(ref) - 1 and self.presence_type() != 'bit':
273                continue
274            code.append(presence + ' = 1;')
275        code += self._setter_lines(ri, member, presence)
276
277        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
278        free = bool([x for x in code if 'free(' in x])
279        alloc = bool([x for x in code if 'alloc(' in x])
280        if free and not alloc:
281            func_name = '__' + func_name
282        ri.cw.write_func('static inline void', func_name, body=code,
283                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
284
285
286class TypeUnused(Type):
287    def presence_type(self):
288        return ''
289
290    def arg_member(self, ri):
291        return []
292
293    def _attr_get(self, ri, var):
294        return ['return YNL_PARSE_CB_ERROR;'], None, None
295
296    def _attr_typol(self):
297        return '.type = YNL_PT_REJECT, '
298
299    def attr_policy(self, cw):
300        pass
301
302    def attr_put(self, ri, var):
303        pass
304
305    def attr_get(self, ri, var, first):
306        pass
307
308    def setter(self, ri, space, direction, deref=False, ref=None):
309        pass
310
311
312class TypePad(Type):
313    def presence_type(self):
314        return ''
315
316    def arg_member(self, ri):
317        return []
318
319    def _attr_typol(self):
320        return '.type = YNL_PT_IGNORE, '
321
322    def attr_put(self, ri, var):
323        pass
324
325    def attr_get(self, ri, var, first):
326        pass
327
328    def attr_policy(self, cw):
329        pass
330
331    def setter(self, ri, space, direction, deref=False, ref=None):
332        pass
333
334
335class TypeScalar(Type):
336    def __init__(self, family, attr_set, attr, value):
337        super().__init__(family, attr_set, attr, value)
338
339        self.byte_order_comment = ''
340        if 'byte-order' in attr:
341            self.byte_order_comment = f" /* {attr['byte-order']} */"
342
343        if 'enum' in self.attr:
344            enum = self.family.consts[self.attr['enum']]
345            low, high = enum.value_range()
346            if 'min' not in self.checks:
347                if low != 0 or self.type[0] == 's':
348                    self.checks['min'] = low
349            if 'max' not in self.checks:
350                self.checks['max'] = high
351
352        if 'min' in self.checks and 'max' in self.checks:
353            if self.get_limit('min') > self.get_limit('max'):
354                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
355            self.checks['range'] = True
356
357        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
358        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
359        if low < 0 and self.type[0] == 'u':
360            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
361        if low < -32768 or high > 32767:
362            self.checks['full-range'] = True
363
364        # Added by resolve():
365        self.is_bitfield = None
366        delattr(self, "is_bitfield")
367        self.type_name = None
368        delattr(self, "type_name")
369
370    def resolve(self):
371        self.resolve_up(super())
372
373        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
374            self.is_bitfield = True
375        elif 'enum' in self.attr:
376            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
377        else:
378            self.is_bitfield = False
379
380        if not self.is_bitfield and 'enum' in self.attr:
381            self.type_name = self.family.consts[self.attr['enum']].user_type
382        elif self.is_auto_scalar:
383            self.type_name = '__' + self.type[0] + '64'
384        else:
385            self.type_name = '__' + self.type
386
387    def _attr_policy(self, policy):
388        if 'flags-mask' in self.checks or self.is_bitfield:
389            if self.is_bitfield:
390                enum = self.family.consts[self.attr['enum']]
391                mask = enum.get_mask(as_flags=True)
392            else:
393                flags = self.family.consts[self.checks['flags-mask']]
394                flag_cnt = len(flags['entries'])
395                mask = (1 << flag_cnt) - 1
396            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
397        elif 'full-range' in self.checks:
398            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
399        elif 'range' in self.checks:
400            return f"NLA_POLICY_RANGE({policy}, {self.get_limit_str('min')}, {self.get_limit_str('max')})"
401        elif 'min' in self.checks:
402            return f"NLA_POLICY_MIN({policy}, {self.get_limit_str('min')})"
403        elif 'max' in self.checks:
404            return f"NLA_POLICY_MAX({policy}, {self.get_limit_str('max')})"
405        return super()._attr_policy(policy)
406
407    def _attr_typol(self):
408        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
409
410    def arg_member(self, ri):
411        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
412
413    def attr_put(self, ri, var):
414        self._attr_put_simple(ri, var, self.type)
415
416    def _attr_get(self, ri, var):
417        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
418
419    def _setter_lines(self, ri, member, presence):
420        return [f"{member} = {self.c_name};"]
421
422
423class TypeFlag(Type):
424    def arg_member(self, ri):
425        return []
426
427    def _attr_typol(self):
428        return '.type = YNL_PT_FLAG, '
429
430    def attr_put(self, ri, var):
431        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
432
433    def _attr_get(self, ri, var):
434        return [], None, None
435
436    def _setter_lines(self, ri, member, presence):
437        return []
438
439
440class TypeString(Type):
441    def arg_member(self, ri):
442        return [f"const char *{self.c_name}"]
443
444    def presence_type(self):
445        return 'len'
446
447    def struct_member(self, ri):
448        ri.cw.p(f"char *{self.c_name};")
449
450    def _attr_typol(self):
451        return f'.type = YNL_PT_NUL_STR, '
452
453    def _attr_policy(self, policy):
454        if 'exact-len' in self.checks:
455            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
456        else:
457            mem = '{ .type = ' + policy
458            if 'max-len' in self.checks:
459                mem += ', .len = ' + self.get_limit_str('max-len')
460            mem += ', }'
461        return mem
462
463    def attr_policy(self, cw):
464        if self.checks.get('unterminated-ok', False):
465            policy = 'NLA_STRING'
466        else:
467            policy = 'NLA_NUL_STRING'
468
469        spec = self._attr_policy(policy)
470        cw.p(f"\t[{self.enum_name}] = {spec},")
471
472    def attr_put(self, ri, var):
473        self._attr_put_simple(ri, var, 'str')
474
475    def _attr_get(self, ri, var):
476        len_mem = var + '->_present.' + self.c_name + '_len'
477        return [f"{len_mem} = len;",
478                f"{var}->{self.c_name} = malloc(len + 1);",
479                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
480                f"{var}->{self.c_name}[len] = 0;"], \
481               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
482               ['unsigned int len;']
483
484    def _setter_lines(self, ri, member, presence):
485        return [f"free({member});",
486                f"{presence}_len = strlen({self.c_name});",
487                f"{member} = malloc({presence}_len + 1);",
488                f'memcpy({member}, {self.c_name}, {presence}_len);',
489                f'{member}[{presence}_len] = 0;']
490
491
492class TypeBinary(Type):
493    def arg_member(self, ri):
494        return [f"const void *{self.c_name}", 'size_t len']
495
496    def presence_type(self):
497        return 'len'
498
499    def struct_member(self, ri):
500        ri.cw.p(f"void *{self.c_name};")
501
502    def _attr_typol(self):
503        return f'.type = YNL_PT_BINARY,'
504
505    def _attr_policy(self, policy):
506        if len(self.checks) == 0:
507            pass
508        elif len(self.checks) == 1:
509            check_name = list(self.checks)[0]
510            if check_name not in {'exact-len', 'min-len', 'max-len'}:
511                raise Exception('Unsupported check for binary type: ' + check_name)
512        else:
513            raise Exception('More than one check for binary type not implemented, yet')
514
515        if len(self.checks) == 0:
516            mem = '{ .type = NLA_BINARY, }'
517        elif 'exact-len' in self.checks:
518            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
519        elif 'min-len' in self.checks:
520            mem = '{ .len = ' + self.get_limit_str('min-len') + ', }'
521        elif 'max-len' in self.checks:
522            mem = 'NLA_POLICY_MAX_LEN(' + self.get_limit_str('max-len') + ')'
523
524        return mem
525
526    def attr_put(self, ri, var):
527        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
528                            f"{var}->{self.c_name}, {var}->_present.{self.c_name}_len)")
529
530    def _attr_get(self, ri, var):
531        len_mem = var + '->_present.' + self.c_name + '_len'
532        return [f"{len_mem} = len;",
533                f"{var}->{self.c_name} = malloc(len);",
534                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
535               ['len = ynl_attr_data_len(attr);'], \
536               ['unsigned int len;']
537
538    def _setter_lines(self, ri, member, presence):
539        return [f"free({member});",
540                f"{presence}_len = len;",
541                f"{member} = malloc({presence}_len);",
542                f'memcpy({member}, {self.c_name}, {presence}_len);']
543
544
545class TypeBitfield32(Type):
546    def _complex_member_type(self, ri):
547        return "struct nla_bitfield32"
548
549    def _attr_typol(self):
550        return f'.type = YNL_PT_BITFIELD32, '
551
552    def _attr_policy(self, policy):
553        if not 'enum' in self.attr:
554            raise Exception('Enum required for bitfield32 attr')
555        enum = self.family.consts[self.attr['enum']]
556        mask = enum.get_mask(as_flags=True)
557        return f"NLA_POLICY_BITFIELD32({mask})"
558
559    def attr_put(self, ri, var):
560        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
561        self._attr_put_line(ri, var, line)
562
563    def _attr_get(self, ri, var):
564        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
565
566    def _setter_lines(self, ri, member, presence):
567        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
568
569
570class TypeNest(Type):
571    def is_recursive(self):
572        return self.family.pure_nested_structs[self.nested_attrs].recursive
573
574    def _complex_member_type(self, ri):
575        return self.nested_struct_type
576
577    def free(self, ri, var, ref):
578        at = '&'
579        if self.is_recursive_for_op(ri):
580            at = ''
581            ri.cw.p(f'if ({var}->{ref}{self.c_name})')
582        ri.cw.p(f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});')
583
584    def _attr_typol(self):
585        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
586
587    def _attr_policy(self, policy):
588        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
589
590    def attr_put(self, ri, var):
591        at = '' if self.is_recursive_for_op(ri) else '&'
592        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
593                            f"{self.enum_name}, {at}{var}->{self.c_name})")
594
595    def _attr_get(self, ri, var):
596        get_lines = [f"if ({self.nested_render_name}_parse(&parg, attr))",
597                     "return YNL_PARSE_CB_ERROR;"]
598        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
599                      f"parg.data = &{var}->{self.c_name};"]
600        return get_lines, init_lines, None
601
602    def setter(self, ri, space, direction, deref=False, ref=None):
603        ref = (ref if ref else []) + [self.c_name]
604
605        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
606            if attr.is_recursive():
607                continue
608            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
609
610
611class TypeMultiAttr(Type):
612    def __init__(self, family, attr_set, attr, value, base_type):
613        super().__init__(family, attr_set, attr, value)
614
615        self.base_type = base_type
616
617    def is_multi_val(self):
618        return True
619
620    def presence_type(self):
621        return 'count'
622
623    def _complex_member_type(self, ri):
624        if 'type' not in self.attr or self.attr['type'] == 'nest':
625            return self.nested_struct_type
626        elif self.attr['type'] in scalars:
627            scalar_pfx = '__' if ri.ku_space == 'user' else ''
628            return scalar_pfx + self.attr['type']
629        else:
630            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
631
632    def free_needs_iter(self):
633        return 'type' not in self.attr or self.attr['type'] == 'nest'
634
635    def free(self, ri, var, ref):
636        if self.attr['type'] in scalars:
637            ri.cw.p(f"free({var}->{ref}{self.c_name});")
638        elif 'type' not in self.attr or self.attr['type'] == 'nest':
639            ri.cw.p(f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)")
640            ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);')
641            ri.cw.p(f"free({var}->{ref}{self.c_name});")
642        else:
643            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
644
645    def _attr_policy(self, policy):
646        return self.base_type._attr_policy(policy)
647
648    def _attr_typol(self):
649        return self.base_type._attr_typol()
650
651    def _attr_get(self, ri, var):
652        return f'n_{self.c_name}++;', None, None
653
654    def attr_put(self, ri, var):
655        if self.attr['type'] in scalars:
656            put_type = self.type
657            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
658            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
659        elif 'type' not in self.attr or self.attr['type'] == 'nest':
660            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
661            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
662                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
663        else:
664            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
665
666    def _setter_lines(self, ri, member, presence):
667        # For multi-attr we have a count, not presence, hack up the presence
668        presence = presence[:-(len('_present.') + len(self.c_name))] + "n_" + self.c_name
669        return [f"free({member});",
670                f"{member} = {self.c_name};",
671                f"{presence} = n_{self.c_name};"]
672
673
674class TypeArrayNest(Type):
675    def is_multi_val(self):
676        return True
677
678    def presence_type(self):
679        return 'count'
680
681    def _complex_member_type(self, ri):
682        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
683            return self.nested_struct_type
684        elif self.attr['sub-type'] in scalars:
685            scalar_pfx = '__' if ri.ku_space == 'user' else ''
686            return scalar_pfx + self.attr['sub-type']
687        else:
688            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
689
690    def _attr_typol(self):
691        if self.attr['sub-type'] in scalars:
692            return f'.type = YNL_PT_U{c_upper(self.sub_type[1:])}, '
693        else:
694            return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
695
696    def _attr_get(self, ri, var):
697        local_vars = ['const struct nlattr *attr2;']
698        get_lines = [f'attr_{self.c_name} = attr;',
699                     'ynl_attr_for_each_nested(attr2, attr)',
700                     f'\t{var}->n_{self.c_name}++;']
701        return get_lines, None, local_vars
702
703
704class TypeNestTypeValue(Type):
705    def _complex_member_type(self, ri):
706        return self.nested_struct_type
707
708    def _attr_typol(self):
709        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
710
711    def _attr_get(self, ri, var):
712        prev = 'attr'
713        tv_args = ''
714        get_lines = []
715        local_vars = []
716        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
717                      f"parg.data = &{var}->{self.c_name};"]
718        if 'type-value' in self.attr:
719            tv_names = [c_lower(x) for x in self.attr["type-value"]]
720            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
721            local_vars += [f'__u32 {", ".join(tv_names)};']
722            for level in self.attr["type-value"]:
723                level = c_lower(level)
724                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
725                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
726                prev = 'attr_' + level
727
728            tv_args = f", {', '.join(tv_names)}"
729
730        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
731        return get_lines, init_lines, local_vars
732
733
734class Struct:
735    def __init__(self, family, space_name, type_list=None, inherited=None):
736        self.family = family
737        self.space_name = space_name
738        self.attr_set = family.attr_sets[space_name]
739        # Use list to catch comparisons with empty sets
740        self._inherited = inherited if inherited is not None else []
741        self.inherited = []
742
743        self.nested = type_list is None
744        if family.name == c_lower(space_name):
745            self.render_name = c_lower(family.ident_name)
746        else:
747            self.render_name = c_lower(family.ident_name + '-' + space_name)
748        self.struct_name = 'struct ' + self.render_name
749        if self.nested and space_name in family.consts:
750            self.struct_name += '_'
751        self.ptr_name = self.struct_name + ' *'
752        # All attr sets this one contains, directly or multiple levels down
753        self.child_nests = set()
754
755        self.request = False
756        self.reply = False
757        self.recursive = False
758
759        self.attr_list = []
760        self.attrs = dict()
761        if type_list is not None:
762            for t in type_list:
763                self.attr_list.append((t, self.attr_set[t]),)
764        else:
765            for t in self.attr_set:
766                self.attr_list.append((t, self.attr_set[t]),)
767
768        max_val = 0
769        self.attr_max_val = None
770        for name, attr in self.attr_list:
771            if attr.value >= max_val:
772                max_val = attr.value
773                self.attr_max_val = attr
774            self.attrs[name] = attr
775
776    def __iter__(self):
777        yield from self.attrs
778
779    def __getitem__(self, key):
780        return self.attrs[key]
781
782    def member_list(self):
783        return self.attr_list
784
785    def set_inherited(self, new_inherited):
786        if self._inherited != new_inherited:
787            raise Exception("Inheriting different members not supported")
788        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
789
790
791class EnumEntry(SpecEnumEntry):
792    def __init__(self, enum_set, yaml, prev, value_start):
793        super().__init__(enum_set, yaml, prev, value_start)
794
795        if prev:
796            self.value_change = (self.value != prev.value + 1)
797        else:
798            self.value_change = (self.value != 0)
799        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
800
801        # Added by resolve:
802        self.c_name = None
803        delattr(self, "c_name")
804
805    def resolve(self):
806        self.resolve_up(super())
807
808        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
809
810
811class EnumSet(SpecEnumSet):
812    def __init__(self, family, yaml):
813        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
814
815        if 'enum-name' in yaml:
816            if yaml['enum-name']:
817                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
818                self.user_type = self.enum_name
819            else:
820                self.enum_name = None
821        else:
822            self.enum_name = 'enum ' + self.render_name
823
824        if self.enum_name:
825            self.user_type = self.enum_name
826        else:
827            self.user_type = 'int'
828
829        self.value_pfx = yaml.get('name-prefix', f"{family.ident_name}-{yaml['name']}-")
830        self.header = yaml.get('header', None)
831        self.enum_cnt_name = yaml.get('enum-cnt-name', None)
832
833        super().__init__(family, yaml)
834
835    def new_entry(self, entry, prev_entry, value_start):
836        return EnumEntry(self, entry, prev_entry, value_start)
837
838    def value_range(self):
839        low = min([x.value for x in self.entries.values()])
840        high = max([x.value for x in self.entries.values()])
841
842        if high - low + 1 != len(self.entries):
843            raise Exception("Can't get value range for a noncontiguous enum")
844
845        return low, high
846
847
848class AttrSet(SpecAttrSet):
849    def __init__(self, family, yaml):
850        super().__init__(family, yaml)
851
852        if self.subset_of is None:
853            if 'name-prefix' in yaml:
854                pfx = yaml['name-prefix']
855            elif self.name == family.name:
856                pfx = family.ident_name + '-a-'
857            else:
858                pfx = f"{family.ident_name}-a-{self.name}-"
859            self.name_prefix = c_upper(pfx)
860            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
861            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
862        else:
863            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
864            self.max_name = family.attr_sets[self.subset_of].max_name
865            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
866
867        # Added by resolve:
868        self.c_name = None
869        delattr(self, "c_name")
870
871    def resolve(self):
872        self.c_name = c_lower(self.name)
873        if self.c_name in _C_KW:
874            self.c_name += '_'
875        if self.c_name == self.family.c_name:
876            self.c_name = ''
877
878    def new_attr(self, elem, value):
879        if elem['type'] in scalars:
880            t = TypeScalar(self.family, self, elem, value)
881        elif elem['type'] == 'unused':
882            t = TypeUnused(self.family, self, elem, value)
883        elif elem['type'] == 'pad':
884            t = TypePad(self.family, self, elem, value)
885        elif elem['type'] == 'flag':
886            t = TypeFlag(self.family, self, elem, value)
887        elif elem['type'] == 'string':
888            t = TypeString(self.family, self, elem, value)
889        elif elem['type'] == 'binary':
890            t = TypeBinary(self.family, self, elem, value)
891        elif elem['type'] == 'bitfield32':
892            t = TypeBitfield32(self.family, self, elem, value)
893        elif elem['type'] == 'nest':
894            t = TypeNest(self.family, self, elem, value)
895        elif elem['type'] == 'indexed-array' and 'sub-type' in elem:
896            if elem["sub-type"] in ['nest', 'u32']:
897                t = TypeArrayNest(self.family, self, elem, value)
898            else:
899                raise Exception(f'new_attr: unsupported sub-type {elem["sub-type"]}')
900        elif elem['type'] == 'nest-type-value':
901            t = TypeNestTypeValue(self.family, self, elem, value)
902        else:
903            raise Exception(f"No typed class for type {elem['type']}")
904
905        if 'multi-attr' in elem and elem['multi-attr']:
906            t = TypeMultiAttr(self.family, self, elem, value, t)
907
908        return t
909
910
911class Operation(SpecOperation):
912    def __init__(self, family, yaml, req_value, rsp_value):
913        super().__init__(family, yaml, req_value, rsp_value)
914
915        self.render_name = c_lower(family.ident_name + '_' + self.name)
916
917        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
918                         ('dump' in yaml and 'request' in yaml['dump'])
919
920        self.has_ntf = False
921
922        # Added by resolve:
923        self.enum_name = None
924        delattr(self, "enum_name")
925
926    def resolve(self):
927        self.resolve_up(super())
928
929        if not self.is_async:
930            self.enum_name = self.family.op_prefix + c_upper(self.name)
931        else:
932            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
933
934    def mark_has_ntf(self):
935        self.has_ntf = True
936
937
938class Family(SpecFamily):
939    def __init__(self, file_name, exclude_ops):
940        # Added by resolve:
941        self.c_name = None
942        delattr(self, "c_name")
943        self.op_prefix = None
944        delattr(self, "op_prefix")
945        self.async_op_prefix = None
946        delattr(self, "async_op_prefix")
947        self.mcgrps = None
948        delattr(self, "mcgrps")
949        self.consts = None
950        delattr(self, "consts")
951        self.hooks = None
952        delattr(self, "hooks")
953
954        super().__init__(file_name, exclude_ops=exclude_ops)
955
956        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
957        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
958
959        if 'definitions' not in self.yaml:
960            self.yaml['definitions'] = []
961
962        if 'uapi-header' in self.yaml:
963            self.uapi_header = self.yaml['uapi-header']
964        else:
965            self.uapi_header = f"linux/{self.ident_name}.h"
966        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
967            self.uapi_header_name = self.uapi_header[6:-2]
968        else:
969            self.uapi_header_name = self.ident_name
970
971    def resolve(self):
972        self.resolve_up(super())
973
974        self.c_name = c_lower(self.ident_name)
975        if 'name-prefix' in self.yaml['operations']:
976            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
977        else:
978            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
979        if 'async-prefix' in self.yaml['operations']:
980            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
981        else:
982            self.async_op_prefix = self.op_prefix
983
984        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
985
986        self.hooks = dict()
987        for when in ['pre', 'post']:
988            self.hooks[when] = dict()
989            for op_mode in ['do', 'dump']:
990                self.hooks[when][op_mode] = dict()
991                self.hooks[when][op_mode]['set'] = set()
992                self.hooks[when][op_mode]['list'] = []
993
994        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
995        self.root_sets = dict()
996        # dict space-name -> set('request', 'reply')
997        self.pure_nested_structs = dict()
998
999        self._mark_notify()
1000        self._mock_up_events()
1001
1002        self._load_root_sets()
1003        self._load_nested_sets()
1004        self._load_attr_use()
1005        self._load_hooks()
1006
1007        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
1008        if self.kernel_policy == 'global':
1009            self._load_global_policy()
1010
1011    def new_enum(self, elem):
1012        return EnumSet(self, elem)
1013
1014    def new_attr_set(self, elem):
1015        return AttrSet(self, elem)
1016
1017    def new_operation(self, elem, req_value, rsp_value):
1018        return Operation(self, elem, req_value, rsp_value)
1019
1020    def is_classic(self):
1021        return self.proto == 'netlink-raw'
1022
1023    def _mark_notify(self):
1024        for op in self.msgs.values():
1025            if 'notify' in op:
1026                self.ops[op['notify']].mark_has_ntf()
1027
1028    # Fake a 'do' equivalent of all events, so that we can render their response parsing
1029    def _mock_up_events(self):
1030        for op in self.yaml['operations']['list']:
1031            if 'event' in op:
1032                op['do'] = {
1033                    'reply': {
1034                        'attributes': op['event']['attributes']
1035                    }
1036                }
1037
1038    def _load_root_sets(self):
1039        for op_name, op in self.msgs.items():
1040            if 'attribute-set' not in op:
1041                continue
1042
1043            req_attrs = set()
1044            rsp_attrs = set()
1045            for op_mode in ['do', 'dump']:
1046                if op_mode in op and 'request' in op[op_mode]:
1047                    req_attrs.update(set(op[op_mode]['request']['attributes']))
1048                if op_mode in op and 'reply' in op[op_mode]:
1049                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
1050            if 'event' in op:
1051                rsp_attrs.update(set(op['event']['attributes']))
1052
1053            if op['attribute-set'] not in self.root_sets:
1054                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1055            else:
1056                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1057                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1058
1059    def _sort_pure_types(self):
1060        # Try to reorder according to dependencies
1061        pns_key_list = list(self.pure_nested_structs.keys())
1062        pns_key_seen = set()
1063        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1064        for _ in range(rounds):
1065            if len(pns_key_list) == 0:
1066                break
1067            name = pns_key_list.pop(0)
1068            finished = True
1069            for _, spec in self.attr_sets[name].items():
1070                if 'nested-attributes' in spec:
1071                    nested = spec['nested-attributes']
1072                    # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1073                    if self.pure_nested_structs[nested].recursive:
1074                        continue
1075                    if nested not in pns_key_seen:
1076                        # Dicts are sorted, this will make struct last
1077                        struct = self.pure_nested_structs.pop(name)
1078                        self.pure_nested_structs[name] = struct
1079                        finished = False
1080                        break
1081            if finished:
1082                pns_key_seen.add(name)
1083            else:
1084                pns_key_list.append(name)
1085
1086    def _load_nested_sets(self):
1087        attr_set_queue = list(self.root_sets.keys())
1088        attr_set_seen = set(self.root_sets.keys())
1089
1090        while len(attr_set_queue):
1091            a_set = attr_set_queue.pop(0)
1092            for attr, spec in self.attr_sets[a_set].items():
1093                if 'nested-attributes' not in spec:
1094                    continue
1095
1096                nested = spec['nested-attributes']
1097                if nested not in attr_set_seen:
1098                    attr_set_queue.append(nested)
1099                    attr_set_seen.add(nested)
1100
1101                inherit = set()
1102                if nested not in self.root_sets:
1103                    if nested not in self.pure_nested_structs:
1104                        self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
1105                else:
1106                    raise Exception(f'Using attr set as root and nested not supported - {nested}')
1107
1108                if 'type-value' in spec:
1109                    if nested in self.root_sets:
1110                        raise Exception("Inheriting members to a space used as root not supported")
1111                    inherit.update(set(spec['type-value']))
1112                elif spec['type'] == 'indexed-array':
1113                    inherit.add('idx')
1114                self.pure_nested_structs[nested].set_inherited(inherit)
1115
1116        for root_set, rs_members in self.root_sets.items():
1117            for attr, spec in self.attr_sets[root_set].items():
1118                if 'nested-attributes' in spec:
1119                    nested = spec['nested-attributes']
1120                    if attr in rs_members['request']:
1121                        self.pure_nested_structs[nested].request = True
1122                    if attr in rs_members['reply']:
1123                        self.pure_nested_structs[nested].reply = True
1124
1125        self._sort_pure_types()
1126
1127        # Propagate the request / reply / recursive
1128        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1129            for _, spec in self.attr_sets[attr_set].items():
1130                if 'nested-attributes' in spec:
1131                    child_name = spec['nested-attributes']
1132                    struct.child_nests.add(child_name)
1133                    child = self.pure_nested_structs.get(child_name)
1134                    if child:
1135                        if not child.recursive:
1136                            struct.child_nests.update(child.child_nests)
1137                        child.request |= struct.request
1138                        child.reply |= struct.reply
1139                if attr_set in struct.child_nests:
1140                    struct.recursive = True
1141
1142        self._sort_pure_types()
1143
1144    def _load_attr_use(self):
1145        for _, struct in self.pure_nested_structs.items():
1146            if struct.request:
1147                for _, arg in struct.member_list():
1148                    arg.set_request()
1149            if struct.reply:
1150                for _, arg in struct.member_list():
1151                    arg.set_reply()
1152
1153        for root_set, rs_members in self.root_sets.items():
1154            for attr, spec in self.attr_sets[root_set].items():
1155                if attr in rs_members['request']:
1156                    spec.set_request()
1157                if attr in rs_members['reply']:
1158                    spec.set_reply()
1159
1160    def _load_global_policy(self):
1161        global_set = set()
1162        attr_set_name = None
1163        for op_name, op in self.ops.items():
1164            if not op:
1165                continue
1166            if 'attribute-set' not in op:
1167                continue
1168
1169            if attr_set_name is None:
1170                attr_set_name = op['attribute-set']
1171            if attr_set_name != op['attribute-set']:
1172                raise Exception('For a global policy all ops must use the same set')
1173
1174            for op_mode in ['do', 'dump']:
1175                if op_mode in op:
1176                    req = op[op_mode].get('request')
1177                    if req:
1178                        global_set.update(req.get('attributes', []))
1179
1180        self.global_policy = []
1181        self.global_policy_set = attr_set_name
1182        for attr in self.attr_sets[attr_set_name]:
1183            if attr in global_set:
1184                self.global_policy.append(attr)
1185
1186    def _load_hooks(self):
1187        for op in self.ops.values():
1188            for op_mode in ['do', 'dump']:
1189                if op_mode not in op:
1190                    continue
1191                for when in ['pre', 'post']:
1192                    if when not in op[op_mode]:
1193                        continue
1194                    name = op[op_mode][when]
1195                    if name in self.hooks[when][op_mode]['set']:
1196                        continue
1197                    self.hooks[when][op_mode]['set'].add(name)
1198                    self.hooks[when][op_mode]['list'].append(name)
1199
1200
1201class RenderInfo:
1202    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1203        self.family = family
1204        self.nl = cw.nlib
1205        self.ku_space = ku_space
1206        self.op_mode = op_mode
1207        self.op = op
1208
1209        self.fixed_hdr = None
1210        if op and op.fixed_header:
1211            self.fixed_hdr = 'struct ' + c_lower(op.fixed_header)
1212
1213        # 'do' and 'dump' response parsing is identical
1214        self.type_consistent = True
1215        self.type_oneside = False
1216        if op_mode != 'do' and 'dump' in op:
1217            if 'do' in op:
1218                if ('reply' in op['do']) != ('reply' in op["dump"]):
1219                    self.type_consistent = False
1220                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1221                    self.type_consistent = False
1222            else:
1223                self.type_consistent = True
1224                self.type_oneside = True
1225
1226        self.attr_set = attr_set
1227        if not self.attr_set:
1228            self.attr_set = op['attribute-set']
1229
1230        self.type_name_conflict = False
1231        if op:
1232            self.type_name = c_lower(op.name)
1233        else:
1234            self.type_name = c_lower(attr_set)
1235            if attr_set in family.consts:
1236                self.type_name_conflict = True
1237
1238        self.cw = cw
1239
1240        self.struct = dict()
1241        if op_mode == 'notify':
1242            op_mode = 'do'
1243        for op_dir in ['request', 'reply']:
1244            if op:
1245                type_list = []
1246                if op_dir in op[op_mode]:
1247                    type_list = op[op_mode][op_dir]['attributes']
1248                self.struct[op_dir] = Struct(family, self.attr_set, type_list=type_list)
1249        if op_mode == 'event':
1250            self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
1251
1252    def type_empty(self, key):
1253        return len(self.struct[key].attr_list) == 0 and self.fixed_hdr is None
1254
1255
1256class CodeWriter:
1257    def __init__(self, nlib, out_file=None, overwrite=True):
1258        self.nlib = nlib
1259        self._overwrite = overwrite
1260
1261        self._nl = False
1262        self._block_end = False
1263        self._silent_block = False
1264        self._ind = 0
1265        self._ifdef_block = None
1266        if out_file is None:
1267            self._out = os.sys.stdout
1268        else:
1269            self._out = tempfile.NamedTemporaryFile('w+')
1270            self._out_file = out_file
1271
1272    def __del__(self):
1273        self.close_out_file()
1274
1275    def close_out_file(self):
1276        if self._out == os.sys.stdout:
1277            return
1278        # Avoid modifying the file if contents didn't change
1279        self._out.flush()
1280        if not self._overwrite and os.path.isfile(self._out_file):
1281            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1282                return
1283        with open(self._out_file, 'w+') as out_file:
1284            self._out.seek(0)
1285            shutil.copyfileobj(self._out, out_file)
1286            self._out.close()
1287        self._out = os.sys.stdout
1288
1289    @classmethod
1290    def _is_cond(cls, line):
1291        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1292
1293    def p(self, line, add_ind=0):
1294        if self._block_end:
1295            self._block_end = False
1296            if line.startswith('else'):
1297                line = '} ' + line
1298            else:
1299                self._out.write('\t' * self._ind + '}\n')
1300
1301        if self._nl:
1302            self._out.write('\n')
1303            self._nl = False
1304
1305        ind = self._ind
1306        if line[-1] == ':':
1307            ind -= 1
1308        if self._silent_block:
1309            ind += 1
1310        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1311        if line[0] == '#':
1312            ind = 0
1313        if add_ind:
1314            ind += add_ind
1315        self._out.write('\t' * ind + line + '\n')
1316
1317    def nl(self):
1318        self._nl = True
1319
1320    def block_start(self, line=''):
1321        if line:
1322            line = line + ' '
1323        self.p(line + '{')
1324        self._ind += 1
1325
1326    def block_end(self, line=''):
1327        if line and line[0] not in {';', ','}:
1328            line = ' ' + line
1329        self._ind -= 1
1330        self._nl = False
1331        if not line:
1332            # Delay printing closing bracket in case "else" comes next
1333            if self._block_end:
1334                self._out.write('\t' * (self._ind + 1) + '}\n')
1335            self._block_end = True
1336        else:
1337            self.p('}' + line)
1338
1339    def write_doc_line(self, doc, indent=True):
1340        words = doc.split()
1341        line = ' *'
1342        for word in words:
1343            if len(line) + len(word) >= 79:
1344                self.p(line)
1345                line = ' *'
1346                if indent:
1347                    line += '  '
1348            line += ' ' + word
1349        self.p(line)
1350
1351    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1352        if not args:
1353            args = ['void']
1354
1355        if doc:
1356            self.p('/*')
1357            self.p(' * ' + doc)
1358            self.p(' */')
1359
1360        oneline = qual_ret
1361        if qual_ret[-1] != '*':
1362            oneline += ' '
1363        oneline += f"{name}({', '.join(args)}){suffix}"
1364
1365        if len(oneline) < 80:
1366            self.p(oneline)
1367            return
1368
1369        v = qual_ret
1370        if len(v) > 3:
1371            self.p(v)
1372            v = ''
1373        elif qual_ret[-1] != '*':
1374            v += ' '
1375        v += name + '('
1376        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1377        delta_ind = len(v) - len(ind)
1378        v += args[0]
1379        i = 1
1380        while i < len(args):
1381            next_len = len(v) + len(args[i])
1382            if v[0] == '\t':
1383                next_len += delta_ind
1384            if next_len > 76:
1385                self.p(v + ',')
1386                v = ind
1387            else:
1388                v += ', '
1389            v += args[i]
1390            i += 1
1391        self.p(v + ')' + suffix)
1392
1393    def write_func_lvar(self, local_vars):
1394        if not local_vars:
1395            return
1396
1397        if type(local_vars) is str:
1398            local_vars = [local_vars]
1399
1400        local_vars.sort(key=len, reverse=True)
1401        for var in local_vars:
1402            self.p(var)
1403        self.nl()
1404
1405    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1406        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1407        self.write_func_lvar(local_vars=local_vars)
1408
1409        self.block_start()
1410        for line in body:
1411            self.p(line)
1412        self.block_end()
1413
1414    def writes_defines(self, defines):
1415        longest = 0
1416        for define in defines:
1417            if len(define[0]) > longest:
1418                longest = len(define[0])
1419        longest = ((longest + 8) // 8) * 8
1420        for define in defines:
1421            line = '#define ' + define[0]
1422            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1423            if type(define[1]) is int:
1424                line += str(define[1])
1425            elif type(define[1]) is str:
1426                line += '"' + define[1] + '"'
1427            self.p(line)
1428
1429    def write_struct_init(self, members):
1430        longest = max([len(x[0]) for x in members])
1431        longest += 1  # because we prepend a .
1432        longest = ((longest + 8) // 8) * 8
1433        for one in members:
1434            line = '.' + one[0]
1435            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1436            line += '= ' + str(one[1]) + ','
1437            self.p(line)
1438
1439    def ifdef_block(self, config):
1440        config_option = None
1441        if config:
1442            config_option = 'CONFIG_' + c_upper(config)
1443        if self._ifdef_block == config_option:
1444            return
1445
1446        if self._ifdef_block:
1447            self.p('#endif /* ' + self._ifdef_block + ' */')
1448        if config_option:
1449            self.p('#ifdef ' + config_option)
1450        self._ifdef_block = config_option
1451
1452
1453scalars = {'u8', 'u16', 'u32', 'u64', 's8', 's16', 's32', 's64', 'uint', 'sint'}
1454
1455direction_to_suffix = {
1456    'reply': '_rsp',
1457    'request': '_req',
1458    '': ''
1459}
1460
1461op_mode_to_wrapper = {
1462    'do': '',
1463    'dump': '_list',
1464    'notify': '_ntf',
1465    'event': '',
1466}
1467
1468_C_KW = {
1469    'auto',
1470    'bool',
1471    'break',
1472    'case',
1473    'char',
1474    'const',
1475    'continue',
1476    'default',
1477    'do',
1478    'double',
1479    'else',
1480    'enum',
1481    'extern',
1482    'float',
1483    'for',
1484    'goto',
1485    'if',
1486    'inline',
1487    'int',
1488    'long',
1489    'register',
1490    'return',
1491    'short',
1492    'signed',
1493    'sizeof',
1494    'static',
1495    'struct',
1496    'switch',
1497    'typedef',
1498    'union',
1499    'unsigned',
1500    'void',
1501    'volatile',
1502    'while'
1503}
1504
1505
1506def rdir(direction):
1507    if direction == 'reply':
1508        return 'request'
1509    if direction == 'request':
1510        return 'reply'
1511    return direction
1512
1513
1514def op_prefix(ri, direction, deref=False):
1515    suffix = f"_{ri.type_name}"
1516
1517    if not ri.op_mode or ri.op_mode == 'do':
1518        suffix += f"{direction_to_suffix[direction]}"
1519    else:
1520        if direction == 'request':
1521            suffix += '_req'
1522            if not ri.type_oneside:
1523                suffix += '_dump'
1524        else:
1525            if ri.type_consistent:
1526                if deref:
1527                    suffix += f"{direction_to_suffix[direction]}"
1528                else:
1529                    suffix += op_mode_to_wrapper[ri.op_mode]
1530            else:
1531                suffix += '_rsp'
1532                suffix += '_dump' if deref else '_list'
1533
1534    return f"{ri.family.c_name}{suffix}"
1535
1536
1537def type_name(ri, direction, deref=False):
1538    return f"struct {op_prefix(ri, direction, deref=deref)}"
1539
1540
1541def print_prototype(ri, direction, terminate=True, doc=None):
1542    suffix = ';' if terminate else ''
1543
1544    fname = ri.op.render_name
1545    if ri.op_mode == 'dump':
1546        fname += '_dump'
1547
1548    args = ['struct ynl_sock *ys']
1549    if 'request' in ri.op[ri.op_mode]:
1550        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1551
1552    ret = 'int'
1553    if 'reply' in ri.op[ri.op_mode]:
1554        ret = f"{type_name(ri, rdir(direction))} *"
1555
1556    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1557
1558
1559def print_req_prototype(ri):
1560    print_prototype(ri, "request", doc=ri.op['doc'])
1561
1562
1563def print_dump_prototype(ri):
1564    print_prototype(ri, "request")
1565
1566
1567def put_typol_fwd(cw, struct):
1568    cw.p(f'extern const struct ynl_policy_nest {struct.render_name}_nest;')
1569
1570
1571def put_typol(cw, struct):
1572    type_max = struct.attr_set.max_name
1573    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1574
1575    for _, arg in struct.member_list():
1576        arg.attr_typol(cw)
1577
1578    cw.block_end(line=';')
1579    cw.nl()
1580
1581    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1582    cw.p(f'.max_attr = {type_max},')
1583    cw.p(f'.table = {struct.render_name}_policy,')
1584    cw.block_end(line=';')
1585    cw.nl()
1586
1587
1588def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1589    args = [f'int {arg_name}']
1590    if enum:
1591        args = [enum.user_type + ' ' + arg_name]
1592    cw.write_func_prot('const char *', f'{render_name}_str', args)
1593    cw.block_start()
1594    if enum and enum.type == 'flags':
1595        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1596    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
1597    cw.p('return NULL;')
1598    cw.p(f'return {map_name}[{arg_name}];')
1599    cw.block_end()
1600    cw.nl()
1601
1602
1603def put_op_name_fwd(family, cw):
1604    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
1605
1606
1607def put_op_name(family, cw):
1608    map_name = f'{family.c_name}_op_strmap'
1609    cw.block_start(line=f"static const char * const {map_name}[] =")
1610    for op_name, op in family.msgs.items():
1611        if op.rsp_value:
1612            # Make sure we don't add duplicated entries, if multiple commands
1613            # produce the same response in legacy families.
1614            if family.rsp_by_value[op.rsp_value] != op:
1615                cw.p(f'// skip "{op_name}", duplicate reply value')
1616                continue
1617
1618            if op.req_value == op.rsp_value:
1619                cw.p(f'[{op.enum_name}] = "{op_name}",')
1620            else:
1621                cw.p(f'[{op.rsp_value}] = "{op_name}",')
1622    cw.block_end(line=';')
1623    cw.nl()
1624
1625    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
1626
1627
1628def put_enum_to_str_fwd(family, cw, enum):
1629    args = [enum.user_type + ' value']
1630    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1631
1632
1633def put_enum_to_str(family, cw, enum):
1634    map_name = f'{enum.render_name}_strmap'
1635    cw.block_start(line=f"static const char * const {map_name}[] =")
1636    for entry in enum.entries.values():
1637        cw.p(f'[{entry.value}] = "{entry.name}",')
1638    cw.block_end(line=';')
1639    cw.nl()
1640
1641    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
1642
1643
1644def put_req_nested_prototype(ri, struct, suffix=';'):
1645    func_args = ['struct nlmsghdr *nlh',
1646                 'unsigned int attr_type',
1647                 f'{struct.ptr_name}obj']
1648
1649    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
1650                          suffix=suffix)
1651
1652
1653def put_req_nested(ri, struct):
1654    put_req_nested_prototype(ri, struct, suffix='')
1655    ri.cw.block_start()
1656    ri.cw.write_func_lvar('struct nlattr *nest;')
1657
1658    ri.cw.p("nest = ynl_attr_nest_start(nlh, attr_type);")
1659
1660    for _, arg in struct.member_list():
1661        arg.attr_put(ri, "obj")
1662
1663    ri.cw.p("ynl_attr_nest_end(nlh, nest);")
1664
1665    ri.cw.nl()
1666    ri.cw.p('return 0;')
1667    ri.cw.block_end()
1668    ri.cw.nl()
1669
1670
1671def _multi_parse(ri, struct, init_lines, local_vars):
1672    if struct.nested:
1673        iter_line = "ynl_attr_for_each_nested(attr, nested)"
1674    else:
1675        if ri.fixed_hdr:
1676            local_vars += ['void *hdr;']
1677        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
1678
1679    array_nests = set()
1680    multi_attrs = set()
1681    needs_parg = False
1682    for arg, aspec in struct.member_list():
1683        if aspec['type'] == 'indexed-array' and 'sub-type' in aspec:
1684            if aspec["sub-type"] == 'nest':
1685                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1686                array_nests.add(arg)
1687            elif aspec['sub-type'] in scalars:
1688                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1689                array_nests.add(arg)
1690            else:
1691                raise Exception(f'Not supported sub-type {aspec["sub-type"]}')
1692        if 'multi-attr' in aspec:
1693            multi_attrs.add(arg)
1694        needs_parg |= 'nested-attributes' in aspec
1695    if array_nests or multi_attrs:
1696        local_vars.append('int i;')
1697    if needs_parg:
1698        local_vars.append('struct ynl_parse_arg parg;')
1699        init_lines.append('parg.ys = yarg->ys;')
1700
1701    all_multi = array_nests | multi_attrs
1702
1703    for anest in sorted(all_multi):
1704        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
1705
1706    ri.cw.block_start()
1707    ri.cw.write_func_lvar(local_vars)
1708
1709    for line in init_lines:
1710        ri.cw.p(line)
1711    ri.cw.nl()
1712
1713    for arg in struct.inherited:
1714        ri.cw.p(f'dst->{arg} = {arg};')
1715
1716    if ri.fixed_hdr:
1717        if ri.family.is_classic():
1718            ri.cw.p('hdr = ynl_nlmsg_data(nlh);')
1719        else:
1720            ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
1721        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({ri.fixed_hdr}));")
1722    for anest in sorted(all_multi):
1723        aspec = struct[anest]
1724        ri.cw.p(f"if (dst->{aspec.c_name})")
1725        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
1726
1727    ri.cw.nl()
1728    ri.cw.block_start(line=iter_line)
1729    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
1730    ri.cw.nl()
1731
1732    first = True
1733    for _, arg in struct.member_list():
1734        good = arg.attr_get(ri, 'dst', first=first)
1735        # First may be 'unused' or 'pad', ignore those
1736        first &= not good
1737
1738    ri.cw.block_end()
1739    ri.cw.nl()
1740
1741    for anest in sorted(array_nests):
1742        aspec = struct[anest]
1743
1744        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1745        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1746        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1747        ri.cw.p('i = 0;')
1748        if 'nested-attributes' in aspec:
1749            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1750        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1751        if 'nested-attributes' in aspec:
1752            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1753            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
1754            ri.cw.p('return YNL_PARSE_CB_ERROR;')
1755        elif aspec.sub_type in scalars:
1756            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.sub_type}(attr);")
1757        else:
1758            raise Exception(f"Nest parsing type not supported in {aspec['name']}")
1759        ri.cw.p('i++;')
1760        ri.cw.block_end()
1761        ri.cw.block_end()
1762    ri.cw.nl()
1763
1764    for anest in sorted(multi_attrs):
1765        aspec = struct[anest]
1766        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1767        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1768        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1769        ri.cw.p('i = 0;')
1770        if 'nested-attributes' in aspec:
1771            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1772        ri.cw.block_start(line=iter_line)
1773        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
1774        if 'nested-attributes' in aspec:
1775            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1776            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1777            ri.cw.p('return YNL_PARSE_CB_ERROR;')
1778        elif aspec.type in scalars:
1779            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
1780        else:
1781            raise Exception('Nest parsing type not supported yet')
1782        ri.cw.p('i++;')
1783        ri.cw.block_end()
1784        ri.cw.block_end()
1785        ri.cw.block_end()
1786    ri.cw.nl()
1787
1788    if struct.nested:
1789        ri.cw.p('return 0;')
1790    else:
1791        ri.cw.p('return YNL_PARSE_CB_OK;')
1792    ri.cw.block_end()
1793    ri.cw.nl()
1794
1795
1796def parse_rsp_nested_prototype(ri, struct, suffix=';'):
1797    func_args = ['struct ynl_parse_arg *yarg',
1798                 'const struct nlattr *nested']
1799    for arg in struct.inherited:
1800        func_args.append('__u32 ' + arg)
1801
1802    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
1803                          suffix=suffix)
1804
1805
1806def parse_rsp_nested(ri, struct):
1807    parse_rsp_nested_prototype(ri, struct, suffix='')
1808
1809    local_vars = ['const struct nlattr *attr;',
1810                  f'{struct.ptr_name}dst = yarg->data;']
1811    init_lines = []
1812
1813    if struct.member_list():
1814        _multi_parse(ri, struct, init_lines, local_vars)
1815    else:
1816        # Empty nest
1817        ri.cw.block_start()
1818        ri.cw.p('return 0;')
1819        ri.cw.block_end()
1820        ri.cw.nl()
1821
1822
1823def parse_rsp_msg(ri, deref=False):
1824    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1825        return
1826
1827    func_args = ['const struct nlmsghdr *nlh',
1828                 'struct ynl_parse_arg *yarg']
1829
1830    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1831                  'const struct nlattr *attr;']
1832    init_lines = ['dst = yarg->data;']
1833
1834    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1835
1836    if ri.struct["reply"].member_list():
1837        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1838    else:
1839        # Empty reply
1840        ri.cw.block_start()
1841        ri.cw.p('return YNL_PARSE_CB_OK;')
1842        ri.cw.block_end()
1843        ri.cw.nl()
1844
1845
1846def print_req(ri):
1847    ret_ok = '0'
1848    ret_err = '-1'
1849    direction = "request"
1850    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
1851                  'struct nlmsghdr *nlh;',
1852                  'int err;']
1853
1854    if 'reply' in ri.op[ri.op_mode]:
1855        ret_ok = 'rsp'
1856        ret_err = 'NULL'
1857        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
1858
1859    if ri.fixed_hdr:
1860        local_vars += ['size_t hdr_len;',
1861                       'void *hdr;']
1862
1863    print_prototype(ri, direction, terminate=False)
1864    ri.cw.block_start()
1865    ri.cw.write_func_lvar(local_vars)
1866
1867    if ri.family.is_classic():
1868        ri.cw.p(f"nlh = ynl_msg_start_req(ys, {ri.op.enum_name});")
1869    else:
1870        ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1871
1872    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1873    if 'reply' in ri.op[ri.op_mode]:
1874        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1875    ri.cw.nl()
1876
1877    if ri.fixed_hdr:
1878        ri.cw.p("hdr_len = sizeof(req->_hdr);")
1879        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
1880        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
1881        ri.cw.nl()
1882
1883    for _, attr in ri.struct["request"].member_list():
1884        attr.attr_put(ri, "req")
1885    ri.cw.nl()
1886
1887    if 'reply' in ri.op[ri.op_mode]:
1888        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1889        ri.cw.p('yrs.yarg.data = rsp;')
1890        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
1891        if ri.op.value is not None:
1892            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
1893        else:
1894            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
1895        ri.cw.nl()
1896    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
1897    ri.cw.p('if (err < 0)')
1898    if 'reply' in ri.op[ri.op_mode]:
1899        ri.cw.p('goto err_free;')
1900    else:
1901        ri.cw.p('return -1;')
1902    ri.cw.nl()
1903
1904    ri.cw.p(f"return {ret_ok};")
1905    ri.cw.nl()
1906
1907    if 'reply' in ri.op[ri.op_mode]:
1908        ri.cw.p('err_free:')
1909        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1910        ri.cw.p(f"return {ret_err};")
1911
1912    ri.cw.block_end()
1913
1914
1915def print_dump(ri):
1916    direction = "request"
1917    print_prototype(ri, direction, terminate=False)
1918    ri.cw.block_start()
1919    local_vars = ['struct ynl_dump_state yds = {};',
1920                  'struct nlmsghdr *nlh;',
1921                  'int err;']
1922
1923    if ri.fixed_hdr:
1924        local_vars += ['size_t hdr_len;',
1925                       'void *hdr;']
1926
1927    ri.cw.write_func_lvar(local_vars)
1928
1929    ri.cw.p('yds.yarg.ys = ys;')
1930    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1931    ri.cw.p("yds.yarg.data = NULL;")
1932    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1933    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1934    if ri.op.value is not None:
1935        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
1936    else:
1937        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
1938    ri.cw.nl()
1939    if ri.family.is_classic():
1940        ri.cw.p(f"nlh = ynl_msg_start_dump(ys, {ri.op.enum_name});")
1941    else:
1942        ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1943
1944    if ri.fixed_hdr:
1945        ri.cw.p("hdr_len = sizeof(req->_hdr);")
1946        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
1947        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
1948        ri.cw.nl()
1949
1950    if "request" in ri.op[ri.op_mode]:
1951        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1952        ri.cw.nl()
1953        for _, attr in ri.struct["request"].member_list():
1954            attr.attr_put(ri, "req")
1955    ri.cw.nl()
1956
1957    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
1958    ri.cw.p('if (err < 0)')
1959    ri.cw.p('goto free_list;')
1960    ri.cw.nl()
1961
1962    ri.cw.p('return yds.first;')
1963    ri.cw.nl()
1964    ri.cw.p('free_list:')
1965    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
1966    ri.cw.p('return NULL;')
1967    ri.cw.block_end()
1968
1969
1970def call_free(ri, direction, var):
1971    return f"{op_prefix(ri, direction)}_free({var});"
1972
1973
1974def free_arg_name(direction):
1975    if direction:
1976        return direction_to_suffix[direction][1:]
1977    return 'obj'
1978
1979
1980def print_alloc_wrapper(ri, direction):
1981    name = op_prefix(ri, direction)
1982    ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
1983    ri.cw.block_start()
1984    ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
1985    ri.cw.block_end()
1986
1987
1988def print_free_prototype(ri, direction, suffix=';'):
1989    name = op_prefix(ri, direction)
1990    struct_name = name
1991    if ri.type_name_conflict:
1992        struct_name += '_'
1993    arg = free_arg_name(direction)
1994    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
1995
1996
1997def _print_type(ri, direction, struct):
1998    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
1999    if not direction and ri.type_name_conflict:
2000        suffix += '_'
2001
2002    if ri.op_mode == 'dump' and not ri.type_oneside:
2003        suffix += '_dump'
2004
2005    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
2006
2007    if ri.fixed_hdr:
2008        ri.cw.p(ri.fixed_hdr + ' _hdr;')
2009        ri.cw.nl()
2010
2011    meta_started = False
2012    for _, attr in struct.member_list():
2013        for type_filter in ['len', 'bit']:
2014            line = attr.presence_member(ri.ku_space, type_filter)
2015            if line:
2016                if not meta_started:
2017                    ri.cw.block_start(line=f"struct")
2018                    meta_started = True
2019                ri.cw.p(line)
2020    if meta_started:
2021        ri.cw.block_end(line='_present;')
2022        ri.cw.nl()
2023
2024    for arg in struct.inherited:
2025        ri.cw.p(f"__u32 {arg};")
2026
2027    for _, attr in struct.member_list():
2028        attr.struct_member(ri)
2029
2030    ri.cw.block_end(line=';')
2031    ri.cw.nl()
2032
2033
2034def print_type(ri, direction):
2035    _print_type(ri, direction, ri.struct[direction])
2036
2037
2038def print_type_full(ri, struct):
2039    _print_type(ri, "", struct)
2040
2041
2042def print_type_helpers(ri, direction, deref=False):
2043    print_free_prototype(ri, direction)
2044    ri.cw.nl()
2045
2046    if ri.ku_space == 'user' and direction == 'request':
2047        for _, attr in ri.struct[direction].member_list():
2048            attr.setter(ri, ri.attr_set, direction, deref=deref)
2049    ri.cw.nl()
2050
2051
2052def print_req_type_helpers(ri):
2053    if ri.type_empty("request"):
2054        return
2055    print_alloc_wrapper(ri, "request")
2056    print_type_helpers(ri, "request")
2057
2058
2059def print_rsp_type_helpers(ri):
2060    if 'reply' not in ri.op[ri.op_mode]:
2061        return
2062    print_type_helpers(ri, "reply")
2063
2064
2065def print_parse_prototype(ri, direction, terminate=True):
2066    suffix = "_rsp" if direction == "reply" else "_req"
2067    term = ';' if terminate else ''
2068
2069    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
2070                          ['const struct nlattr **tb',
2071                           f"struct {ri.op.render_name}{suffix} *req"],
2072                          suffix=term)
2073
2074
2075def print_req_type(ri):
2076    if ri.type_empty("request"):
2077        return
2078    print_type(ri, "request")
2079
2080
2081def print_req_free(ri):
2082    if 'request' not in ri.op[ri.op_mode]:
2083        return
2084    _free_type(ri, 'request', ri.struct['request'])
2085
2086
2087def print_rsp_type(ri):
2088    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2089        direction = 'reply'
2090    elif ri.op_mode == 'event':
2091        direction = 'reply'
2092    else:
2093        return
2094    print_type(ri, direction)
2095
2096
2097def print_wrapped_type(ri):
2098    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2099    if ri.op_mode == 'dump':
2100        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2101    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2102        ri.cw.p('__u16 family;')
2103        ri.cw.p('__u8 cmd;')
2104        ri.cw.p('struct ynl_ntf_base_type *next;')
2105        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2106    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2107    ri.cw.block_end(line=';')
2108    ri.cw.nl()
2109    print_free_prototype(ri, 'reply')
2110    ri.cw.nl()
2111
2112
2113def _free_type_members_iter(ri, struct):
2114    for _, attr in struct.member_list():
2115        if attr.free_needs_iter():
2116            ri.cw.p('unsigned int i;')
2117            ri.cw.nl()
2118            break
2119
2120
2121def _free_type_members(ri, var, struct, ref=''):
2122    for _, attr in struct.member_list():
2123        attr.free(ri, var, ref)
2124
2125
2126def _free_type(ri, direction, struct):
2127    var = free_arg_name(direction)
2128
2129    print_free_prototype(ri, direction, suffix='')
2130    ri.cw.block_start()
2131    _free_type_members_iter(ri, struct)
2132    _free_type_members(ri, var, struct)
2133    if direction:
2134        ri.cw.p(f'free({var});')
2135    ri.cw.block_end()
2136    ri.cw.nl()
2137
2138
2139def free_rsp_nested_prototype(ri):
2140        print_free_prototype(ri, "")
2141
2142
2143def free_rsp_nested(ri, struct):
2144    _free_type(ri, "", struct)
2145
2146
2147def print_rsp_free(ri):
2148    if 'reply' not in ri.op[ri.op_mode]:
2149        return
2150    _free_type(ri, 'reply', ri.struct['reply'])
2151
2152
2153def print_dump_type_free(ri):
2154    sub_type = type_name(ri, 'reply')
2155
2156    print_free_prototype(ri, 'reply', suffix='')
2157    ri.cw.block_start()
2158    ri.cw.p(f"{sub_type} *next = rsp;")
2159    ri.cw.nl()
2160    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2161    _free_type_members_iter(ri, ri.struct['reply'])
2162    ri.cw.p('rsp = next;')
2163    ri.cw.p('next = rsp->next;')
2164    ri.cw.nl()
2165
2166    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2167    ri.cw.p(f'free(rsp);')
2168    ri.cw.block_end()
2169    ri.cw.block_end()
2170    ri.cw.nl()
2171
2172
2173def print_ntf_type_free(ri):
2174    print_free_prototype(ri, 'reply', suffix='')
2175    ri.cw.block_start()
2176    _free_type_members_iter(ri, ri.struct['reply'])
2177    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2178    ri.cw.p(f'free(rsp);')
2179    ri.cw.block_end()
2180    ri.cw.nl()
2181
2182
2183def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2184    if terminate and ri and policy_should_be_static(struct.family):
2185        return
2186
2187    if terminate:
2188        prefix = 'extern '
2189    else:
2190        if ri and policy_should_be_static(struct.family):
2191            prefix = 'static '
2192        else:
2193            prefix = ''
2194
2195    suffix = ';' if terminate else ' = {'
2196
2197    max_attr = struct.attr_max_val
2198    if ri:
2199        name = ri.op.render_name
2200        if ri.op.dual_policy:
2201            name += '_' + ri.op_mode
2202    else:
2203        name = struct.render_name
2204    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2205
2206
2207def print_req_policy(cw, struct, ri=None):
2208    if ri and ri.op:
2209        cw.ifdef_block(ri.op.get('config-cond', None))
2210    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2211    for _, arg in struct.member_list():
2212        arg.attr_policy(cw)
2213    cw.p("};")
2214    cw.ifdef_block(None)
2215    cw.nl()
2216
2217
2218def kernel_can_gen_family_struct(family):
2219    return family.proto == 'genetlink'
2220
2221
2222def policy_should_be_static(family):
2223    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2224
2225
2226def print_kernel_policy_ranges(family, cw):
2227    first = True
2228    for _, attr_set in family.attr_sets.items():
2229        if attr_set.subset_of:
2230            continue
2231
2232        for _, attr in attr_set.items():
2233            if not attr.request:
2234                continue
2235            if 'full-range' not in attr.checks:
2236                continue
2237
2238            if first:
2239                cw.p('/* Integer value ranges */')
2240                first = False
2241
2242            sign = '' if attr.type[0] == 'u' else '_signed'
2243            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2244            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2245            members = []
2246            if 'min' in attr.checks:
2247                members.append(('min', attr.get_limit_str('min', suffix=suffix)))
2248            if 'max' in attr.checks:
2249                members.append(('max', attr.get_limit_str('max', suffix=suffix)))
2250            cw.write_struct_init(members)
2251            cw.block_end(line=';')
2252            cw.nl()
2253
2254
2255def print_kernel_op_table_fwd(family, cw, terminate):
2256    exported = not kernel_can_gen_family_struct(family)
2257
2258    if not terminate or exported:
2259        cw.p(f"/* Ops table for {family.ident_name} */")
2260
2261        pol_to_struct = {'global': 'genl_small_ops',
2262                         'per-op': 'genl_ops',
2263                         'split': 'genl_split_ops'}
2264        struct_type = pol_to_struct[family.kernel_policy]
2265
2266        if not exported:
2267            cnt = ""
2268        elif family.kernel_policy == 'split':
2269            cnt = 0
2270            for op in family.ops.values():
2271                if 'do' in op:
2272                    cnt += 1
2273                if 'dump' in op:
2274                    cnt += 1
2275        else:
2276            cnt = len(family.ops)
2277
2278        qual = 'static const' if not exported else 'const'
2279        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2280        if terminate:
2281            cw.p(f"extern {line};")
2282        else:
2283            cw.block_start(line=line + ' =')
2284
2285    if not terminate:
2286        return
2287
2288    cw.nl()
2289    for name in family.hooks['pre']['do']['list']:
2290        cw.write_func_prot('int', c_lower(name),
2291                           ['const struct genl_split_ops *ops',
2292                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2293    for name in family.hooks['post']['do']['list']:
2294        cw.write_func_prot('void', c_lower(name),
2295                           ['const struct genl_split_ops *ops',
2296                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2297    for name in family.hooks['pre']['dump']['list']:
2298        cw.write_func_prot('int', c_lower(name),
2299                           ['struct netlink_callback *cb'], suffix=';')
2300    for name in family.hooks['post']['dump']['list']:
2301        cw.write_func_prot('int', c_lower(name),
2302                           ['struct netlink_callback *cb'], suffix=';')
2303
2304    cw.nl()
2305
2306    for op_name, op in family.ops.items():
2307        if op.is_async:
2308            continue
2309
2310        if 'do' in op:
2311            name = c_lower(f"{family.ident_name}-nl-{op_name}-doit")
2312            cw.write_func_prot('int', name,
2313                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2314
2315        if 'dump' in op:
2316            name = c_lower(f"{family.ident_name}-nl-{op_name}-dumpit")
2317            cw.write_func_prot('int', name,
2318                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2319    cw.nl()
2320
2321
2322def print_kernel_op_table_hdr(family, cw):
2323    print_kernel_op_table_fwd(family, cw, terminate=True)
2324
2325
2326def print_kernel_op_table(family, cw):
2327    print_kernel_op_table_fwd(family, cw, terminate=False)
2328    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2329        for op_name, op in family.ops.items():
2330            if op.is_async:
2331                continue
2332
2333            cw.ifdef_block(op.get('config-cond', None))
2334            cw.block_start()
2335            members = [('cmd', op.enum_name)]
2336            if 'dont-validate' in op:
2337                members.append(('validate',
2338                                ' | '.join([c_upper('genl-dont-validate-' + x)
2339                                            for x in op['dont-validate']])), )
2340            for op_mode in ['do', 'dump']:
2341                if op_mode in op:
2342                    name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2343                    members.append((op_mode + 'it', name))
2344            if family.kernel_policy == 'per-op':
2345                struct = Struct(family, op['attribute-set'],
2346                                type_list=op['do']['request']['attributes'])
2347
2348                name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2349                members.append(('policy', name))
2350                members.append(('maxattr', struct.attr_max_val.enum_name))
2351            if 'flags' in op:
2352                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2353            cw.write_struct_init(members)
2354            cw.block_end(line=',')
2355    elif family.kernel_policy == 'split':
2356        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2357                    'dump': {'pre': 'start', 'post': 'done'}}
2358
2359        for op_name, op in family.ops.items():
2360            for op_mode in ['do', 'dump']:
2361                if op.is_async or op_mode not in op:
2362                    continue
2363
2364                cw.ifdef_block(op.get('config-cond', None))
2365                cw.block_start()
2366                members = [('cmd', op.enum_name)]
2367                if 'dont-validate' in op:
2368                    dont_validate = []
2369                    for x in op['dont-validate']:
2370                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2371                            continue
2372                        if op_mode == "dump" and x == 'strict':
2373                            continue
2374                        dont_validate.append(x)
2375
2376                    if dont_validate:
2377                        members.append(('validate',
2378                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2379                                                    for x in dont_validate])), )
2380                name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2381                if 'pre' in op[op_mode]:
2382                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2383                members.append((op_mode + 'it', name))
2384                if 'post' in op[op_mode]:
2385                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2386                if 'request' in op[op_mode]:
2387                    struct = Struct(family, op['attribute-set'],
2388                                    type_list=op[op_mode]['request']['attributes'])
2389
2390                    if op.dual_policy:
2391                        name = c_lower(f"{family.ident_name}-{op_name}-{op_mode}-nl-policy")
2392                    else:
2393                        name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2394                    members.append(('policy', name))
2395                    members.append(('maxattr', struct.attr_max_val.enum_name))
2396                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2397                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2398                cw.write_struct_init(members)
2399                cw.block_end(line=',')
2400    cw.ifdef_block(None)
2401
2402    cw.block_end(line=';')
2403    cw.nl()
2404
2405
2406def print_kernel_mcgrp_hdr(family, cw):
2407    if not family.mcgrps['list']:
2408        return
2409
2410    cw.block_start('enum')
2411    for grp in family.mcgrps['list']:
2412        grp_id = c_upper(f"{family.ident_name}-nlgrp-{grp['name']},")
2413        cw.p(grp_id)
2414    cw.block_end(';')
2415    cw.nl()
2416
2417
2418def print_kernel_mcgrp_src(family, cw):
2419    if not family.mcgrps['list']:
2420        return
2421
2422    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
2423    for grp in family.mcgrps['list']:
2424        name = grp['name']
2425        grp_id = c_upper(f"{family.ident_name}-nlgrp-{name}")
2426        cw.p('[' + grp_id + '] = { "' + name + '", },')
2427    cw.block_end(';')
2428    cw.nl()
2429
2430
2431def print_kernel_family_struct_hdr(family, cw):
2432    if not kernel_can_gen_family_struct(family):
2433        return
2434
2435    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
2436    cw.nl()
2437    if 'sock-priv' in family.kernel_family:
2438        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
2439        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
2440        cw.nl()
2441
2442
2443def print_kernel_family_struct_src(family, cw):
2444    if not kernel_can_gen_family_struct(family):
2445        return
2446
2447    if 'sock-priv' in family.kernel_family:
2448        # Generate "trampolines" to make CFI happy
2449        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_init",
2450                      [f"{family.c_name}_nl_sock_priv_init(priv);"],
2451                      ["void *priv"])
2452        cw.nl()
2453        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_destroy",
2454                      [f"{family.c_name}_nl_sock_priv_destroy(priv);"],
2455                      ["void *priv"])
2456        cw.nl()
2457
2458    cw.block_start(f"struct genl_family {family.ident_name}_nl_family __ro_after_init =")
2459    cw.p('.name\t\t= ' + family.fam_key + ',')
2460    cw.p('.version\t= ' + family.ver_key + ',')
2461    cw.p('.netnsok\t= true,')
2462    cw.p('.parallel_ops\t= true,')
2463    cw.p('.module\t\t= THIS_MODULE,')
2464    if family.kernel_policy == 'per-op':
2465        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
2466        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2467    elif family.kernel_policy == 'split':
2468        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
2469        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2470    if family.mcgrps['list']:
2471        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
2472        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
2473    if 'sock-priv' in family.kernel_family:
2474        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
2475        cw.p(f'.sock_priv_init\t= __{family.c_name}_nl_sock_priv_init,')
2476        cw.p(f'.sock_priv_destroy = __{family.c_name}_nl_sock_priv_destroy,')
2477    cw.block_end(';')
2478
2479
2480def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2481    start_line = 'enum'
2482    if enum_name in obj:
2483        if obj[enum_name]:
2484            start_line = 'enum ' + c_lower(obj[enum_name])
2485    elif ckey and ckey in obj:
2486        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
2487    cw.block_start(line=start_line)
2488
2489
2490def render_uapi_unified(family, cw, max_by_define, separate_ntf):
2491    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2492    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2493    max_value = f"({cnt_name} - 1)"
2494
2495    uapi_enum_start(family, cw, family['operations'], 'enum-name')
2496    val = 0
2497    for op in family.msgs.values():
2498        if separate_ntf and ('notify' in op or 'event' in op):
2499            continue
2500
2501        suffix = ','
2502        if op.value != val:
2503            suffix = f" = {op.value},"
2504            val = op.value
2505        cw.p(op.enum_name + suffix)
2506        val += 1
2507    cw.nl()
2508    cw.p(cnt_name + ('' if max_by_define else ','))
2509    if not max_by_define:
2510        cw.p(f"{max_name} = {max_value}")
2511    cw.block_end(line=';')
2512    if max_by_define:
2513        cw.p(f"#define {max_name} {max_value}")
2514    cw.nl()
2515
2516
2517def render_uapi_directional(family, cw, max_by_define):
2518    max_name = f"{family.op_prefix}USER_MAX"
2519    cnt_name = f"__{family.op_prefix}USER_CNT"
2520    max_value = f"({cnt_name} - 1)"
2521
2522    cw.block_start(line='enum')
2523    cw.p(c_upper(f'{family.name}_MSG_USER_NONE = 0,'))
2524    val = 0
2525    for op in family.msgs.values():
2526        if 'do' in op and 'event' not in op:
2527            suffix = ','
2528            if op.value and op.value != val:
2529                suffix = f" = {op.value},"
2530                val = op.value
2531            cw.p(op.enum_name + suffix)
2532            val += 1
2533    cw.nl()
2534    cw.p(cnt_name + ('' if max_by_define else ','))
2535    if not max_by_define:
2536        cw.p(f"{max_name} = {max_value}")
2537    cw.block_end(line=';')
2538    if max_by_define:
2539        cw.p(f"#define {max_name} {max_value}")
2540    cw.nl()
2541
2542    max_name = f"{family.op_prefix}KERNEL_MAX"
2543    cnt_name = f"__{family.op_prefix}KERNEL_CNT"
2544    max_value = f"({cnt_name} - 1)"
2545
2546    cw.block_start(line='enum')
2547    cw.p(c_upper(f'{family.name}_MSG_KERNEL_NONE = 0,'))
2548    val = 0
2549    for op in family.msgs.values():
2550        if ('do' in op and 'reply' in op['do']) or 'notify' in op or 'event' in op:
2551            enum_name = op.enum_name
2552            if 'event' not in op and 'notify' not in op:
2553                enum_name = f'{enum_name}_REPLY'
2554
2555            suffix = ','
2556            if op.value and op.value != val:
2557                suffix = f" = {op.value},"
2558                val = op.value
2559            cw.p(enum_name + suffix)
2560            val += 1
2561    cw.nl()
2562    cw.p(cnt_name + ('' if max_by_define else ','))
2563    if not max_by_define:
2564        cw.p(f"{max_name} = {max_value}")
2565    cw.block_end(line=';')
2566    if max_by_define:
2567        cw.p(f"#define {max_name} {max_value}")
2568    cw.nl()
2569
2570
2571def render_uapi(family, cw):
2572    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
2573    hdr_prot = hdr_prot.replace('/', '_')
2574    cw.p('#ifndef ' + hdr_prot)
2575    cw.p('#define ' + hdr_prot)
2576    cw.nl()
2577
2578    defines = [(family.fam_key, family["name"]),
2579               (family.ver_key, family.get('version', 1))]
2580    cw.writes_defines(defines)
2581    cw.nl()
2582
2583    defines = []
2584    for const in family['definitions']:
2585        if const.get('header'):
2586            continue
2587
2588        if const['type'] != 'const':
2589            cw.writes_defines(defines)
2590            defines = []
2591            cw.nl()
2592
2593        # Write kdoc for enum and flags (one day maybe also structs)
2594        if const['type'] == 'enum' or const['type'] == 'flags':
2595            enum = family.consts[const['name']]
2596
2597            if enum.header:
2598                continue
2599
2600            if enum.has_doc():
2601                if enum.has_entry_doc():
2602                    cw.p('/**')
2603                    doc = ''
2604                    if 'doc' in enum:
2605                        doc = ' - ' + enum['doc']
2606                    cw.write_doc_line(enum.enum_name + doc)
2607                else:
2608                    cw.p('/*')
2609                    cw.write_doc_line(enum['doc'], indent=False)
2610                for entry in enum.entries.values():
2611                    if entry.has_doc():
2612                        doc = '@' + entry.c_name + ': ' + entry['doc']
2613                        cw.write_doc_line(doc)
2614                cw.p(' */')
2615
2616            uapi_enum_start(family, cw, const, 'name')
2617            name_pfx = const.get('name-prefix', f"{family.ident_name}-{const['name']}-")
2618            for entry in enum.entries.values():
2619                suffix = ','
2620                if entry.value_change:
2621                    suffix = f" = {entry.user_value()}" + suffix
2622                cw.p(entry.c_name + suffix)
2623
2624            if const.get('render-max', False):
2625                cw.nl()
2626                cw.p('/* private: */')
2627                if const['type'] == 'flags':
2628                    max_name = c_upper(name_pfx + 'mask')
2629                    max_val = f' = {enum.get_mask()},'
2630                    cw.p(max_name + max_val)
2631                else:
2632                    cnt_name = enum.enum_cnt_name
2633                    max_name = c_upper(name_pfx + 'max')
2634                    if not cnt_name:
2635                        cnt_name = '__' + name_pfx + 'max'
2636                    cw.p(c_upper(cnt_name) + ',')
2637                    cw.p(max_name + ' = (' + c_upper(cnt_name) + ' - 1)')
2638            cw.block_end(line=';')
2639            cw.nl()
2640        elif const['type'] == 'const':
2641            defines.append([c_upper(family.get('c-define-name',
2642                                               f"{family.ident_name}-{const['name']}")),
2643                            const['value']])
2644
2645    if defines:
2646        cw.writes_defines(defines)
2647        cw.nl()
2648
2649    max_by_define = family.get('max-by-define', False)
2650
2651    for _, attr_set in family.attr_sets.items():
2652        if attr_set.subset_of:
2653            continue
2654
2655        max_value = f"({attr_set.cnt_name} - 1)"
2656
2657        val = 0
2658        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2659        for _, attr in attr_set.items():
2660            suffix = ','
2661            if attr.value != val:
2662                suffix = f" = {attr.value},"
2663                val = attr.value
2664            val += 1
2665            cw.p(attr.enum_name + suffix)
2666        if attr_set.items():
2667            cw.nl()
2668        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
2669        if not max_by_define:
2670            cw.p(f"{attr_set.max_name} = {max_value}")
2671        cw.block_end(line=';')
2672        if max_by_define:
2673            cw.p(f"#define {attr_set.max_name} {max_value}")
2674        cw.nl()
2675
2676    # Commands
2677    separate_ntf = 'async-prefix' in family['operations']
2678
2679    if family.msg_id_model == 'unified':
2680        render_uapi_unified(family, cw, max_by_define, separate_ntf)
2681    elif family.msg_id_model == 'directional':
2682        render_uapi_directional(family, cw, max_by_define)
2683    else:
2684        raise Exception(f'Unsupported message enum-model {family.msg_id_model}')
2685
2686    if separate_ntf:
2687        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2688        for op in family.msgs.values():
2689            if separate_ntf and not ('notify' in op or 'event' in op):
2690                continue
2691
2692            suffix = ','
2693            if 'value' in op:
2694                suffix = f" = {op['value']},"
2695            cw.p(op.enum_name + suffix)
2696        cw.block_end(line=';')
2697        cw.nl()
2698
2699    # Multicast
2700    defines = []
2701    for grp in family.mcgrps['list']:
2702        name = grp['name']
2703        defines.append([c_upper(grp.get('c-define-name', f"{family.ident_name}-mcgrp-{name}")),
2704                        f'{name}'])
2705    cw.nl()
2706    if defines:
2707        cw.writes_defines(defines)
2708        cw.nl()
2709
2710    cw.p(f'#endif /* {hdr_prot} */')
2711
2712
2713def _render_user_ntf_entry(ri, op):
2714    ri.cw.block_start(line=f"[{op.enum_name}] = ")
2715    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
2716    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
2717    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
2718    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
2719    ri.cw.block_end(line=',')
2720
2721
2722def render_user_family(family, cw, prototype):
2723    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
2724    if prototype:
2725        cw.p(f'extern {symbol};')
2726        return
2727
2728    if family.ntfs:
2729        cw.block_start(line=f"static const struct ynl_ntf_info {family.c_name}_ntf_info[] = ")
2730        for ntf_op_name, ntf_op in family.ntfs.items():
2731            if 'notify' in ntf_op:
2732                op = family.ops[ntf_op['notify']]
2733                ri = RenderInfo(cw, family, "user", op, "notify")
2734            elif 'event' in ntf_op:
2735                ri = RenderInfo(cw, family, "user", ntf_op, "event")
2736            else:
2737                raise Exception('Invalid notification ' + ntf_op_name)
2738            _render_user_ntf_entry(ri, ntf_op)
2739        for op_name, op in family.ops.items():
2740            if 'event' not in op:
2741                continue
2742            ri = RenderInfo(cw, family, "user", op, "event")
2743            _render_user_ntf_entry(ri, op)
2744        cw.block_end(line=";")
2745        cw.nl()
2746
2747    cw.block_start(f'{symbol} = ')
2748    cw.p(f'.name\t\t= "{family.c_name}",')
2749    if family.is_classic():
2750        cw.p(f'.is_classic\t= true,')
2751        cw.p(f'.classic_id\t= {family.get("protonum")},')
2752    if family.is_classic():
2753        cw.p(f'.hdr_len\t= sizeof(struct {c_lower(family.fixed_header)}),')
2754    elif family.fixed_header:
2755        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
2756    else:
2757        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
2758    if family.ntfs:
2759        cw.p(f".ntf_info\t= {family.c_name}_ntf_info,")
2760        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family.c_name}_ntf_info),")
2761    cw.block_end(line=';')
2762
2763
2764def family_contains_bitfield32(family):
2765    for _, attr_set in family.attr_sets.items():
2766        if attr_set.subset_of:
2767            continue
2768        for _, attr in attr_set.items():
2769            if attr.type == "bitfield32":
2770                return True
2771    return False
2772
2773
2774def find_kernel_root(full_path):
2775    sub_path = ''
2776    while True:
2777        sub_path = os.path.join(os.path.basename(full_path), sub_path)
2778        full_path = os.path.dirname(full_path)
2779        maintainers = os.path.join(full_path, "MAINTAINERS")
2780        if os.path.exists(maintainers):
2781            return full_path, sub_path[:-1]
2782
2783
2784def main():
2785    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2786    parser.add_argument('--mode', dest='mode', type=str, required=True,
2787                        choices=('user', 'kernel', 'uapi'))
2788    parser.add_argument('--spec', dest='spec', type=str, required=True)
2789    parser.add_argument('--header', dest='header', action='store_true', default=None)
2790    parser.add_argument('--source', dest='header', action='store_false')
2791    parser.add_argument('--user-header', nargs='+', default=[])
2792    parser.add_argument('--cmp-out', action='store_true', default=None,
2793                        help='Do not overwrite the output file if the new output is identical to the old')
2794    parser.add_argument('--exclude-op', action='append', default=[])
2795    parser.add_argument('-o', dest='out_file', type=str, default=None)
2796    args = parser.parse_args()
2797
2798    if args.header is None:
2799        parser.error("--header or --source is required")
2800
2801    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
2802
2803    try:
2804        parsed = Family(args.spec, exclude_ops)
2805        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
2806            print('Spec license:', parsed.license)
2807            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
2808            os.sys.exit(1)
2809    except yaml.YAMLError as exc:
2810        print(exc)
2811        os.sys.exit(1)
2812        return
2813
2814    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
2815
2816    _, spec_kernel = find_kernel_root(args.spec)
2817    if args.mode == 'uapi' or args.header:
2818        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
2819    else:
2820        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
2821    cw.p("/* Do not edit directly, auto-generated from: */")
2822    cw.p(f"/*\t{spec_kernel} */")
2823    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2824    if args.exclude_op or args.user_header:
2825        line = ''
2826        line += ' --user-header '.join([''] + args.user_header)
2827        line += ' --exclude-op '.join([''] + args.exclude_op)
2828        cw.p(f'/* YNL-ARG{line} */')
2829    cw.nl()
2830
2831    if args.mode == 'uapi':
2832        render_uapi(parsed, cw)
2833        return
2834
2835    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
2836    if args.header:
2837        cw.p('#ifndef ' + hdr_prot)
2838        cw.p('#define ' + hdr_prot)
2839        cw.nl()
2840
2841    if args.out_file:
2842        hdr_file = os.path.basename(args.out_file[:-2]) + ".h"
2843    else:
2844        hdr_file = "generated_header_file.h"
2845
2846    if args.mode == 'kernel':
2847        cw.p('#include <net/netlink.h>')
2848        cw.p('#include <net/genetlink.h>')
2849        cw.nl()
2850        if not args.header:
2851            if args.out_file:
2852                cw.p(f'#include "{hdr_file}"')
2853            cw.nl()
2854        headers = ['uapi/' + parsed.uapi_header]
2855        headers += parsed.kernel_family.get('headers', [])
2856    else:
2857        cw.p('#include <stdlib.h>')
2858        cw.p('#include <string.h>')
2859        if args.header:
2860            cw.p('#include <linux/types.h>')
2861            if family_contains_bitfield32(parsed):
2862                cw.p('#include <linux/netlink.h>')
2863        else:
2864            cw.p(f'#include "{hdr_file}"')
2865            cw.p('#include "ynl.h"')
2866        headers = []
2867    for definition in parsed['definitions']:
2868        if 'header' in definition:
2869            headers.append(definition['header'])
2870    if args.mode == 'user':
2871        headers.append(parsed.uapi_header)
2872    seen_header = []
2873    for one in headers:
2874        if one not in seen_header:
2875            cw.p(f"#include <{one}>")
2876            seen_header.append(one)
2877    cw.nl()
2878
2879    if args.mode == "user":
2880        if not args.header:
2881            cw.p("#include <linux/genetlink.h>")
2882            cw.nl()
2883            for one in args.user_header:
2884                cw.p(f'#include "{one}"')
2885        else:
2886            cw.p('struct ynl_sock;')
2887            cw.nl()
2888            render_user_family(parsed, cw, True)
2889        cw.nl()
2890
2891    if args.mode == "kernel":
2892        if args.header:
2893            for _, struct in sorted(parsed.pure_nested_structs.items()):
2894                if struct.request:
2895                    cw.p('/* Common nested types */')
2896                    break
2897            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2898                if struct.request:
2899                    print_req_policy_fwd(cw, struct)
2900            cw.nl()
2901
2902            if parsed.kernel_policy == 'global':
2903                cw.p(f"/* Global operation policy for {parsed.name} */")
2904
2905                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2906                print_req_policy_fwd(cw, struct)
2907                cw.nl()
2908
2909            if parsed.kernel_policy in {'per-op', 'split'}:
2910                for op_name, op in parsed.ops.items():
2911                    if 'do' in op and 'event' not in op:
2912                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
2913                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2914                        cw.nl()
2915
2916            print_kernel_op_table_hdr(parsed, cw)
2917            print_kernel_mcgrp_hdr(parsed, cw)
2918            print_kernel_family_struct_hdr(parsed, cw)
2919        else:
2920            print_kernel_policy_ranges(parsed, cw)
2921
2922            for _, struct in sorted(parsed.pure_nested_structs.items()):
2923                if struct.request:
2924                    cw.p('/* Common nested types */')
2925                    break
2926            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2927                if struct.request:
2928                    print_req_policy(cw, struct)
2929            cw.nl()
2930
2931            if parsed.kernel_policy == 'global':
2932                cw.p(f"/* Global operation policy for {parsed.name} */")
2933
2934                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2935                print_req_policy(cw, struct)
2936                cw.nl()
2937
2938            for op_name, op in parsed.ops.items():
2939                if parsed.kernel_policy in {'per-op', 'split'}:
2940                    for op_mode in ['do', 'dump']:
2941                        if op_mode in op and 'request' in op[op_mode]:
2942                            cw.p(f"/* {op.enum_name} - {op_mode} */")
2943                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
2944                            print_req_policy(cw, ri.struct['request'], ri=ri)
2945                            cw.nl()
2946
2947            print_kernel_op_table(parsed, cw)
2948            print_kernel_mcgrp_src(parsed, cw)
2949            print_kernel_family_struct_src(parsed, cw)
2950
2951    if args.mode == "user":
2952        if args.header:
2953            cw.p('/* Enums */')
2954            put_op_name_fwd(parsed, cw)
2955
2956            for name, const in parsed.consts.items():
2957                if isinstance(const, EnumSet):
2958                    put_enum_to_str_fwd(parsed, cw, const)
2959            cw.nl()
2960
2961            cw.p('/* Common nested types */')
2962            for attr_set, struct in parsed.pure_nested_structs.items():
2963                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2964                print_type_full(ri, struct)
2965
2966            for op_name, op in parsed.ops.items():
2967                cw.p(f"/* ============== {op.enum_name} ============== */")
2968
2969                if 'do' in op and 'event' not in op:
2970                    cw.p(f"/* {op.enum_name} - do */")
2971                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2972                    print_req_type(ri)
2973                    print_req_type_helpers(ri)
2974                    cw.nl()
2975                    print_rsp_type(ri)
2976                    print_rsp_type_helpers(ri)
2977                    cw.nl()
2978                    print_req_prototype(ri)
2979                    cw.nl()
2980
2981                if 'dump' in op:
2982                    cw.p(f"/* {op.enum_name} - dump */")
2983                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
2984                    print_req_type(ri)
2985                    print_req_type_helpers(ri)
2986                    if not ri.type_consistent or ri.type_oneside:
2987                        print_rsp_type(ri)
2988                    print_wrapped_type(ri)
2989                    print_dump_prototype(ri)
2990                    cw.nl()
2991
2992                if op.has_ntf:
2993                    cw.p(f"/* {op.enum_name} - notify */")
2994                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2995                    if not ri.type_consistent:
2996                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2997                    print_wrapped_type(ri)
2998
2999            for op_name, op in parsed.ntfs.items():
3000                if 'event' in op:
3001                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
3002                    cw.p(f"/* {op.enum_name} - event */")
3003                    print_rsp_type(ri)
3004                    cw.nl()
3005                    print_wrapped_type(ri)
3006            cw.nl()
3007        else:
3008            cw.p('/* Enums */')
3009            put_op_name(parsed, cw)
3010
3011            for name, const in parsed.consts.items():
3012                if isinstance(const, EnumSet):
3013                    put_enum_to_str(parsed, cw, const)
3014            cw.nl()
3015
3016            has_recursive_nests = False
3017            cw.p('/* Policies */')
3018            for struct in parsed.pure_nested_structs.values():
3019                if struct.recursive:
3020                    put_typol_fwd(cw, struct)
3021                    has_recursive_nests = True
3022            if has_recursive_nests:
3023                cw.nl()
3024            for name in parsed.pure_nested_structs:
3025                struct = Struct(parsed, name)
3026                put_typol(cw, struct)
3027            for name in parsed.root_sets:
3028                struct = Struct(parsed, name)
3029                put_typol(cw, struct)
3030
3031            cw.p('/* Common nested types */')
3032            if has_recursive_nests:
3033                for attr_set, struct in parsed.pure_nested_structs.items():
3034                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3035                    free_rsp_nested_prototype(ri)
3036                    if struct.request:
3037                        put_req_nested_prototype(ri, struct)
3038                    if struct.reply:
3039                        parse_rsp_nested_prototype(ri, struct)
3040                cw.nl()
3041            for attr_set, struct in parsed.pure_nested_structs.items():
3042                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3043
3044                free_rsp_nested(ri, struct)
3045                if struct.request:
3046                    put_req_nested(ri, struct)
3047                if struct.reply:
3048                    parse_rsp_nested(ri, struct)
3049
3050            for op_name, op in parsed.ops.items():
3051                cw.p(f"/* ============== {op.enum_name} ============== */")
3052                if 'do' in op and 'event' not in op:
3053                    cw.p(f"/* {op.enum_name} - do */")
3054                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3055                    print_req_free(ri)
3056                    print_rsp_free(ri)
3057                    parse_rsp_msg(ri)
3058                    print_req(ri)
3059                    cw.nl()
3060
3061                if 'dump' in op:
3062                    cw.p(f"/* {op.enum_name} - dump */")
3063                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
3064                    if not ri.type_consistent or ri.type_oneside:
3065                        parse_rsp_msg(ri, deref=True)
3066                    print_req_free(ri)
3067                    print_dump_type_free(ri)
3068                    print_dump(ri)
3069                    cw.nl()
3070
3071                if op.has_ntf:
3072                    cw.p(f"/* {op.enum_name} - notify */")
3073                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3074                    if not ri.type_consistent:
3075                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3076                    print_ntf_type_free(ri)
3077
3078            for op_name, op in parsed.ntfs.items():
3079                if 'event' in op:
3080                    cw.p(f"/* {op.enum_name} - event */")
3081
3082                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3083                    parse_rsp_msg(ri)
3084
3085                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
3086                    print_ntf_type_free(ri)
3087            cw.nl()
3088            render_user_family(parsed, cw, False)
3089
3090    if args.header:
3091        cw.p(f'#endif /* {hdr_prot} */')
3092
3093
3094if __name__ == "__main__":
3095    main()
3096