xref: /linux/tools/net/ynl/ynl-gen-c.py (revision a634dda26186cf9a51567020fcce52bcba5e1e59)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import filecmp
7import pathlib
8import os
9import re
10import shutil
11import sys
12import tempfile
13import yaml
14
15sys.path.append(pathlib.Path(__file__).resolve().parent.as_posix())
16from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
17
18
19def c_upper(name):
20    return name.upper().replace('-', '_')
21
22
23def c_lower(name):
24    return name.lower().replace('-', '_')
25
26
27def limit_to_number(name):
28    """
29    Turn a string limit like u32-max or s64-min into its numerical value
30    """
31    if name[0] == 'u' and name.endswith('-min'):
32        return 0
33    width = int(name[1:-4])
34    if name[0] == 's':
35        width -= 1
36    value = (1 << width) - 1
37    if name[0] == 's' and name.endswith('-min'):
38        value = -value - 1
39    return value
40
41
42class BaseNlLib:
43    def get_family_id(self):
44        return 'ys->family_id'
45
46
47class Type(SpecAttr):
48    def __init__(self, family, attr_set, attr, value):
49        super().__init__(family, attr_set, attr, value)
50
51        self.attr = attr
52        self.attr_set = attr_set
53        self.type = attr['type']
54        self.checks = attr.get('checks', {})
55
56        self.request = False
57        self.reply = False
58
59        if 'len' in attr:
60            self.len = attr['len']
61
62        if 'nested-attributes' in attr:
63            self.nested_attrs = attr['nested-attributes']
64            if self.nested_attrs == family.name:
65                self.nested_render_name = c_lower(f"{family.ident_name}")
66            else:
67                self.nested_render_name = c_lower(f"{family.ident_name}_{self.nested_attrs}")
68
69            if self.nested_attrs in self.family.consts:
70                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
71            else:
72                self.nested_struct_type = 'struct ' + self.nested_render_name
73
74        self.c_name = c_lower(self.name)
75        if self.c_name in _C_KW:
76            self.c_name += '_'
77
78        # Added by resolve():
79        self.enum_name = None
80        delattr(self, "enum_name")
81
82    def get_limit(self, limit, default=None):
83        value = self.checks.get(limit, default)
84        if value is None:
85            return value
86        if isinstance(value, int):
87            return value
88        if value in self.family.consts:
89            raise Exception("Resolving family constants not implemented, yet")
90        return limit_to_number(value)
91
92    def get_limit_str(self, limit, default=None, suffix=''):
93        value = self.checks.get(limit, default)
94        if value is None:
95            return ''
96        if isinstance(value, int):
97            return str(value) + suffix
98        if value in self.family.consts:
99            return c_upper(f"{self.family['name']}-{value}")
100        return c_upper(value)
101
102    def resolve(self):
103        if 'name-prefix' in self.attr:
104            enum_name = f"{self.attr['name-prefix']}{self.name}"
105        else:
106            enum_name = f"{self.attr_set.name_prefix}{self.name}"
107        self.enum_name = c_upper(enum_name)
108
109    def is_multi_val(self):
110        return None
111
112    def is_scalar(self):
113        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
114
115    def is_recursive(self):
116        return False
117
118    def is_recursive_for_op(self, ri):
119        return self.is_recursive() and not ri.op
120
121    def presence_type(self):
122        return 'bit'
123
124    def presence_member(self, space, type_filter):
125        if self.presence_type() != type_filter:
126            return
127
128        if self.presence_type() == 'bit':
129            pfx = '__' if space == 'user' else ''
130            return f"{pfx}u32 {self.c_name}:1;"
131
132        if self.presence_type() == 'len':
133            pfx = '__' if space == 'user' else ''
134            return f"{pfx}u32 {self.c_name}_len;"
135
136    def _complex_member_type(self, ri):
137        return None
138
139    def free_needs_iter(self):
140        return False
141
142    def free(self, ri, var, ref):
143        if self.is_multi_val() or self.presence_type() == 'len':
144            ri.cw.p(f'free({var}->{ref}{self.c_name});')
145
146    def arg_member(self, ri):
147        member = self._complex_member_type(ri)
148        if member:
149            arg = [member + ' *' + self.c_name]
150            if self.presence_type() == 'count':
151                arg += ['unsigned int n_' + self.c_name]
152            return arg
153        raise Exception(f"Struct member not implemented for class type {self.type}")
154
155    def struct_member(self, ri):
156        if self.is_multi_val():
157            ri.cw.p(f"unsigned int n_{self.c_name};")
158        member = self._complex_member_type(ri)
159        if member:
160            ptr = '*' if self.is_multi_val() else ''
161            if self.is_recursive_for_op(ri):
162                ptr = '*'
163            ri.cw.p(f"{member} {ptr}{self.c_name};")
164            return
165        members = self.arg_member(ri)
166        for one in members:
167            ri.cw.p(one + ';')
168
169    def _attr_policy(self, policy):
170        return '{ .type = ' + policy + ', }'
171
172    def attr_policy(self, cw):
173        policy = f'NLA_{c_upper(self.type)}'
174        if self.attr.get('byte-order') == 'big-endian':
175            if self.type in {'u16', 'u32'}:
176                policy = f'NLA_BE{self.type[1:]}'
177
178        spec = self._attr_policy(policy)
179        cw.p(f"\t[{self.enum_name}] = {spec},")
180
181    def _attr_typol(self):
182        raise Exception(f"Type policy not implemented for class type {self.type}")
183
184    def attr_typol(self, cw):
185        typol = self._attr_typol()
186        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
187
188    def _attr_put_line(self, ri, var, line):
189        if self.presence_type() == 'bit':
190            ri.cw.p(f"if ({var}->_present.{self.c_name})")
191        elif self.presence_type() == 'len':
192            ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
193        ri.cw.p(f"{line};")
194
195    def _attr_put_simple(self, ri, var, put_type):
196        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
197        self._attr_put_line(ri, var, line)
198
199    def attr_put(self, ri, var):
200        raise Exception(f"Put not implemented for class type {self.type}")
201
202    def _attr_get(self, ri, var):
203        raise Exception(f"Attr get not implemented for class type {self.type}")
204
205    def attr_get(self, ri, var, first):
206        lines, init_lines, local_vars = self._attr_get(ri, var)
207        if type(lines) is str:
208            lines = [lines]
209        if type(init_lines) is str:
210            init_lines = [init_lines]
211
212        kw = 'if' if first else 'else if'
213        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
214        if local_vars:
215            for local in local_vars:
216                ri.cw.p(local)
217            ri.cw.nl()
218
219        if not self.is_multi_val():
220            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
221            ri.cw.p("return YNL_PARSE_CB_ERROR;")
222            if self.presence_type() == 'bit':
223                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
224
225        if init_lines:
226            ri.cw.nl()
227            for line in init_lines:
228                ri.cw.p(line)
229
230        for line in lines:
231            ri.cw.p(line)
232        ri.cw.block_end()
233        return True
234
235    def _setter_lines(self, ri, member, presence):
236        raise Exception(f"Setter not implemented for class type {self.type}")
237
238    def setter(self, ri, space, direction, deref=False, ref=None):
239        ref = (ref if ref else []) + [self.c_name]
240        var = "req"
241        member = f"{var}->{'.'.join(ref)}"
242
243        code = []
244        presence = ''
245        for i in range(0, len(ref)):
246            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
247            # Every layer below last is a nest, so we know it uses bit presence
248            # last layer is "self" and may be a complex type
249            if i == len(ref) - 1 and self.presence_type() != 'bit':
250                continue
251            code.append(presence + ' = 1;')
252        code += self._setter_lines(ri, member, presence)
253
254        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
255        free = bool([x for x in code if 'free(' in x])
256        alloc = bool([x for x in code if 'alloc(' in x])
257        if free and not alloc:
258            func_name = '__' + func_name
259        ri.cw.write_func('static inline void', func_name, body=code,
260                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
261
262
263class TypeUnused(Type):
264    def presence_type(self):
265        return ''
266
267    def arg_member(self, ri):
268        return []
269
270    def _attr_get(self, ri, var):
271        return ['return YNL_PARSE_CB_ERROR;'], None, None
272
273    def _attr_typol(self):
274        return '.type = YNL_PT_REJECT, '
275
276    def attr_policy(self, cw):
277        pass
278
279    def attr_put(self, ri, var):
280        pass
281
282    def attr_get(self, ri, var, first):
283        pass
284
285    def setter(self, ri, space, direction, deref=False, ref=None):
286        pass
287
288
289class TypePad(Type):
290    def presence_type(self):
291        return ''
292
293    def arg_member(self, ri):
294        return []
295
296    def _attr_typol(self):
297        return '.type = YNL_PT_IGNORE, '
298
299    def attr_put(self, ri, var):
300        pass
301
302    def attr_get(self, ri, var, first):
303        pass
304
305    def attr_policy(self, cw):
306        pass
307
308    def setter(self, ri, space, direction, deref=False, ref=None):
309        pass
310
311
312class TypeScalar(Type):
313    def __init__(self, family, attr_set, attr, value):
314        super().__init__(family, attr_set, attr, value)
315
316        self.byte_order_comment = ''
317        if 'byte-order' in attr:
318            self.byte_order_comment = f" /* {attr['byte-order']} */"
319
320        if 'enum' in self.attr:
321            enum = self.family.consts[self.attr['enum']]
322            low, high = enum.value_range()
323            if 'min' not in self.checks:
324                if low != 0 or self.type[0] == 's':
325                    self.checks['min'] = low
326            if 'max' not in self.checks:
327                self.checks['max'] = high
328
329        if 'min' in self.checks and 'max' in self.checks:
330            if self.get_limit('min') > self.get_limit('max'):
331                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
332            self.checks['range'] = True
333
334        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
335        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
336        if low < 0 and self.type[0] == 'u':
337            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
338        if low < -32768 or high > 32767:
339            self.checks['full-range'] = True
340
341        # Added by resolve():
342        self.is_bitfield = None
343        delattr(self, "is_bitfield")
344        self.type_name = None
345        delattr(self, "type_name")
346
347    def resolve(self):
348        self.resolve_up(super())
349
350        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
351            self.is_bitfield = True
352        elif 'enum' in self.attr:
353            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
354        else:
355            self.is_bitfield = False
356
357        if not self.is_bitfield and 'enum' in self.attr:
358            self.type_name = self.family.consts[self.attr['enum']].user_type
359        elif self.is_auto_scalar:
360            self.type_name = '__' + self.type[0] + '64'
361        else:
362            self.type_name = '__' + self.type
363
364    def _attr_policy(self, policy):
365        if 'flags-mask' in self.checks or self.is_bitfield:
366            if self.is_bitfield:
367                enum = self.family.consts[self.attr['enum']]
368                mask = enum.get_mask(as_flags=True)
369            else:
370                flags = self.family.consts[self.checks['flags-mask']]
371                flag_cnt = len(flags['entries'])
372                mask = (1 << flag_cnt) - 1
373            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
374        elif 'full-range' in self.checks:
375            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
376        elif 'range' in self.checks:
377            return f"NLA_POLICY_RANGE({policy}, {self.get_limit_str('min')}, {self.get_limit_str('max')})"
378        elif 'min' in self.checks:
379            return f"NLA_POLICY_MIN({policy}, {self.get_limit_str('min')})"
380        elif 'max' in self.checks:
381            return f"NLA_POLICY_MAX({policy}, {self.get_limit_str('max')})"
382        return super()._attr_policy(policy)
383
384    def _attr_typol(self):
385        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
386
387    def arg_member(self, ri):
388        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
389
390    def attr_put(self, ri, var):
391        self._attr_put_simple(ri, var, self.type)
392
393    def _attr_get(self, ri, var):
394        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
395
396    def _setter_lines(self, ri, member, presence):
397        return [f"{member} = {self.c_name};"]
398
399
400class TypeFlag(Type):
401    def arg_member(self, ri):
402        return []
403
404    def _attr_typol(self):
405        return '.type = YNL_PT_FLAG, '
406
407    def attr_put(self, ri, var):
408        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
409
410    def _attr_get(self, ri, var):
411        return [], None, None
412
413    def _setter_lines(self, ri, member, presence):
414        return []
415
416
417class TypeString(Type):
418    def arg_member(self, ri):
419        return [f"const char *{self.c_name}"]
420
421    def presence_type(self):
422        return 'len'
423
424    def struct_member(self, ri):
425        ri.cw.p(f"char *{self.c_name};")
426
427    def _attr_typol(self):
428        return f'.type = YNL_PT_NUL_STR, '
429
430    def _attr_policy(self, policy):
431        if 'exact-len' in self.checks:
432            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
433        else:
434            mem = '{ .type = ' + policy
435            if 'max-len' in self.checks:
436                mem += ', .len = ' + self.get_limit_str('max-len')
437            mem += ', }'
438        return mem
439
440    def attr_policy(self, cw):
441        if self.checks.get('unterminated-ok', False):
442            policy = 'NLA_STRING'
443        else:
444            policy = 'NLA_NUL_STRING'
445
446        spec = self._attr_policy(policy)
447        cw.p(f"\t[{self.enum_name}] = {spec},")
448
449    def attr_put(self, ri, var):
450        self._attr_put_simple(ri, var, 'str')
451
452    def _attr_get(self, ri, var):
453        len_mem = var + '->_present.' + self.c_name + '_len'
454        return [f"{len_mem} = len;",
455                f"{var}->{self.c_name} = malloc(len + 1);",
456                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
457                f"{var}->{self.c_name}[len] = 0;"], \
458               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
459               ['unsigned int len;']
460
461    def _setter_lines(self, ri, member, presence):
462        return [f"free({member});",
463                f"{presence}_len = strlen({self.c_name});",
464                f"{member} = malloc({presence}_len + 1);",
465                f'memcpy({member}, {self.c_name}, {presence}_len);',
466                f'{member}[{presence}_len] = 0;']
467
468
469class TypeBinary(Type):
470    def arg_member(self, ri):
471        return [f"const void *{self.c_name}", 'size_t len']
472
473    def presence_type(self):
474        return 'len'
475
476    def struct_member(self, ri):
477        ri.cw.p(f"void *{self.c_name};")
478
479    def _attr_typol(self):
480        return f'.type = YNL_PT_BINARY,'
481
482    def _attr_policy(self, policy):
483        if len(self.checks) == 0:
484            pass
485        elif len(self.checks) == 1:
486            check_name = list(self.checks)[0]
487            if check_name not in {'exact-len', 'min-len', 'max-len'}:
488                raise Exception('Unsupported check for binary type: ' + check_name)
489        else:
490            raise Exception('More than one check for binary type not implemented, yet')
491
492        if len(self.checks) == 0:
493            mem = '{ .type = NLA_BINARY, }'
494        elif 'exact-len' in self.checks:
495            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
496        elif 'min-len' in self.checks:
497            mem = '{ .len = ' + self.get_limit_str('min-len') + ', }'
498        elif 'max-len' in self.checks:
499            mem = 'NLA_POLICY_MAX_LEN(' + self.get_limit_str('max-len') + ')'
500
501        return mem
502
503    def attr_put(self, ri, var):
504        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
505                            f"{var}->{self.c_name}, {var}->_present.{self.c_name}_len)")
506
507    def _attr_get(self, ri, var):
508        len_mem = var + '->_present.' + self.c_name + '_len'
509        return [f"{len_mem} = len;",
510                f"{var}->{self.c_name} = malloc(len);",
511                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
512               ['len = ynl_attr_data_len(attr);'], \
513               ['unsigned int len;']
514
515    def _setter_lines(self, ri, member, presence):
516        return [f"free({member});",
517                f"{presence}_len = len;",
518                f"{member} = malloc({presence}_len);",
519                f'memcpy({member}, {self.c_name}, {presence}_len);']
520
521
522class TypeBitfield32(Type):
523    def _complex_member_type(self, ri):
524        return "struct nla_bitfield32"
525
526    def _attr_typol(self):
527        return f'.type = YNL_PT_BITFIELD32, '
528
529    def _attr_policy(self, policy):
530        if not 'enum' in self.attr:
531            raise Exception('Enum required for bitfield32 attr')
532        enum = self.family.consts[self.attr['enum']]
533        mask = enum.get_mask(as_flags=True)
534        return f"NLA_POLICY_BITFIELD32({mask})"
535
536    def attr_put(self, ri, var):
537        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
538        self._attr_put_line(ri, var, line)
539
540    def _attr_get(self, ri, var):
541        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
542
543    def _setter_lines(self, ri, member, presence):
544        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
545
546
547class TypeNest(Type):
548    def is_recursive(self):
549        return self.family.pure_nested_structs[self.nested_attrs].recursive
550
551    def _complex_member_type(self, ri):
552        return self.nested_struct_type
553
554    def free(self, ri, var, ref):
555        at = '&'
556        if self.is_recursive_for_op(ri):
557            at = ''
558            ri.cw.p(f'if ({var}->{ref}{self.c_name})')
559        ri.cw.p(f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});')
560
561    def _attr_typol(self):
562        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
563
564    def _attr_policy(self, policy):
565        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
566
567    def attr_put(self, ri, var):
568        at = '' if self.is_recursive_for_op(ri) else '&'
569        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
570                            f"{self.enum_name}, {at}{var}->{self.c_name})")
571
572    def _attr_get(self, ri, var):
573        get_lines = [f"if ({self.nested_render_name}_parse(&parg, attr))",
574                     "return YNL_PARSE_CB_ERROR;"]
575        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
576                      f"parg.data = &{var}->{self.c_name};"]
577        return get_lines, init_lines, None
578
579    def setter(self, ri, space, direction, deref=False, ref=None):
580        ref = (ref if ref else []) + [self.c_name]
581
582        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
583            if attr.is_recursive():
584                continue
585            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
586
587
588class TypeMultiAttr(Type):
589    def __init__(self, family, attr_set, attr, value, base_type):
590        super().__init__(family, attr_set, attr, value)
591
592        self.base_type = base_type
593
594    def is_multi_val(self):
595        return True
596
597    def presence_type(self):
598        return 'count'
599
600    def _complex_member_type(self, ri):
601        if 'type' not in self.attr or self.attr['type'] == 'nest':
602            return self.nested_struct_type
603        elif self.attr['type'] in scalars:
604            scalar_pfx = '__' if ri.ku_space == 'user' else ''
605            return scalar_pfx + self.attr['type']
606        else:
607            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
608
609    def free_needs_iter(self):
610        return 'type' not in self.attr or self.attr['type'] == 'nest'
611
612    def free(self, ri, var, ref):
613        if self.attr['type'] in scalars:
614            ri.cw.p(f"free({var}->{ref}{self.c_name});")
615        elif 'type' not in self.attr or self.attr['type'] == 'nest':
616            ri.cw.p(f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)")
617            ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);')
618            ri.cw.p(f"free({var}->{ref}{self.c_name});")
619        else:
620            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
621
622    def _attr_policy(self, policy):
623        return self.base_type._attr_policy(policy)
624
625    def _attr_typol(self):
626        return self.base_type._attr_typol()
627
628    def _attr_get(self, ri, var):
629        return f'n_{self.c_name}++;', None, None
630
631    def attr_put(self, ri, var):
632        if self.attr['type'] in scalars:
633            put_type = self.type
634            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
635            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
636        elif 'type' not in self.attr or self.attr['type'] == 'nest':
637            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
638            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
639                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
640        else:
641            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
642
643    def _setter_lines(self, ri, member, presence):
644        # For multi-attr we have a count, not presence, hack up the presence
645        presence = presence[:-(len('_present.') + len(self.c_name))] + "n_" + self.c_name
646        return [f"free({member});",
647                f"{member} = {self.c_name};",
648                f"{presence} = n_{self.c_name};"]
649
650
651class TypeArrayNest(Type):
652    def is_multi_val(self):
653        return True
654
655    def presence_type(self):
656        return 'count'
657
658    def _complex_member_type(self, ri):
659        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
660            return self.nested_struct_type
661        elif self.attr['sub-type'] in scalars:
662            scalar_pfx = '__' if ri.ku_space == 'user' else ''
663            return scalar_pfx + self.attr['sub-type']
664        else:
665            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
666
667    def _attr_typol(self):
668        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
669
670    def _attr_get(self, ri, var):
671        local_vars = ['const struct nlattr *attr2;']
672        get_lines = [f'attr_{self.c_name} = attr;',
673                     'ynl_attr_for_each_nested(attr2, attr)',
674                     f'\t{var}->n_{self.c_name}++;']
675        return get_lines, None, local_vars
676
677
678class TypeNestTypeValue(Type):
679    def _complex_member_type(self, ri):
680        return self.nested_struct_type
681
682    def _attr_typol(self):
683        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
684
685    def _attr_get(self, ri, var):
686        prev = 'attr'
687        tv_args = ''
688        get_lines = []
689        local_vars = []
690        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
691                      f"parg.data = &{var}->{self.c_name};"]
692        if 'type-value' in self.attr:
693            tv_names = [c_lower(x) for x in self.attr["type-value"]]
694            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
695            local_vars += [f'__u32 {", ".join(tv_names)};']
696            for level in self.attr["type-value"]:
697                level = c_lower(level)
698                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
699                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
700                prev = 'attr_' + level
701
702            tv_args = f", {', '.join(tv_names)}"
703
704        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
705        return get_lines, init_lines, local_vars
706
707
708class Struct:
709    def __init__(self, family, space_name, type_list=None, inherited=None):
710        self.family = family
711        self.space_name = space_name
712        self.attr_set = family.attr_sets[space_name]
713        # Use list to catch comparisons with empty sets
714        self._inherited = inherited if inherited is not None else []
715        self.inherited = []
716
717        self.nested = type_list is None
718        if family.name == c_lower(space_name):
719            self.render_name = c_lower(family.ident_name)
720        else:
721            self.render_name = c_lower(family.ident_name + '-' + space_name)
722        self.struct_name = 'struct ' + self.render_name
723        if self.nested and space_name in family.consts:
724            self.struct_name += '_'
725        self.ptr_name = self.struct_name + ' *'
726        # All attr sets this one contains, directly or multiple levels down
727        self.child_nests = set()
728
729        self.request = False
730        self.reply = False
731        self.recursive = False
732
733        self.attr_list = []
734        self.attrs = dict()
735        if type_list is not None:
736            for t in type_list:
737                self.attr_list.append((t, self.attr_set[t]),)
738        else:
739            for t in self.attr_set:
740                self.attr_list.append((t, self.attr_set[t]),)
741
742        max_val = 0
743        self.attr_max_val = None
744        for name, attr in self.attr_list:
745            if attr.value >= max_val:
746                max_val = attr.value
747                self.attr_max_val = attr
748            self.attrs[name] = attr
749
750    def __iter__(self):
751        yield from self.attrs
752
753    def __getitem__(self, key):
754        return self.attrs[key]
755
756    def member_list(self):
757        return self.attr_list
758
759    def set_inherited(self, new_inherited):
760        if self._inherited != new_inherited:
761            raise Exception("Inheriting different members not supported")
762        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
763
764
765class EnumEntry(SpecEnumEntry):
766    def __init__(self, enum_set, yaml, prev, value_start):
767        super().__init__(enum_set, yaml, prev, value_start)
768
769        if prev:
770            self.value_change = (self.value != prev.value + 1)
771        else:
772            self.value_change = (self.value != 0)
773        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
774
775        # Added by resolve:
776        self.c_name = None
777        delattr(self, "c_name")
778
779    def resolve(self):
780        self.resolve_up(super())
781
782        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
783
784
785class EnumSet(SpecEnumSet):
786    def __init__(self, family, yaml):
787        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
788
789        if 'enum-name' in yaml:
790            if yaml['enum-name']:
791                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
792                self.user_type = self.enum_name
793            else:
794                self.enum_name = None
795        else:
796            self.enum_name = 'enum ' + self.render_name
797
798        if self.enum_name:
799            self.user_type = self.enum_name
800        else:
801            self.user_type = 'int'
802
803        self.value_pfx = yaml.get('name-prefix', f"{family.ident_name}-{yaml['name']}-")
804
805        super().__init__(family, yaml)
806
807    def new_entry(self, entry, prev_entry, value_start):
808        return EnumEntry(self, entry, prev_entry, value_start)
809
810    def value_range(self):
811        low = min([x.value for x in self.entries.values()])
812        high = max([x.value for x in self.entries.values()])
813
814        if high - low + 1 != len(self.entries):
815            raise Exception("Can't get value range for a noncontiguous enum")
816
817        return low, high
818
819
820class AttrSet(SpecAttrSet):
821    def __init__(self, family, yaml):
822        super().__init__(family, yaml)
823
824        if self.subset_of is None:
825            if 'name-prefix' in yaml:
826                pfx = yaml['name-prefix']
827            elif self.name == family.name:
828                pfx = family.ident_name + '-a-'
829            else:
830                pfx = f"{family.ident_name}-a-{self.name}-"
831            self.name_prefix = c_upper(pfx)
832            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
833            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
834        else:
835            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
836            self.max_name = family.attr_sets[self.subset_of].max_name
837            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
838
839        # Added by resolve:
840        self.c_name = None
841        delattr(self, "c_name")
842
843    def resolve(self):
844        self.c_name = c_lower(self.name)
845        if self.c_name in _C_KW:
846            self.c_name += '_'
847        if self.c_name == self.family.c_name:
848            self.c_name = ''
849
850    def new_attr(self, elem, value):
851        if elem['type'] in scalars:
852            t = TypeScalar(self.family, self, elem, value)
853        elif elem['type'] == 'unused':
854            t = TypeUnused(self.family, self, elem, value)
855        elif elem['type'] == 'pad':
856            t = TypePad(self.family, self, elem, value)
857        elif elem['type'] == 'flag':
858            t = TypeFlag(self.family, self, elem, value)
859        elif elem['type'] == 'string':
860            t = TypeString(self.family, self, elem, value)
861        elif elem['type'] == 'binary':
862            t = TypeBinary(self.family, self, elem, value)
863        elif elem['type'] == 'bitfield32':
864            t = TypeBitfield32(self.family, self, elem, value)
865        elif elem['type'] == 'nest':
866            t = TypeNest(self.family, self, elem, value)
867        elif elem['type'] == 'indexed-array' and 'sub-type' in elem:
868            if elem["sub-type"] == 'nest':
869                t = TypeArrayNest(self.family, self, elem, value)
870            else:
871                raise Exception(f'new_attr: unsupported sub-type {elem["sub-type"]}')
872        elif elem['type'] == 'nest-type-value':
873            t = TypeNestTypeValue(self.family, self, elem, value)
874        else:
875            raise Exception(f"No typed class for type {elem['type']}")
876
877        if 'multi-attr' in elem and elem['multi-attr']:
878            t = TypeMultiAttr(self.family, self, elem, value, t)
879
880        return t
881
882
883class Operation(SpecOperation):
884    def __init__(self, family, yaml, req_value, rsp_value):
885        super().__init__(family, yaml, req_value, rsp_value)
886
887        self.render_name = c_lower(family.ident_name + '_' + self.name)
888
889        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
890                         ('dump' in yaml and 'request' in yaml['dump'])
891
892        self.has_ntf = False
893
894        # Added by resolve:
895        self.enum_name = None
896        delattr(self, "enum_name")
897
898    def resolve(self):
899        self.resolve_up(super())
900
901        if not self.is_async:
902            self.enum_name = self.family.op_prefix + c_upper(self.name)
903        else:
904            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
905
906    def mark_has_ntf(self):
907        self.has_ntf = True
908
909
910class Family(SpecFamily):
911    def __init__(self, file_name, exclude_ops):
912        # Added by resolve:
913        self.c_name = None
914        delattr(self, "c_name")
915        self.op_prefix = None
916        delattr(self, "op_prefix")
917        self.async_op_prefix = None
918        delattr(self, "async_op_prefix")
919        self.mcgrps = None
920        delattr(self, "mcgrps")
921        self.consts = None
922        delattr(self, "consts")
923        self.hooks = None
924        delattr(self, "hooks")
925
926        super().__init__(file_name, exclude_ops=exclude_ops)
927
928        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
929        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
930
931        if 'definitions' not in self.yaml:
932            self.yaml['definitions'] = []
933
934        if 'uapi-header' in self.yaml:
935            self.uapi_header = self.yaml['uapi-header']
936        else:
937            self.uapi_header = f"linux/{self.ident_name}.h"
938        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
939            self.uapi_header_name = self.uapi_header[6:-2]
940        else:
941            self.uapi_header_name = self.ident_name
942
943    def resolve(self):
944        self.resolve_up(super())
945
946        if self.yaml.get('protocol', 'genetlink') not in {'genetlink', 'genetlink-c', 'genetlink-legacy'}:
947            raise Exception("Codegen only supported for genetlink")
948
949        self.c_name = c_lower(self.ident_name)
950        if 'name-prefix' in self.yaml['operations']:
951            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
952        else:
953            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
954        if 'async-prefix' in self.yaml['operations']:
955            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
956        else:
957            self.async_op_prefix = self.op_prefix
958
959        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
960
961        self.hooks = dict()
962        for when in ['pre', 'post']:
963            self.hooks[when] = dict()
964            for op_mode in ['do', 'dump']:
965                self.hooks[when][op_mode] = dict()
966                self.hooks[when][op_mode]['set'] = set()
967                self.hooks[when][op_mode]['list'] = []
968
969        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
970        self.root_sets = dict()
971        # dict space-name -> set('request', 'reply')
972        self.pure_nested_structs = dict()
973
974        self._mark_notify()
975        self._mock_up_events()
976
977        self._load_root_sets()
978        self._load_nested_sets()
979        self._load_attr_use()
980        self._load_hooks()
981
982        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
983        if self.kernel_policy == 'global':
984            self._load_global_policy()
985
986    def new_enum(self, elem):
987        return EnumSet(self, elem)
988
989    def new_attr_set(self, elem):
990        return AttrSet(self, elem)
991
992    def new_operation(self, elem, req_value, rsp_value):
993        return Operation(self, elem, req_value, rsp_value)
994
995    def _mark_notify(self):
996        for op in self.msgs.values():
997            if 'notify' in op:
998                self.ops[op['notify']].mark_has_ntf()
999
1000    # Fake a 'do' equivalent of all events, so that we can render their response parsing
1001    def _mock_up_events(self):
1002        for op in self.yaml['operations']['list']:
1003            if 'event' in op:
1004                op['do'] = {
1005                    'reply': {
1006                        'attributes': op['event']['attributes']
1007                    }
1008                }
1009
1010    def _load_root_sets(self):
1011        for op_name, op in self.msgs.items():
1012            if 'attribute-set' not in op:
1013                continue
1014
1015            req_attrs = set()
1016            rsp_attrs = set()
1017            for op_mode in ['do', 'dump']:
1018                if op_mode in op and 'request' in op[op_mode]:
1019                    req_attrs.update(set(op[op_mode]['request']['attributes']))
1020                if op_mode in op and 'reply' in op[op_mode]:
1021                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
1022            if 'event' in op:
1023                rsp_attrs.update(set(op['event']['attributes']))
1024
1025            if op['attribute-set'] not in self.root_sets:
1026                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1027            else:
1028                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1029                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1030
1031    def _sort_pure_types(self):
1032        # Try to reorder according to dependencies
1033        pns_key_list = list(self.pure_nested_structs.keys())
1034        pns_key_seen = set()
1035        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1036        for _ in range(rounds):
1037            if len(pns_key_list) == 0:
1038                break
1039            name = pns_key_list.pop(0)
1040            finished = True
1041            for _, spec in self.attr_sets[name].items():
1042                if 'nested-attributes' in spec:
1043                    nested = spec['nested-attributes']
1044                    # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1045                    if self.pure_nested_structs[nested].recursive:
1046                        continue
1047                    if nested not in pns_key_seen:
1048                        # Dicts are sorted, this will make struct last
1049                        struct = self.pure_nested_structs.pop(name)
1050                        self.pure_nested_structs[name] = struct
1051                        finished = False
1052                        break
1053            if finished:
1054                pns_key_seen.add(name)
1055            else:
1056                pns_key_list.append(name)
1057
1058    def _load_nested_sets(self):
1059        attr_set_queue = list(self.root_sets.keys())
1060        attr_set_seen = set(self.root_sets.keys())
1061
1062        while len(attr_set_queue):
1063            a_set = attr_set_queue.pop(0)
1064            for attr, spec in self.attr_sets[a_set].items():
1065                if 'nested-attributes' not in spec:
1066                    continue
1067
1068                nested = spec['nested-attributes']
1069                if nested not in attr_set_seen:
1070                    attr_set_queue.append(nested)
1071                    attr_set_seen.add(nested)
1072
1073                inherit = set()
1074                if nested not in self.root_sets:
1075                    if nested not in self.pure_nested_structs:
1076                        self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
1077                else:
1078                    raise Exception(f'Using attr set as root and nested not supported - {nested}')
1079
1080                if 'type-value' in spec:
1081                    if nested in self.root_sets:
1082                        raise Exception("Inheriting members to a space used as root not supported")
1083                    inherit.update(set(spec['type-value']))
1084                elif spec['type'] == 'indexed-array':
1085                    inherit.add('idx')
1086                self.pure_nested_structs[nested].set_inherited(inherit)
1087
1088        for root_set, rs_members in self.root_sets.items():
1089            for attr, spec in self.attr_sets[root_set].items():
1090                if 'nested-attributes' in spec:
1091                    nested = spec['nested-attributes']
1092                    if attr in rs_members['request']:
1093                        self.pure_nested_structs[nested].request = True
1094                    if attr in rs_members['reply']:
1095                        self.pure_nested_structs[nested].reply = True
1096
1097        self._sort_pure_types()
1098
1099        # Propagate the request / reply / recursive
1100        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1101            for _, spec in self.attr_sets[attr_set].items():
1102                if 'nested-attributes' in spec:
1103                    child_name = spec['nested-attributes']
1104                    struct.child_nests.add(child_name)
1105                    child = self.pure_nested_structs.get(child_name)
1106                    if child:
1107                        if not child.recursive:
1108                            struct.child_nests.update(child.child_nests)
1109                        child.request |= struct.request
1110                        child.reply |= struct.reply
1111                if attr_set in struct.child_nests:
1112                    struct.recursive = True
1113
1114        self._sort_pure_types()
1115
1116    def _load_attr_use(self):
1117        for _, struct in self.pure_nested_structs.items():
1118            if struct.request:
1119                for _, arg in struct.member_list():
1120                    arg.request = True
1121            if struct.reply:
1122                for _, arg in struct.member_list():
1123                    arg.reply = True
1124
1125        for root_set, rs_members in self.root_sets.items():
1126            for attr, spec in self.attr_sets[root_set].items():
1127                if attr in rs_members['request']:
1128                    spec.request = True
1129                if attr in rs_members['reply']:
1130                    spec.reply = True
1131
1132    def _load_global_policy(self):
1133        global_set = set()
1134        attr_set_name = None
1135        for op_name, op in self.ops.items():
1136            if not op:
1137                continue
1138            if 'attribute-set' not in op:
1139                continue
1140
1141            if attr_set_name is None:
1142                attr_set_name = op['attribute-set']
1143            if attr_set_name != op['attribute-set']:
1144                raise Exception('For a global policy all ops must use the same set')
1145
1146            for op_mode in ['do', 'dump']:
1147                if op_mode in op:
1148                    req = op[op_mode].get('request')
1149                    if req:
1150                        global_set.update(req.get('attributes', []))
1151
1152        self.global_policy = []
1153        self.global_policy_set = attr_set_name
1154        for attr in self.attr_sets[attr_set_name]:
1155            if attr in global_set:
1156                self.global_policy.append(attr)
1157
1158    def _load_hooks(self):
1159        for op in self.ops.values():
1160            for op_mode in ['do', 'dump']:
1161                if op_mode not in op:
1162                    continue
1163                for when in ['pre', 'post']:
1164                    if when not in op[op_mode]:
1165                        continue
1166                    name = op[op_mode][when]
1167                    if name in self.hooks[when][op_mode]['set']:
1168                        continue
1169                    self.hooks[when][op_mode]['set'].add(name)
1170                    self.hooks[when][op_mode]['list'].append(name)
1171
1172
1173class RenderInfo:
1174    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1175        self.family = family
1176        self.nl = cw.nlib
1177        self.ku_space = ku_space
1178        self.op_mode = op_mode
1179        self.op = op
1180
1181        self.fixed_hdr = None
1182        if op and op.fixed_header:
1183            self.fixed_hdr = 'struct ' + c_lower(op.fixed_header)
1184
1185        # 'do' and 'dump' response parsing is identical
1186        self.type_consistent = True
1187        if op_mode != 'do' and 'dump' in op:
1188            if 'do' in op:
1189                if ('reply' in op['do']) != ('reply' in op["dump"]):
1190                    self.type_consistent = False
1191                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1192                    self.type_consistent = False
1193            else:
1194                self.type_consistent = False
1195
1196        self.attr_set = attr_set
1197        if not self.attr_set:
1198            self.attr_set = op['attribute-set']
1199
1200        self.type_name_conflict = False
1201        if op:
1202            self.type_name = c_lower(op.name)
1203        else:
1204            self.type_name = c_lower(attr_set)
1205            if attr_set in family.consts:
1206                self.type_name_conflict = True
1207
1208        self.cw = cw
1209
1210        self.struct = dict()
1211        if op_mode == 'notify':
1212            op_mode = 'do'
1213        for op_dir in ['request', 'reply']:
1214            if op:
1215                type_list = []
1216                if op_dir in op[op_mode]:
1217                    type_list = op[op_mode][op_dir]['attributes']
1218                self.struct[op_dir] = Struct(family, self.attr_set, type_list=type_list)
1219        if op_mode == 'event':
1220            self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
1221
1222
1223class CodeWriter:
1224    def __init__(self, nlib, out_file=None, overwrite=True):
1225        self.nlib = nlib
1226        self._overwrite = overwrite
1227
1228        self._nl = False
1229        self._block_end = False
1230        self._silent_block = False
1231        self._ind = 0
1232        self._ifdef_block = None
1233        if out_file is None:
1234            self._out = os.sys.stdout
1235        else:
1236            self._out = tempfile.NamedTemporaryFile('w+')
1237            self._out_file = out_file
1238
1239    def __del__(self):
1240        self.close_out_file()
1241
1242    def close_out_file(self):
1243        if self._out == os.sys.stdout:
1244            return
1245        # Avoid modifying the file if contents didn't change
1246        self._out.flush()
1247        if not self._overwrite and os.path.isfile(self._out_file):
1248            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1249                return
1250        with open(self._out_file, 'w+') as out_file:
1251            self._out.seek(0)
1252            shutil.copyfileobj(self._out, out_file)
1253            self._out.close()
1254        self._out = os.sys.stdout
1255
1256    @classmethod
1257    def _is_cond(cls, line):
1258        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1259
1260    def p(self, line, add_ind=0):
1261        if self._block_end:
1262            self._block_end = False
1263            if line.startswith('else'):
1264                line = '} ' + line
1265            else:
1266                self._out.write('\t' * self._ind + '}\n')
1267
1268        if self._nl:
1269            self._out.write('\n')
1270            self._nl = False
1271
1272        ind = self._ind
1273        if line[-1] == ':':
1274            ind -= 1
1275        if self._silent_block:
1276            ind += 1
1277        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1278        if line[0] == '#':
1279            ind = 0
1280        if add_ind:
1281            ind += add_ind
1282        self._out.write('\t' * ind + line + '\n')
1283
1284    def nl(self):
1285        self._nl = True
1286
1287    def block_start(self, line=''):
1288        if line:
1289            line = line + ' '
1290        self.p(line + '{')
1291        self._ind += 1
1292
1293    def block_end(self, line=''):
1294        if line and line[0] not in {';', ','}:
1295            line = ' ' + line
1296        self._ind -= 1
1297        self._nl = False
1298        if not line:
1299            # Delay printing closing bracket in case "else" comes next
1300            if self._block_end:
1301                self._out.write('\t' * (self._ind + 1) + '}\n')
1302            self._block_end = True
1303        else:
1304            self.p('}' + line)
1305
1306    def write_doc_line(self, doc, indent=True):
1307        words = doc.split()
1308        line = ' *'
1309        for word in words:
1310            if len(line) + len(word) >= 79:
1311                self.p(line)
1312                line = ' *'
1313                if indent:
1314                    line += '  '
1315            line += ' ' + word
1316        self.p(line)
1317
1318    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1319        if not args:
1320            args = ['void']
1321
1322        if doc:
1323            self.p('/*')
1324            self.p(' * ' + doc)
1325            self.p(' */')
1326
1327        oneline = qual_ret
1328        if qual_ret[-1] != '*':
1329            oneline += ' '
1330        oneline += f"{name}({', '.join(args)}){suffix}"
1331
1332        if len(oneline) < 80:
1333            self.p(oneline)
1334            return
1335
1336        v = qual_ret
1337        if len(v) > 3:
1338            self.p(v)
1339            v = ''
1340        elif qual_ret[-1] != '*':
1341            v += ' '
1342        v += name + '('
1343        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1344        delta_ind = len(v) - len(ind)
1345        v += args[0]
1346        i = 1
1347        while i < len(args):
1348            next_len = len(v) + len(args[i])
1349            if v[0] == '\t':
1350                next_len += delta_ind
1351            if next_len > 76:
1352                self.p(v + ',')
1353                v = ind
1354            else:
1355                v += ', '
1356            v += args[i]
1357            i += 1
1358        self.p(v + ')' + suffix)
1359
1360    def write_func_lvar(self, local_vars):
1361        if not local_vars:
1362            return
1363
1364        if type(local_vars) is str:
1365            local_vars = [local_vars]
1366
1367        local_vars.sort(key=len, reverse=True)
1368        for var in local_vars:
1369            self.p(var)
1370        self.nl()
1371
1372    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1373        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1374        self.write_func_lvar(local_vars=local_vars)
1375
1376        self.block_start()
1377        for line in body:
1378            self.p(line)
1379        self.block_end()
1380
1381    def writes_defines(self, defines):
1382        longest = 0
1383        for define in defines:
1384            if len(define[0]) > longest:
1385                longest = len(define[0])
1386        longest = ((longest + 8) // 8) * 8
1387        for define in defines:
1388            line = '#define ' + define[0]
1389            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1390            if type(define[1]) is int:
1391                line += str(define[1])
1392            elif type(define[1]) is str:
1393                line += '"' + define[1] + '"'
1394            self.p(line)
1395
1396    def write_struct_init(self, members):
1397        longest = max([len(x[0]) for x in members])
1398        longest += 1  # because we prepend a .
1399        longest = ((longest + 8) // 8) * 8
1400        for one in members:
1401            line = '.' + one[0]
1402            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1403            line += '= ' + str(one[1]) + ','
1404            self.p(line)
1405
1406    def ifdef_block(self, config):
1407        config_option = None
1408        if config:
1409            config_option = 'CONFIG_' + c_upper(config)
1410        if self._ifdef_block == config_option:
1411            return
1412
1413        if self._ifdef_block:
1414            self.p('#endif /* ' + self._ifdef_block + ' */')
1415        if config_option:
1416            self.p('#ifdef ' + config_option)
1417        self._ifdef_block = config_option
1418
1419
1420scalars = {'u8', 'u16', 'u32', 'u64', 's32', 's64', 'uint', 'sint'}
1421
1422direction_to_suffix = {
1423    'reply': '_rsp',
1424    'request': '_req',
1425    '': ''
1426}
1427
1428op_mode_to_wrapper = {
1429    'do': '',
1430    'dump': '_list',
1431    'notify': '_ntf',
1432    'event': '',
1433}
1434
1435_C_KW = {
1436    'auto',
1437    'bool',
1438    'break',
1439    'case',
1440    'char',
1441    'const',
1442    'continue',
1443    'default',
1444    'do',
1445    'double',
1446    'else',
1447    'enum',
1448    'extern',
1449    'float',
1450    'for',
1451    'goto',
1452    'if',
1453    'inline',
1454    'int',
1455    'long',
1456    'register',
1457    'return',
1458    'short',
1459    'signed',
1460    'sizeof',
1461    'static',
1462    'struct',
1463    'switch',
1464    'typedef',
1465    'union',
1466    'unsigned',
1467    'void',
1468    'volatile',
1469    'while'
1470}
1471
1472
1473def rdir(direction):
1474    if direction == 'reply':
1475        return 'request'
1476    if direction == 'request':
1477        return 'reply'
1478    return direction
1479
1480
1481def op_prefix(ri, direction, deref=False):
1482    suffix = f"_{ri.type_name}"
1483
1484    if not ri.op_mode or ri.op_mode == 'do':
1485        suffix += f"{direction_to_suffix[direction]}"
1486    else:
1487        if direction == 'request':
1488            suffix += '_req_dump'
1489        else:
1490            if ri.type_consistent:
1491                if deref:
1492                    suffix += f"{direction_to_suffix[direction]}"
1493                else:
1494                    suffix += op_mode_to_wrapper[ri.op_mode]
1495            else:
1496                suffix += '_rsp'
1497                suffix += '_dump' if deref else '_list'
1498
1499    return f"{ri.family.c_name}{suffix}"
1500
1501
1502def type_name(ri, direction, deref=False):
1503    return f"struct {op_prefix(ri, direction, deref=deref)}"
1504
1505
1506def print_prototype(ri, direction, terminate=True, doc=None):
1507    suffix = ';' if terminate else ''
1508
1509    fname = ri.op.render_name
1510    if ri.op_mode == 'dump':
1511        fname += '_dump'
1512
1513    args = ['struct ynl_sock *ys']
1514    if 'request' in ri.op[ri.op_mode]:
1515        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1516
1517    ret = 'int'
1518    if 'reply' in ri.op[ri.op_mode]:
1519        ret = f"{type_name(ri, rdir(direction))} *"
1520
1521    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1522
1523
1524def print_req_prototype(ri):
1525    print_prototype(ri, "request", doc=ri.op['doc'])
1526
1527
1528def print_dump_prototype(ri):
1529    print_prototype(ri, "request")
1530
1531
1532def put_typol_fwd(cw, struct):
1533    cw.p(f'extern const struct ynl_policy_nest {struct.render_name}_nest;')
1534
1535
1536def put_typol(cw, struct):
1537    type_max = struct.attr_set.max_name
1538    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1539
1540    for _, arg in struct.member_list():
1541        arg.attr_typol(cw)
1542
1543    cw.block_end(line=';')
1544    cw.nl()
1545
1546    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1547    cw.p(f'.max_attr = {type_max},')
1548    cw.p(f'.table = {struct.render_name}_policy,')
1549    cw.block_end(line=';')
1550    cw.nl()
1551
1552
1553def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1554    args = [f'int {arg_name}']
1555    if enum:
1556        args = [enum.user_type + ' ' + arg_name]
1557    cw.write_func_prot('const char *', f'{render_name}_str', args)
1558    cw.block_start()
1559    if enum and enum.type == 'flags':
1560        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1561    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
1562    cw.p('return NULL;')
1563    cw.p(f'return {map_name}[{arg_name}];')
1564    cw.block_end()
1565    cw.nl()
1566
1567
1568def put_op_name_fwd(family, cw):
1569    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
1570
1571
1572def put_op_name(family, cw):
1573    map_name = f'{family.c_name}_op_strmap'
1574    cw.block_start(line=f"static const char * const {map_name}[] =")
1575    for op_name, op in family.msgs.items():
1576        if op.rsp_value:
1577            # Make sure we don't add duplicated entries, if multiple commands
1578            # produce the same response in legacy families.
1579            if family.rsp_by_value[op.rsp_value] != op:
1580                cw.p(f'// skip "{op_name}", duplicate reply value')
1581                continue
1582
1583            if op.req_value == op.rsp_value:
1584                cw.p(f'[{op.enum_name}] = "{op_name}",')
1585            else:
1586                cw.p(f'[{op.rsp_value}] = "{op_name}",')
1587    cw.block_end(line=';')
1588    cw.nl()
1589
1590    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
1591
1592
1593def put_enum_to_str_fwd(family, cw, enum):
1594    args = [enum.user_type + ' value']
1595    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1596
1597
1598def put_enum_to_str(family, cw, enum):
1599    map_name = f'{enum.render_name}_strmap'
1600    cw.block_start(line=f"static const char * const {map_name}[] =")
1601    for entry in enum.entries.values():
1602        cw.p(f'[{entry.value}] = "{entry.name}",')
1603    cw.block_end(line=';')
1604    cw.nl()
1605
1606    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
1607
1608
1609def put_req_nested_prototype(ri, struct, suffix=';'):
1610    func_args = ['struct nlmsghdr *nlh',
1611                 'unsigned int attr_type',
1612                 f'{struct.ptr_name}obj']
1613
1614    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
1615                          suffix=suffix)
1616
1617
1618def put_req_nested(ri, struct):
1619    put_req_nested_prototype(ri, struct, suffix='')
1620    ri.cw.block_start()
1621    ri.cw.write_func_lvar('struct nlattr *nest;')
1622
1623    ri.cw.p("nest = ynl_attr_nest_start(nlh, attr_type);")
1624
1625    for _, arg in struct.member_list():
1626        arg.attr_put(ri, "obj")
1627
1628    ri.cw.p("ynl_attr_nest_end(nlh, nest);")
1629
1630    ri.cw.nl()
1631    ri.cw.p('return 0;')
1632    ri.cw.block_end()
1633    ri.cw.nl()
1634
1635
1636def _multi_parse(ri, struct, init_lines, local_vars):
1637    if struct.nested:
1638        iter_line = "ynl_attr_for_each_nested(attr, nested)"
1639    else:
1640        if ri.fixed_hdr:
1641            local_vars += ['void *hdr;']
1642        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
1643
1644    array_nests = set()
1645    multi_attrs = set()
1646    needs_parg = False
1647    for arg, aspec in struct.member_list():
1648        if aspec['type'] == 'indexed-array' and 'sub-type' in aspec:
1649            if aspec["sub-type"] == 'nest':
1650                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1651                array_nests.add(arg)
1652            else:
1653                raise Exception(f'Not supported sub-type {aspec["sub-type"]}')
1654        if 'multi-attr' in aspec:
1655            multi_attrs.add(arg)
1656        needs_parg |= 'nested-attributes' in aspec
1657    if array_nests or multi_attrs:
1658        local_vars.append('int i;')
1659    if needs_parg:
1660        local_vars.append('struct ynl_parse_arg parg;')
1661        init_lines.append('parg.ys = yarg->ys;')
1662
1663    all_multi = array_nests | multi_attrs
1664
1665    for anest in sorted(all_multi):
1666        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
1667
1668    ri.cw.block_start()
1669    ri.cw.write_func_lvar(local_vars)
1670
1671    for line in init_lines:
1672        ri.cw.p(line)
1673    ri.cw.nl()
1674
1675    for arg in struct.inherited:
1676        ri.cw.p(f'dst->{arg} = {arg};')
1677
1678    if ri.fixed_hdr:
1679        ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
1680        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({ri.fixed_hdr}));")
1681    for anest in sorted(all_multi):
1682        aspec = struct[anest]
1683        ri.cw.p(f"if (dst->{aspec.c_name})")
1684        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
1685
1686    ri.cw.nl()
1687    ri.cw.block_start(line=iter_line)
1688    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
1689    ri.cw.nl()
1690
1691    first = True
1692    for _, arg in struct.member_list():
1693        good = arg.attr_get(ri, 'dst', first=first)
1694        # First may be 'unused' or 'pad', ignore those
1695        first &= not good
1696
1697    ri.cw.block_end()
1698    ri.cw.nl()
1699
1700    for anest in sorted(array_nests):
1701        aspec = struct[anest]
1702
1703        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1704        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1705        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1706        ri.cw.p('i = 0;')
1707        ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1708        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1709        ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1710        ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
1711        ri.cw.p('return YNL_PARSE_CB_ERROR;')
1712        ri.cw.p('i++;')
1713        ri.cw.block_end()
1714        ri.cw.block_end()
1715    ri.cw.nl()
1716
1717    for anest in sorted(multi_attrs):
1718        aspec = struct[anest]
1719        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1720        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1721        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1722        ri.cw.p('i = 0;')
1723        if 'nested-attributes' in aspec:
1724            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1725        ri.cw.block_start(line=iter_line)
1726        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
1727        if 'nested-attributes' in aspec:
1728            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1729            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1730            ri.cw.p('return YNL_PARSE_CB_ERROR;')
1731        elif aspec.type in scalars:
1732            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
1733        else:
1734            raise Exception('Nest parsing type not supported yet')
1735        ri.cw.p('i++;')
1736        ri.cw.block_end()
1737        ri.cw.block_end()
1738        ri.cw.block_end()
1739    ri.cw.nl()
1740
1741    if struct.nested:
1742        ri.cw.p('return 0;')
1743    else:
1744        ri.cw.p('return YNL_PARSE_CB_OK;')
1745    ri.cw.block_end()
1746    ri.cw.nl()
1747
1748
1749def parse_rsp_nested_prototype(ri, struct, suffix=';'):
1750    func_args = ['struct ynl_parse_arg *yarg',
1751                 'const struct nlattr *nested']
1752    for arg in struct.inherited:
1753        func_args.append('__u32 ' + arg)
1754
1755    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
1756                          suffix=suffix)
1757
1758
1759def parse_rsp_nested(ri, struct):
1760    parse_rsp_nested_prototype(ri, struct, suffix='')
1761
1762    local_vars = ['const struct nlattr *attr;',
1763                  f'{struct.ptr_name}dst = yarg->data;']
1764    init_lines = []
1765
1766    _multi_parse(ri, struct, init_lines, local_vars)
1767
1768
1769def parse_rsp_msg(ri, deref=False):
1770    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1771        return
1772
1773    func_args = ['const struct nlmsghdr *nlh',
1774                 'struct ynl_parse_arg *yarg']
1775
1776    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1777                  'const struct nlattr *attr;']
1778    init_lines = ['dst = yarg->data;']
1779
1780    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1781
1782    if ri.struct["reply"].member_list():
1783        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1784    else:
1785        # Empty reply
1786        ri.cw.block_start()
1787        ri.cw.p('return YNL_PARSE_CB_OK;')
1788        ri.cw.block_end()
1789        ri.cw.nl()
1790
1791
1792def print_req(ri):
1793    ret_ok = '0'
1794    ret_err = '-1'
1795    direction = "request"
1796    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
1797                  'struct nlmsghdr *nlh;',
1798                  'int err;']
1799
1800    if 'reply' in ri.op[ri.op_mode]:
1801        ret_ok = 'rsp'
1802        ret_err = 'NULL'
1803        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
1804
1805    if ri.fixed_hdr:
1806        local_vars += ['size_t hdr_len;',
1807                       'void *hdr;']
1808
1809    print_prototype(ri, direction, terminate=False)
1810    ri.cw.block_start()
1811    ri.cw.write_func_lvar(local_vars)
1812
1813    ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1814
1815    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1816    if 'reply' in ri.op[ri.op_mode]:
1817        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1818    ri.cw.nl()
1819
1820    if ri.fixed_hdr:
1821        ri.cw.p("hdr_len = sizeof(req->_hdr);")
1822        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
1823        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
1824        ri.cw.nl()
1825
1826    for _, attr in ri.struct["request"].member_list():
1827        attr.attr_put(ri, "req")
1828    ri.cw.nl()
1829
1830    if 'reply' in ri.op[ri.op_mode]:
1831        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1832        ri.cw.p('yrs.yarg.data = rsp;')
1833        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
1834        if ri.op.value is not None:
1835            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
1836        else:
1837            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
1838        ri.cw.nl()
1839    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
1840    ri.cw.p('if (err < 0)')
1841    if 'reply' in ri.op[ri.op_mode]:
1842        ri.cw.p('goto err_free;')
1843    else:
1844        ri.cw.p('return -1;')
1845    ri.cw.nl()
1846
1847    ri.cw.p(f"return {ret_ok};")
1848    ri.cw.nl()
1849
1850    if 'reply' in ri.op[ri.op_mode]:
1851        ri.cw.p('err_free:')
1852        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1853        ri.cw.p(f"return {ret_err};")
1854
1855    ri.cw.block_end()
1856
1857
1858def print_dump(ri):
1859    direction = "request"
1860    print_prototype(ri, direction, terminate=False)
1861    ri.cw.block_start()
1862    local_vars = ['struct ynl_dump_state yds = {};',
1863                  'struct nlmsghdr *nlh;',
1864                  'int err;']
1865
1866    if ri.fixed_hdr:
1867        local_vars += ['size_t hdr_len;',
1868                       'void *hdr;']
1869
1870    ri.cw.write_func_lvar(local_vars)
1871
1872    ri.cw.p('yds.yarg.ys = ys;')
1873    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1874    ri.cw.p("yds.yarg.data = NULL;")
1875    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1876    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1877    if ri.op.value is not None:
1878        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
1879    else:
1880        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
1881    ri.cw.nl()
1882    ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1883
1884    if ri.fixed_hdr:
1885        ri.cw.p("hdr_len = sizeof(req->_hdr);")
1886        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
1887        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
1888        ri.cw.nl()
1889
1890    if "request" in ri.op[ri.op_mode]:
1891        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1892        ri.cw.nl()
1893        for _, attr in ri.struct["request"].member_list():
1894            attr.attr_put(ri, "req")
1895    ri.cw.nl()
1896
1897    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
1898    ri.cw.p('if (err < 0)')
1899    ri.cw.p('goto free_list;')
1900    ri.cw.nl()
1901
1902    ri.cw.p('return yds.first;')
1903    ri.cw.nl()
1904    ri.cw.p('free_list:')
1905    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
1906    ri.cw.p('return NULL;')
1907    ri.cw.block_end()
1908
1909
1910def call_free(ri, direction, var):
1911    return f"{op_prefix(ri, direction)}_free({var});"
1912
1913
1914def free_arg_name(direction):
1915    if direction:
1916        return direction_to_suffix[direction][1:]
1917    return 'obj'
1918
1919
1920def print_alloc_wrapper(ri, direction):
1921    name = op_prefix(ri, direction)
1922    ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
1923    ri.cw.block_start()
1924    ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
1925    ri.cw.block_end()
1926
1927
1928def print_free_prototype(ri, direction, suffix=';'):
1929    name = op_prefix(ri, direction)
1930    struct_name = name
1931    if ri.type_name_conflict:
1932        struct_name += '_'
1933    arg = free_arg_name(direction)
1934    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
1935
1936
1937def _print_type(ri, direction, struct):
1938    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
1939    if not direction and ri.type_name_conflict:
1940        suffix += '_'
1941
1942    if ri.op_mode == 'dump':
1943        suffix += '_dump'
1944
1945    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
1946
1947    if ri.fixed_hdr:
1948        ri.cw.p(ri.fixed_hdr + ' _hdr;')
1949        ri.cw.nl()
1950
1951    meta_started = False
1952    for _, attr in struct.member_list():
1953        for type_filter in ['len', 'bit']:
1954            line = attr.presence_member(ri.ku_space, type_filter)
1955            if line:
1956                if not meta_started:
1957                    ri.cw.block_start(line=f"struct")
1958                    meta_started = True
1959                ri.cw.p(line)
1960    if meta_started:
1961        ri.cw.block_end(line='_present;')
1962        ri.cw.nl()
1963
1964    for arg in struct.inherited:
1965        ri.cw.p(f"__u32 {arg};")
1966
1967    for _, attr in struct.member_list():
1968        attr.struct_member(ri)
1969
1970    ri.cw.block_end(line=';')
1971    ri.cw.nl()
1972
1973
1974def print_type(ri, direction):
1975    _print_type(ri, direction, ri.struct[direction])
1976
1977
1978def print_type_full(ri, struct):
1979    _print_type(ri, "", struct)
1980
1981
1982def print_type_helpers(ri, direction, deref=False):
1983    print_free_prototype(ri, direction)
1984    ri.cw.nl()
1985
1986    if ri.ku_space == 'user' and direction == 'request':
1987        for _, attr in ri.struct[direction].member_list():
1988            attr.setter(ri, ri.attr_set, direction, deref=deref)
1989    ri.cw.nl()
1990
1991
1992def print_req_type_helpers(ri):
1993    if len(ri.struct["request"].attr_list) == 0:
1994        return
1995    print_alloc_wrapper(ri, "request")
1996    print_type_helpers(ri, "request")
1997
1998
1999def print_rsp_type_helpers(ri):
2000    if 'reply' not in ri.op[ri.op_mode]:
2001        return
2002    print_type_helpers(ri, "reply")
2003
2004
2005def print_parse_prototype(ri, direction, terminate=True):
2006    suffix = "_rsp" if direction == "reply" else "_req"
2007    term = ';' if terminate else ''
2008
2009    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
2010                          ['const struct nlattr **tb',
2011                           f"struct {ri.op.render_name}{suffix} *req"],
2012                          suffix=term)
2013
2014
2015def print_req_type(ri):
2016    if len(ri.struct["request"].attr_list) == 0:
2017        return
2018    print_type(ri, "request")
2019
2020
2021def print_req_free(ri):
2022    if 'request' not in ri.op[ri.op_mode]:
2023        return
2024    _free_type(ri, 'request', ri.struct['request'])
2025
2026
2027def print_rsp_type(ri):
2028    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2029        direction = 'reply'
2030    elif ri.op_mode == 'event':
2031        direction = 'reply'
2032    else:
2033        return
2034    print_type(ri, direction)
2035
2036
2037def print_wrapped_type(ri):
2038    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2039    if ri.op_mode == 'dump':
2040        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2041    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2042        ri.cw.p('__u16 family;')
2043        ri.cw.p('__u8 cmd;')
2044        ri.cw.p('struct ynl_ntf_base_type *next;')
2045        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2046    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2047    ri.cw.block_end(line=';')
2048    ri.cw.nl()
2049    print_free_prototype(ri, 'reply')
2050    ri.cw.nl()
2051
2052
2053def _free_type_members_iter(ri, struct):
2054    for _, attr in struct.member_list():
2055        if attr.free_needs_iter():
2056            ri.cw.p('unsigned int i;')
2057            ri.cw.nl()
2058            break
2059
2060
2061def _free_type_members(ri, var, struct, ref=''):
2062    for _, attr in struct.member_list():
2063        attr.free(ri, var, ref)
2064
2065
2066def _free_type(ri, direction, struct):
2067    var = free_arg_name(direction)
2068
2069    print_free_prototype(ri, direction, suffix='')
2070    ri.cw.block_start()
2071    _free_type_members_iter(ri, struct)
2072    _free_type_members(ri, var, struct)
2073    if direction:
2074        ri.cw.p(f'free({var});')
2075    ri.cw.block_end()
2076    ri.cw.nl()
2077
2078
2079def free_rsp_nested_prototype(ri):
2080        print_free_prototype(ri, "")
2081
2082
2083def free_rsp_nested(ri, struct):
2084    _free_type(ri, "", struct)
2085
2086
2087def print_rsp_free(ri):
2088    if 'reply' not in ri.op[ri.op_mode]:
2089        return
2090    _free_type(ri, 'reply', ri.struct['reply'])
2091
2092
2093def print_dump_type_free(ri):
2094    sub_type = type_name(ri, 'reply')
2095
2096    print_free_prototype(ri, 'reply', suffix='')
2097    ri.cw.block_start()
2098    ri.cw.p(f"{sub_type} *next = rsp;")
2099    ri.cw.nl()
2100    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2101    _free_type_members_iter(ri, ri.struct['reply'])
2102    ri.cw.p('rsp = next;')
2103    ri.cw.p('next = rsp->next;')
2104    ri.cw.nl()
2105
2106    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2107    ri.cw.p(f'free(rsp);')
2108    ri.cw.block_end()
2109    ri.cw.block_end()
2110    ri.cw.nl()
2111
2112
2113def print_ntf_type_free(ri):
2114    print_free_prototype(ri, 'reply', suffix='')
2115    ri.cw.block_start()
2116    _free_type_members_iter(ri, ri.struct['reply'])
2117    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2118    ri.cw.p(f'free(rsp);')
2119    ri.cw.block_end()
2120    ri.cw.nl()
2121
2122
2123def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2124    if terminate and ri and policy_should_be_static(struct.family):
2125        return
2126
2127    if terminate:
2128        prefix = 'extern '
2129    else:
2130        if ri and policy_should_be_static(struct.family):
2131            prefix = 'static '
2132        else:
2133            prefix = ''
2134
2135    suffix = ';' if terminate else ' = {'
2136
2137    max_attr = struct.attr_max_val
2138    if ri:
2139        name = ri.op.render_name
2140        if ri.op.dual_policy:
2141            name += '_' + ri.op_mode
2142    else:
2143        name = struct.render_name
2144    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2145
2146
2147def print_req_policy(cw, struct, ri=None):
2148    if ri and ri.op:
2149        cw.ifdef_block(ri.op.get('config-cond', None))
2150    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2151    for _, arg in struct.member_list():
2152        arg.attr_policy(cw)
2153    cw.p("};")
2154    cw.ifdef_block(None)
2155    cw.nl()
2156
2157
2158def kernel_can_gen_family_struct(family):
2159    return family.proto == 'genetlink'
2160
2161
2162def policy_should_be_static(family):
2163    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2164
2165
2166def print_kernel_policy_ranges(family, cw):
2167    first = True
2168    for _, attr_set in family.attr_sets.items():
2169        if attr_set.subset_of:
2170            continue
2171
2172        for _, attr in attr_set.items():
2173            if not attr.request:
2174                continue
2175            if 'full-range' not in attr.checks:
2176                continue
2177
2178            if first:
2179                cw.p('/* Integer value ranges */')
2180                first = False
2181
2182            sign = '' if attr.type[0] == 'u' else '_signed'
2183            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2184            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2185            members = []
2186            if 'min' in attr.checks:
2187                members.append(('min', attr.get_limit_str('min', suffix=suffix)))
2188            if 'max' in attr.checks:
2189                members.append(('max', attr.get_limit_str('max', suffix=suffix)))
2190            cw.write_struct_init(members)
2191            cw.block_end(line=';')
2192            cw.nl()
2193
2194
2195def print_kernel_op_table_fwd(family, cw, terminate):
2196    exported = not kernel_can_gen_family_struct(family)
2197
2198    if not terminate or exported:
2199        cw.p(f"/* Ops table for {family.ident_name} */")
2200
2201        pol_to_struct = {'global': 'genl_small_ops',
2202                         'per-op': 'genl_ops',
2203                         'split': 'genl_split_ops'}
2204        struct_type = pol_to_struct[family.kernel_policy]
2205
2206        if not exported:
2207            cnt = ""
2208        elif family.kernel_policy == 'split':
2209            cnt = 0
2210            for op in family.ops.values():
2211                if 'do' in op:
2212                    cnt += 1
2213                if 'dump' in op:
2214                    cnt += 1
2215        else:
2216            cnt = len(family.ops)
2217
2218        qual = 'static const' if not exported else 'const'
2219        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2220        if terminate:
2221            cw.p(f"extern {line};")
2222        else:
2223            cw.block_start(line=line + ' =')
2224
2225    if not terminate:
2226        return
2227
2228    cw.nl()
2229    for name in family.hooks['pre']['do']['list']:
2230        cw.write_func_prot('int', c_lower(name),
2231                           ['const struct genl_split_ops *ops',
2232                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2233    for name in family.hooks['post']['do']['list']:
2234        cw.write_func_prot('void', c_lower(name),
2235                           ['const struct genl_split_ops *ops',
2236                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2237    for name in family.hooks['pre']['dump']['list']:
2238        cw.write_func_prot('int', c_lower(name),
2239                           ['struct netlink_callback *cb'], suffix=';')
2240    for name in family.hooks['post']['dump']['list']:
2241        cw.write_func_prot('int', c_lower(name),
2242                           ['struct netlink_callback *cb'], suffix=';')
2243
2244    cw.nl()
2245
2246    for op_name, op in family.ops.items():
2247        if op.is_async:
2248            continue
2249
2250        if 'do' in op:
2251            name = c_lower(f"{family.ident_name}-nl-{op_name}-doit")
2252            cw.write_func_prot('int', name,
2253                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2254
2255        if 'dump' in op:
2256            name = c_lower(f"{family.ident_name}-nl-{op_name}-dumpit")
2257            cw.write_func_prot('int', name,
2258                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2259    cw.nl()
2260
2261
2262def print_kernel_op_table_hdr(family, cw):
2263    print_kernel_op_table_fwd(family, cw, terminate=True)
2264
2265
2266def print_kernel_op_table(family, cw):
2267    print_kernel_op_table_fwd(family, cw, terminate=False)
2268    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2269        for op_name, op in family.ops.items():
2270            if op.is_async:
2271                continue
2272
2273            cw.ifdef_block(op.get('config-cond', None))
2274            cw.block_start()
2275            members = [('cmd', op.enum_name)]
2276            if 'dont-validate' in op:
2277                members.append(('validate',
2278                                ' | '.join([c_upper('genl-dont-validate-' + x)
2279                                            for x in op['dont-validate']])), )
2280            for op_mode in ['do', 'dump']:
2281                if op_mode in op:
2282                    name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2283                    members.append((op_mode + 'it', name))
2284            if family.kernel_policy == 'per-op':
2285                struct = Struct(family, op['attribute-set'],
2286                                type_list=op['do']['request']['attributes'])
2287
2288                name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2289                members.append(('policy', name))
2290                members.append(('maxattr', struct.attr_max_val.enum_name))
2291            if 'flags' in op:
2292                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2293            cw.write_struct_init(members)
2294            cw.block_end(line=',')
2295    elif family.kernel_policy == 'split':
2296        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2297                    'dump': {'pre': 'start', 'post': 'done'}}
2298
2299        for op_name, op in family.ops.items():
2300            for op_mode in ['do', 'dump']:
2301                if op.is_async or op_mode not in op:
2302                    continue
2303
2304                cw.ifdef_block(op.get('config-cond', None))
2305                cw.block_start()
2306                members = [('cmd', op.enum_name)]
2307                if 'dont-validate' in op:
2308                    dont_validate = []
2309                    for x in op['dont-validate']:
2310                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2311                            continue
2312                        if op_mode == "dump" and x == 'strict':
2313                            continue
2314                        dont_validate.append(x)
2315
2316                    if dont_validate:
2317                        members.append(('validate',
2318                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2319                                                    for x in dont_validate])), )
2320                name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2321                if 'pre' in op[op_mode]:
2322                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2323                members.append((op_mode + 'it', name))
2324                if 'post' in op[op_mode]:
2325                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2326                if 'request' in op[op_mode]:
2327                    struct = Struct(family, op['attribute-set'],
2328                                    type_list=op[op_mode]['request']['attributes'])
2329
2330                    if op.dual_policy:
2331                        name = c_lower(f"{family.ident_name}-{op_name}-{op_mode}-nl-policy")
2332                    else:
2333                        name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2334                    members.append(('policy', name))
2335                    members.append(('maxattr', struct.attr_max_val.enum_name))
2336                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2337                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2338                cw.write_struct_init(members)
2339                cw.block_end(line=',')
2340    cw.ifdef_block(None)
2341
2342    cw.block_end(line=';')
2343    cw.nl()
2344
2345
2346def print_kernel_mcgrp_hdr(family, cw):
2347    if not family.mcgrps['list']:
2348        return
2349
2350    cw.block_start('enum')
2351    for grp in family.mcgrps['list']:
2352        grp_id = c_upper(f"{family.ident_name}-nlgrp-{grp['name']},")
2353        cw.p(grp_id)
2354    cw.block_end(';')
2355    cw.nl()
2356
2357
2358def print_kernel_mcgrp_src(family, cw):
2359    if not family.mcgrps['list']:
2360        return
2361
2362    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
2363    for grp in family.mcgrps['list']:
2364        name = grp['name']
2365        grp_id = c_upper(f"{family.ident_name}-nlgrp-{name}")
2366        cw.p('[' + grp_id + '] = { "' + name + '", },')
2367    cw.block_end(';')
2368    cw.nl()
2369
2370
2371def print_kernel_family_struct_hdr(family, cw):
2372    if not kernel_can_gen_family_struct(family):
2373        return
2374
2375    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
2376    cw.nl()
2377    if 'sock-priv' in family.kernel_family:
2378        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
2379        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
2380        cw.nl()
2381
2382
2383def print_kernel_family_struct_src(family, cw):
2384    if not kernel_can_gen_family_struct(family):
2385        return
2386
2387    if 'sock-priv' in family.kernel_family:
2388        # Generate "trampolines" to make CFI happy
2389        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_init",
2390                      [f"{family.c_name}_nl_sock_priv_init(priv);"],
2391                      ["void *priv"])
2392        cw.nl()
2393        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_destroy",
2394                      [f"{family.c_name}_nl_sock_priv_destroy(priv);"],
2395                      ["void *priv"])
2396        cw.nl()
2397
2398    cw.block_start(f"struct genl_family {family.ident_name}_nl_family __ro_after_init =")
2399    cw.p('.name\t\t= ' + family.fam_key + ',')
2400    cw.p('.version\t= ' + family.ver_key + ',')
2401    cw.p('.netnsok\t= true,')
2402    cw.p('.parallel_ops\t= true,')
2403    cw.p('.module\t\t= THIS_MODULE,')
2404    if family.kernel_policy == 'per-op':
2405        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
2406        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2407    elif family.kernel_policy == 'split':
2408        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
2409        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2410    if family.mcgrps['list']:
2411        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
2412        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
2413    if 'sock-priv' in family.kernel_family:
2414        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
2415        cw.p(f'.sock_priv_init\t= __{family.c_name}_nl_sock_priv_init,')
2416        cw.p(f'.sock_priv_destroy = __{family.c_name}_nl_sock_priv_destroy,')
2417    cw.block_end(';')
2418
2419
2420def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2421    start_line = 'enum'
2422    if enum_name in obj:
2423        if obj[enum_name]:
2424            start_line = 'enum ' + c_lower(obj[enum_name])
2425    elif ckey and ckey in obj:
2426        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
2427    cw.block_start(line=start_line)
2428
2429
2430def render_uapi(family, cw):
2431    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
2432    hdr_prot = hdr_prot.replace('/', '_')
2433    cw.p('#ifndef ' + hdr_prot)
2434    cw.p('#define ' + hdr_prot)
2435    cw.nl()
2436
2437    defines = [(family.fam_key, family["name"]),
2438               (family.ver_key, family.get('version', 1))]
2439    cw.writes_defines(defines)
2440    cw.nl()
2441
2442    defines = []
2443    for const in family['definitions']:
2444        if const['type'] != 'const':
2445            cw.writes_defines(defines)
2446            defines = []
2447            cw.nl()
2448
2449        # Write kdoc for enum and flags (one day maybe also structs)
2450        if const['type'] == 'enum' or const['type'] == 'flags':
2451            enum = family.consts[const['name']]
2452
2453            if enum.has_doc():
2454                if enum.has_entry_doc():
2455                    cw.p('/**')
2456                    doc = ''
2457                    if 'doc' in enum:
2458                        doc = ' - ' + enum['doc']
2459                    cw.write_doc_line(enum.enum_name + doc)
2460                else:
2461                    cw.p('/*')
2462                    cw.write_doc_line(enum['doc'], indent=False)
2463                for entry in enum.entries.values():
2464                    if entry.has_doc():
2465                        doc = '@' + entry.c_name + ': ' + entry['doc']
2466                        cw.write_doc_line(doc)
2467                cw.p(' */')
2468
2469            uapi_enum_start(family, cw, const, 'name')
2470            name_pfx = const.get('name-prefix', f"{family.ident_name}-{const['name']}-")
2471            for entry in enum.entries.values():
2472                suffix = ','
2473                if entry.value_change:
2474                    suffix = f" = {entry.user_value()}" + suffix
2475                cw.p(entry.c_name + suffix)
2476
2477            if const.get('render-max', False):
2478                cw.nl()
2479                cw.p('/* private: */')
2480                if const['type'] == 'flags':
2481                    max_name = c_upper(name_pfx + 'mask')
2482                    max_val = f' = {enum.get_mask()},'
2483                    cw.p(max_name + max_val)
2484                else:
2485                    max_name = c_upper(name_pfx + 'max')
2486                    cw.p('__' + max_name + ',')
2487                    cw.p(max_name + ' = (__' + max_name + ' - 1)')
2488            cw.block_end(line=';')
2489            cw.nl()
2490        elif const['type'] == 'const':
2491            defines.append([c_upper(family.get('c-define-name',
2492                                               f"{family.ident_name}-{const['name']}")),
2493                            const['value']])
2494
2495    if defines:
2496        cw.writes_defines(defines)
2497        cw.nl()
2498
2499    max_by_define = family.get('max-by-define', False)
2500
2501    for _, attr_set in family.attr_sets.items():
2502        if attr_set.subset_of:
2503            continue
2504
2505        max_value = f"({attr_set.cnt_name} - 1)"
2506
2507        val = 0
2508        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2509        for _, attr in attr_set.items():
2510            suffix = ','
2511            if attr.value != val:
2512                suffix = f" = {attr.value},"
2513                val = attr.value
2514            val += 1
2515            cw.p(attr.enum_name + suffix)
2516        cw.nl()
2517        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
2518        if not max_by_define:
2519            cw.p(f"{attr_set.max_name} = {max_value}")
2520        cw.block_end(line=';')
2521        if max_by_define:
2522            cw.p(f"#define {attr_set.max_name} {max_value}")
2523        cw.nl()
2524
2525    # Commands
2526    separate_ntf = 'async-prefix' in family['operations']
2527
2528    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2529    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2530    max_value = f"({cnt_name} - 1)"
2531
2532    uapi_enum_start(family, cw, family['operations'], 'enum-name')
2533    val = 0
2534    for op in family.msgs.values():
2535        if separate_ntf and ('notify' in op or 'event' in op):
2536            continue
2537
2538        suffix = ','
2539        if op.value != val:
2540            suffix = f" = {op.value},"
2541            val = op.value
2542        cw.p(op.enum_name + suffix)
2543        val += 1
2544    cw.nl()
2545    cw.p(cnt_name + ('' if max_by_define else ','))
2546    if not max_by_define:
2547        cw.p(f"{max_name} = {max_value}")
2548    cw.block_end(line=';')
2549    if max_by_define:
2550        cw.p(f"#define {max_name} {max_value}")
2551    cw.nl()
2552
2553    if separate_ntf:
2554        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2555        for op in family.msgs.values():
2556            if separate_ntf and not ('notify' in op or 'event' in op):
2557                continue
2558
2559            suffix = ','
2560            if 'value' in op:
2561                suffix = f" = {op['value']},"
2562            cw.p(op.enum_name + suffix)
2563        cw.block_end(line=';')
2564        cw.nl()
2565
2566    # Multicast
2567    defines = []
2568    for grp in family.mcgrps['list']:
2569        name = grp['name']
2570        defines.append([c_upper(grp.get('c-define-name', f"{family.ident_name}-mcgrp-{name}")),
2571                        f'{name}'])
2572    cw.nl()
2573    if defines:
2574        cw.writes_defines(defines)
2575        cw.nl()
2576
2577    cw.p(f'#endif /* {hdr_prot} */')
2578
2579
2580def _render_user_ntf_entry(ri, op):
2581    ri.cw.block_start(line=f"[{op.enum_name}] = ")
2582    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
2583    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
2584    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
2585    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
2586    ri.cw.block_end(line=',')
2587
2588
2589def render_user_family(family, cw, prototype):
2590    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
2591    if prototype:
2592        cw.p(f'extern {symbol};')
2593        return
2594
2595    if family.ntfs:
2596        cw.block_start(line=f"static const struct ynl_ntf_info {family['name']}_ntf_info[] = ")
2597        for ntf_op_name, ntf_op in family.ntfs.items():
2598            if 'notify' in ntf_op:
2599                op = family.ops[ntf_op['notify']]
2600                ri = RenderInfo(cw, family, "user", op, "notify")
2601            elif 'event' in ntf_op:
2602                ri = RenderInfo(cw, family, "user", ntf_op, "event")
2603            else:
2604                raise Exception('Invalid notification ' + ntf_op_name)
2605            _render_user_ntf_entry(ri, ntf_op)
2606        for op_name, op in family.ops.items():
2607            if 'event' not in op:
2608                continue
2609            ri = RenderInfo(cw, family, "user", op, "event")
2610            _render_user_ntf_entry(ri, op)
2611        cw.block_end(line=";")
2612        cw.nl()
2613
2614    cw.block_start(f'{symbol} = ')
2615    cw.p(f'.name\t\t= "{family.c_name}",')
2616    if family.fixed_header:
2617        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
2618    else:
2619        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
2620    if family.ntfs:
2621        cw.p(f".ntf_info\t= {family['name']}_ntf_info,")
2622        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family['name']}_ntf_info),")
2623    cw.block_end(line=';')
2624
2625
2626def family_contains_bitfield32(family):
2627    for _, attr_set in family.attr_sets.items():
2628        if attr_set.subset_of:
2629            continue
2630        for _, attr in attr_set.items():
2631            if attr.type == "bitfield32":
2632                return True
2633    return False
2634
2635
2636def find_kernel_root(full_path):
2637    sub_path = ''
2638    while True:
2639        sub_path = os.path.join(os.path.basename(full_path), sub_path)
2640        full_path = os.path.dirname(full_path)
2641        maintainers = os.path.join(full_path, "MAINTAINERS")
2642        if os.path.exists(maintainers):
2643            return full_path, sub_path[:-1]
2644
2645
2646def main():
2647    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2648    parser.add_argument('--mode', dest='mode', type=str, required=True)
2649    parser.add_argument('--spec', dest='spec', type=str, required=True)
2650    parser.add_argument('--header', dest='header', action='store_true', default=None)
2651    parser.add_argument('--source', dest='header', action='store_false')
2652    parser.add_argument('--user-header', nargs='+', default=[])
2653    parser.add_argument('--cmp-out', action='store_true', default=None,
2654                        help='Do not overwrite the output file if the new output is identical to the old')
2655    parser.add_argument('--exclude-op', action='append', default=[])
2656    parser.add_argument('-o', dest='out_file', type=str, default=None)
2657    args = parser.parse_args()
2658
2659    if args.header is None:
2660        parser.error("--header or --source is required")
2661
2662    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
2663
2664    try:
2665        parsed = Family(args.spec, exclude_ops)
2666        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
2667            print('Spec license:', parsed.license)
2668            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
2669            os.sys.exit(1)
2670    except yaml.YAMLError as exc:
2671        print(exc)
2672        os.sys.exit(1)
2673        return
2674
2675    supported_models = ['unified']
2676    if args.mode in ['user', 'kernel']:
2677        supported_models += ['directional']
2678    if parsed.msg_id_model not in supported_models:
2679        print(f'Message enum-model {parsed.msg_id_model} not supported for {args.mode} generation')
2680        os.sys.exit(1)
2681
2682    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
2683
2684    _, spec_kernel = find_kernel_root(args.spec)
2685    if args.mode == 'uapi' or args.header:
2686        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
2687    else:
2688        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
2689    cw.p("/* Do not edit directly, auto-generated from: */")
2690    cw.p(f"/*\t{spec_kernel} */")
2691    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2692    if args.exclude_op or args.user_header:
2693        line = ''
2694        line += ' --user-header '.join([''] + args.user_header)
2695        line += ' --exclude-op '.join([''] + args.exclude_op)
2696        cw.p(f'/* YNL-ARG{line} */')
2697    cw.nl()
2698
2699    if args.mode == 'uapi':
2700        render_uapi(parsed, cw)
2701        return
2702
2703    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
2704    if args.header:
2705        cw.p('#ifndef ' + hdr_prot)
2706        cw.p('#define ' + hdr_prot)
2707        cw.nl()
2708
2709    hdr_file=os.path.basename(args.out_file[:-2]) + ".h"
2710
2711    if args.mode == 'kernel':
2712        cw.p('#include <net/netlink.h>')
2713        cw.p('#include <net/genetlink.h>')
2714        cw.nl()
2715        if not args.header:
2716            if args.out_file:
2717                cw.p(f'#include "{hdr_file}"')
2718            cw.nl()
2719        headers = ['uapi/' + parsed.uapi_header]
2720        headers += parsed.kernel_family.get('headers', [])
2721    else:
2722        cw.p('#include <stdlib.h>')
2723        cw.p('#include <string.h>')
2724        if args.header:
2725            cw.p('#include <linux/types.h>')
2726            if family_contains_bitfield32(parsed):
2727                cw.p('#include <linux/netlink.h>')
2728        else:
2729            cw.p(f'#include "{hdr_file}"')
2730            cw.p('#include "ynl.h"')
2731        headers = [parsed.uapi_header]
2732    for definition in parsed['definitions']:
2733        if 'header' in definition:
2734            headers.append(definition['header'])
2735    for one in headers:
2736        cw.p(f"#include <{one}>")
2737    cw.nl()
2738
2739    if args.mode == "user":
2740        if not args.header:
2741            cw.p("#include <linux/genetlink.h>")
2742            cw.nl()
2743            for one in args.user_header:
2744                cw.p(f'#include "{one}"')
2745        else:
2746            cw.p('struct ynl_sock;')
2747            cw.nl()
2748            render_user_family(parsed, cw, True)
2749        cw.nl()
2750
2751    if args.mode == "kernel":
2752        if args.header:
2753            for _, struct in sorted(parsed.pure_nested_structs.items()):
2754                if struct.request:
2755                    cw.p('/* Common nested types */')
2756                    break
2757            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2758                if struct.request:
2759                    print_req_policy_fwd(cw, struct)
2760            cw.nl()
2761
2762            if parsed.kernel_policy == 'global':
2763                cw.p(f"/* Global operation policy for {parsed.name} */")
2764
2765                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2766                print_req_policy_fwd(cw, struct)
2767                cw.nl()
2768
2769            if parsed.kernel_policy in {'per-op', 'split'}:
2770                for op_name, op in parsed.ops.items():
2771                    if 'do' in op and 'event' not in op:
2772                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
2773                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2774                        cw.nl()
2775
2776            print_kernel_op_table_hdr(parsed, cw)
2777            print_kernel_mcgrp_hdr(parsed, cw)
2778            print_kernel_family_struct_hdr(parsed, cw)
2779        else:
2780            print_kernel_policy_ranges(parsed, cw)
2781
2782            for _, struct in sorted(parsed.pure_nested_structs.items()):
2783                if struct.request:
2784                    cw.p('/* Common nested types */')
2785                    break
2786            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2787                if struct.request:
2788                    print_req_policy(cw, struct)
2789            cw.nl()
2790
2791            if parsed.kernel_policy == 'global':
2792                cw.p(f"/* Global operation policy for {parsed.name} */")
2793
2794                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2795                print_req_policy(cw, struct)
2796                cw.nl()
2797
2798            for op_name, op in parsed.ops.items():
2799                if parsed.kernel_policy in {'per-op', 'split'}:
2800                    for op_mode in ['do', 'dump']:
2801                        if op_mode in op and 'request' in op[op_mode]:
2802                            cw.p(f"/* {op.enum_name} - {op_mode} */")
2803                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
2804                            print_req_policy(cw, ri.struct['request'], ri=ri)
2805                            cw.nl()
2806
2807            print_kernel_op_table(parsed, cw)
2808            print_kernel_mcgrp_src(parsed, cw)
2809            print_kernel_family_struct_src(parsed, cw)
2810
2811    if args.mode == "user":
2812        if args.header:
2813            cw.p('/* Enums */')
2814            put_op_name_fwd(parsed, cw)
2815
2816            for name, const in parsed.consts.items():
2817                if isinstance(const, EnumSet):
2818                    put_enum_to_str_fwd(parsed, cw, const)
2819            cw.nl()
2820
2821            cw.p('/* Common nested types */')
2822            for attr_set, struct in parsed.pure_nested_structs.items():
2823                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2824                print_type_full(ri, struct)
2825
2826            for op_name, op in parsed.ops.items():
2827                cw.p(f"/* ============== {op.enum_name} ============== */")
2828
2829                if 'do' in op and 'event' not in op:
2830                    cw.p(f"/* {op.enum_name} - do */")
2831                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2832                    print_req_type(ri)
2833                    print_req_type_helpers(ri)
2834                    cw.nl()
2835                    print_rsp_type(ri)
2836                    print_rsp_type_helpers(ri)
2837                    cw.nl()
2838                    print_req_prototype(ri)
2839                    cw.nl()
2840
2841                if 'dump' in op:
2842                    cw.p(f"/* {op.enum_name} - dump */")
2843                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
2844                    print_req_type(ri)
2845                    print_req_type_helpers(ri)
2846                    if not ri.type_consistent:
2847                        print_rsp_type(ri)
2848                    print_wrapped_type(ri)
2849                    print_dump_prototype(ri)
2850                    cw.nl()
2851
2852                if op.has_ntf:
2853                    cw.p(f"/* {op.enum_name} - notify */")
2854                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2855                    if not ri.type_consistent:
2856                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2857                    print_wrapped_type(ri)
2858
2859            for op_name, op in parsed.ntfs.items():
2860                if 'event' in op:
2861                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
2862                    cw.p(f"/* {op.enum_name} - event */")
2863                    print_rsp_type(ri)
2864                    cw.nl()
2865                    print_wrapped_type(ri)
2866            cw.nl()
2867        else:
2868            cw.p('/* Enums */')
2869            put_op_name(parsed, cw)
2870
2871            for name, const in parsed.consts.items():
2872                if isinstance(const, EnumSet):
2873                    put_enum_to_str(parsed, cw, const)
2874            cw.nl()
2875
2876            has_recursive_nests = False
2877            cw.p('/* Policies */')
2878            for struct in parsed.pure_nested_structs.values():
2879                if struct.recursive:
2880                    put_typol_fwd(cw, struct)
2881                    has_recursive_nests = True
2882            if has_recursive_nests:
2883                cw.nl()
2884            for name in parsed.pure_nested_structs:
2885                struct = Struct(parsed, name)
2886                put_typol(cw, struct)
2887            for name in parsed.root_sets:
2888                struct = Struct(parsed, name)
2889                put_typol(cw, struct)
2890
2891            cw.p('/* Common nested types */')
2892            if has_recursive_nests:
2893                for attr_set, struct in parsed.pure_nested_structs.items():
2894                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2895                    free_rsp_nested_prototype(ri)
2896                    if struct.request:
2897                        put_req_nested_prototype(ri, struct)
2898                    if struct.reply:
2899                        parse_rsp_nested_prototype(ri, struct)
2900                cw.nl()
2901            for attr_set, struct in parsed.pure_nested_structs.items():
2902                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2903
2904                free_rsp_nested(ri, struct)
2905                if struct.request:
2906                    put_req_nested(ri, struct)
2907                if struct.reply:
2908                    parse_rsp_nested(ri, struct)
2909
2910            for op_name, op in parsed.ops.items():
2911                cw.p(f"/* ============== {op.enum_name} ============== */")
2912                if 'do' in op and 'event' not in op:
2913                    cw.p(f"/* {op.enum_name} - do */")
2914                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2915                    print_req_free(ri)
2916                    print_rsp_free(ri)
2917                    parse_rsp_msg(ri)
2918                    print_req(ri)
2919                    cw.nl()
2920
2921                if 'dump' in op:
2922                    cw.p(f"/* {op.enum_name} - dump */")
2923                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
2924                    if not ri.type_consistent:
2925                        parse_rsp_msg(ri, deref=True)
2926                    print_req_free(ri)
2927                    print_dump_type_free(ri)
2928                    print_dump(ri)
2929                    cw.nl()
2930
2931                if op.has_ntf:
2932                    cw.p(f"/* {op.enum_name} - notify */")
2933                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2934                    if not ri.type_consistent:
2935                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2936                    print_ntf_type_free(ri)
2937
2938            for op_name, op in parsed.ntfs.items():
2939                if 'event' in op:
2940                    cw.p(f"/* {op.enum_name} - event */")
2941
2942                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2943                    parse_rsp_msg(ri)
2944
2945                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
2946                    print_ntf_type_free(ri)
2947            cw.nl()
2948            render_user_family(parsed, cw, False)
2949
2950    if args.header:
2951        cw.p(f'#endif /* {hdr_prot} */')
2952
2953
2954if __name__ == "__main__":
2955    main()
2956