xref: /linux/tools/net/ynl/ynl-gen-c.py (revision fcc79e1714e8c2b8e216dc3149812edd37884eef)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import filecmp
7import pathlib
8import os
9import re
10import shutil
11import sys
12import tempfile
13import yaml
14
15sys.path.append(pathlib.Path(__file__).resolve().parent.as_posix())
16from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
17
18
19def c_upper(name):
20    return name.upper().replace('-', '_')
21
22
23def c_lower(name):
24    return name.lower().replace('-', '_')
25
26
27def limit_to_number(name):
28    """
29    Turn a string limit like u32-max or s64-min into its numerical value
30    """
31    if name[0] == 'u' and name.endswith('-min'):
32        return 0
33    width = int(name[1:-4])
34    if name[0] == 's':
35        width -= 1
36    value = (1 << width) - 1
37    if name[0] == 's' and name.endswith('-min'):
38        value = -value - 1
39    return value
40
41
42class BaseNlLib:
43    def get_family_id(self):
44        return 'ys->family_id'
45
46
47class Type(SpecAttr):
48    def __init__(self, family, attr_set, attr, value):
49        super().__init__(family, attr_set, attr, value)
50
51        self.attr = attr
52        self.attr_set = attr_set
53        self.type = attr['type']
54        self.checks = attr.get('checks', {})
55
56        self.request = False
57        self.reply = False
58
59        if 'len' in attr:
60            self.len = attr['len']
61
62        if 'nested-attributes' in attr:
63            self.nested_attrs = attr['nested-attributes']
64            if self.nested_attrs == family.name:
65                self.nested_render_name = c_lower(f"{family.ident_name}")
66            else:
67                self.nested_render_name = c_lower(f"{family.ident_name}_{self.nested_attrs}")
68
69            if self.nested_attrs in self.family.consts:
70                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
71            else:
72                self.nested_struct_type = 'struct ' + self.nested_render_name
73
74        self.c_name = c_lower(self.name)
75        if self.c_name in _C_KW:
76            self.c_name += '_'
77
78        # Added by resolve():
79        self.enum_name = None
80        delattr(self, "enum_name")
81
82    def get_limit(self, limit, default=None):
83        value = self.checks.get(limit, default)
84        if value is None:
85            return value
86        if isinstance(value, int):
87            return value
88        if value in self.family.consts:
89            raise Exception("Resolving family constants not implemented, yet")
90        return limit_to_number(value)
91
92    def get_limit_str(self, limit, default=None, suffix=''):
93        value = self.checks.get(limit, default)
94        if value is None:
95            return ''
96        if isinstance(value, int):
97            return str(value) + suffix
98        if value in self.family.consts:
99            return c_upper(f"{self.family['name']}-{value}")
100        return c_upper(value)
101
102    def resolve(self):
103        if 'name-prefix' in self.attr:
104            enum_name = f"{self.attr['name-prefix']}{self.name}"
105        else:
106            enum_name = f"{self.attr_set.name_prefix}{self.name}"
107        self.enum_name = c_upper(enum_name)
108
109    def is_multi_val(self):
110        return None
111
112    def is_scalar(self):
113        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
114
115    def is_recursive(self):
116        return False
117
118    def is_recursive_for_op(self, ri):
119        return self.is_recursive() and not ri.op
120
121    def presence_type(self):
122        return 'bit'
123
124    def presence_member(self, space, type_filter):
125        if self.presence_type() != type_filter:
126            return
127
128        if self.presence_type() == 'bit':
129            pfx = '__' if space == 'user' else ''
130            return f"{pfx}u32 {self.c_name}:1;"
131
132        if self.presence_type() == 'len':
133            pfx = '__' if space == 'user' else ''
134            return f"{pfx}u32 {self.c_name}_len;"
135
136    def _complex_member_type(self, ri):
137        return None
138
139    def free_needs_iter(self):
140        return False
141
142    def free(self, ri, var, ref):
143        if self.is_multi_val() or self.presence_type() == 'len':
144            ri.cw.p(f'free({var}->{ref}{self.c_name});')
145
146    def arg_member(self, ri):
147        member = self._complex_member_type(ri)
148        if member:
149            arg = [member + ' *' + self.c_name]
150            if self.presence_type() == 'count':
151                arg += ['unsigned int n_' + self.c_name]
152            return arg
153        raise Exception(f"Struct member not implemented for class type {self.type}")
154
155    def struct_member(self, ri):
156        if self.is_multi_val():
157            ri.cw.p(f"unsigned int n_{self.c_name};")
158        member = self._complex_member_type(ri)
159        if member:
160            ptr = '*' if self.is_multi_val() else ''
161            if self.is_recursive_for_op(ri):
162                ptr = '*'
163            ri.cw.p(f"{member} {ptr}{self.c_name};")
164            return
165        members = self.arg_member(ri)
166        for one in members:
167            ri.cw.p(one + ';')
168
169    def _attr_policy(self, policy):
170        return '{ .type = ' + policy + ', }'
171
172    def attr_policy(self, cw):
173        policy = f'NLA_{c_upper(self.type)}'
174        if self.attr.get('byte-order') == 'big-endian':
175            if self.type in {'u16', 'u32'}:
176                policy = f'NLA_BE{self.type[1:]}'
177
178        spec = self._attr_policy(policy)
179        cw.p(f"\t[{self.enum_name}] = {spec},")
180
181    def _attr_typol(self):
182        raise Exception(f"Type policy not implemented for class type {self.type}")
183
184    def attr_typol(self, cw):
185        typol = self._attr_typol()
186        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
187
188    def _attr_put_line(self, ri, var, line):
189        if self.presence_type() == 'bit':
190            ri.cw.p(f"if ({var}->_present.{self.c_name})")
191        elif self.presence_type() == 'len':
192            ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
193        ri.cw.p(f"{line};")
194
195    def _attr_put_simple(self, ri, var, put_type):
196        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
197        self._attr_put_line(ri, var, line)
198
199    def attr_put(self, ri, var):
200        raise Exception(f"Put not implemented for class type {self.type}")
201
202    def _attr_get(self, ri, var):
203        raise Exception(f"Attr get not implemented for class type {self.type}")
204
205    def attr_get(self, ri, var, first):
206        lines, init_lines, local_vars = self._attr_get(ri, var)
207        if type(lines) is str:
208            lines = [lines]
209        if type(init_lines) is str:
210            init_lines = [init_lines]
211
212        kw = 'if' if first else 'else if'
213        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
214        if local_vars:
215            for local in local_vars:
216                ri.cw.p(local)
217            ri.cw.nl()
218
219        if not self.is_multi_val():
220            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
221            ri.cw.p("return YNL_PARSE_CB_ERROR;")
222            if self.presence_type() == 'bit':
223                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
224
225        if init_lines:
226            ri.cw.nl()
227            for line in init_lines:
228                ri.cw.p(line)
229
230        for line in lines:
231            ri.cw.p(line)
232        ri.cw.block_end()
233        return True
234
235    def _setter_lines(self, ri, member, presence):
236        raise Exception(f"Setter not implemented for class type {self.type}")
237
238    def setter(self, ri, space, direction, deref=False, ref=None):
239        ref = (ref if ref else []) + [self.c_name]
240        var = "req"
241        member = f"{var}->{'.'.join(ref)}"
242
243        code = []
244        presence = ''
245        for i in range(0, len(ref)):
246            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
247            # Every layer below last is a nest, so we know it uses bit presence
248            # last layer is "self" and may be a complex type
249            if i == len(ref) - 1 and self.presence_type() != 'bit':
250                continue
251            code.append(presence + ' = 1;')
252        code += self._setter_lines(ri, member, presence)
253
254        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
255        free = bool([x for x in code if 'free(' in x])
256        alloc = bool([x for x in code if 'alloc(' in x])
257        if free and not alloc:
258            func_name = '__' + func_name
259        ri.cw.write_func('static inline void', func_name, body=code,
260                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
261
262
263class TypeUnused(Type):
264    def presence_type(self):
265        return ''
266
267    def arg_member(self, ri):
268        return []
269
270    def _attr_get(self, ri, var):
271        return ['return YNL_PARSE_CB_ERROR;'], None, None
272
273    def _attr_typol(self):
274        return '.type = YNL_PT_REJECT, '
275
276    def attr_policy(self, cw):
277        pass
278
279    def attr_put(self, ri, var):
280        pass
281
282    def attr_get(self, ri, var, first):
283        pass
284
285    def setter(self, ri, space, direction, deref=False, ref=None):
286        pass
287
288
289class TypePad(Type):
290    def presence_type(self):
291        return ''
292
293    def arg_member(self, ri):
294        return []
295
296    def _attr_typol(self):
297        return '.type = YNL_PT_IGNORE, '
298
299    def attr_put(self, ri, var):
300        pass
301
302    def attr_get(self, ri, var, first):
303        pass
304
305    def attr_policy(self, cw):
306        pass
307
308    def setter(self, ri, space, direction, deref=False, ref=None):
309        pass
310
311
312class TypeScalar(Type):
313    def __init__(self, family, attr_set, attr, value):
314        super().__init__(family, attr_set, attr, value)
315
316        self.byte_order_comment = ''
317        if 'byte-order' in attr:
318            self.byte_order_comment = f" /* {attr['byte-order']} */"
319
320        if 'enum' in self.attr:
321            enum = self.family.consts[self.attr['enum']]
322            low, high = enum.value_range()
323            if 'min' not in self.checks:
324                if low != 0 or self.type[0] == 's':
325                    self.checks['min'] = low
326            if 'max' not in self.checks:
327                self.checks['max'] = high
328
329        if 'min' in self.checks and 'max' in self.checks:
330            if self.get_limit('min') > self.get_limit('max'):
331                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
332            self.checks['range'] = True
333
334        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
335        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
336        if low < 0 and self.type[0] == 'u':
337            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
338        if low < -32768 or high > 32767:
339            self.checks['full-range'] = True
340
341        # Added by resolve():
342        self.is_bitfield = None
343        delattr(self, "is_bitfield")
344        self.type_name = None
345        delattr(self, "type_name")
346
347    def resolve(self):
348        self.resolve_up(super())
349
350        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
351            self.is_bitfield = True
352        elif 'enum' in self.attr:
353            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
354        else:
355            self.is_bitfield = False
356
357        if not self.is_bitfield and 'enum' in self.attr:
358            self.type_name = self.family.consts[self.attr['enum']].user_type
359        elif self.is_auto_scalar:
360            self.type_name = '__' + self.type[0] + '64'
361        else:
362            self.type_name = '__' + self.type
363
364    def _attr_policy(self, policy):
365        if 'flags-mask' in self.checks or self.is_bitfield:
366            if self.is_bitfield:
367                enum = self.family.consts[self.attr['enum']]
368                mask = enum.get_mask(as_flags=True)
369            else:
370                flags = self.family.consts[self.checks['flags-mask']]
371                flag_cnt = len(flags['entries'])
372                mask = (1 << flag_cnt) - 1
373            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
374        elif 'full-range' in self.checks:
375            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
376        elif 'range' in self.checks:
377            return f"NLA_POLICY_RANGE({policy}, {self.get_limit_str('min')}, {self.get_limit_str('max')})"
378        elif 'min' in self.checks:
379            return f"NLA_POLICY_MIN({policy}, {self.get_limit_str('min')})"
380        elif 'max' in self.checks:
381            return f"NLA_POLICY_MAX({policy}, {self.get_limit_str('max')})"
382        return super()._attr_policy(policy)
383
384    def _attr_typol(self):
385        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
386
387    def arg_member(self, ri):
388        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
389
390    def attr_put(self, ri, var):
391        self._attr_put_simple(ri, var, self.type)
392
393    def _attr_get(self, ri, var):
394        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
395
396    def _setter_lines(self, ri, member, presence):
397        return [f"{member} = {self.c_name};"]
398
399
400class TypeFlag(Type):
401    def arg_member(self, ri):
402        return []
403
404    def _attr_typol(self):
405        return '.type = YNL_PT_FLAG, '
406
407    def attr_put(self, ri, var):
408        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
409
410    def _attr_get(self, ri, var):
411        return [], None, None
412
413    def _setter_lines(self, ri, member, presence):
414        return []
415
416
417class TypeString(Type):
418    def arg_member(self, ri):
419        return [f"const char *{self.c_name}"]
420
421    def presence_type(self):
422        return 'len'
423
424    def struct_member(self, ri):
425        ri.cw.p(f"char *{self.c_name};")
426
427    def _attr_typol(self):
428        return f'.type = YNL_PT_NUL_STR, '
429
430    def _attr_policy(self, policy):
431        if 'exact-len' in self.checks:
432            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
433        else:
434            mem = '{ .type = ' + policy
435            if 'max-len' in self.checks:
436                mem += ', .len = ' + self.get_limit_str('max-len')
437            mem += ', }'
438        return mem
439
440    def attr_policy(self, cw):
441        if self.checks.get('unterminated-ok', False):
442            policy = 'NLA_STRING'
443        else:
444            policy = 'NLA_NUL_STRING'
445
446        spec = self._attr_policy(policy)
447        cw.p(f"\t[{self.enum_name}] = {spec},")
448
449    def attr_put(self, ri, var):
450        self._attr_put_simple(ri, var, 'str')
451
452    def _attr_get(self, ri, var):
453        len_mem = var + '->_present.' + self.c_name + '_len'
454        return [f"{len_mem} = len;",
455                f"{var}->{self.c_name} = malloc(len + 1);",
456                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
457                f"{var}->{self.c_name}[len] = 0;"], \
458               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
459               ['unsigned int len;']
460
461    def _setter_lines(self, ri, member, presence):
462        return [f"free({member});",
463                f"{presence}_len = strlen({self.c_name});",
464                f"{member} = malloc({presence}_len + 1);",
465                f'memcpy({member}, {self.c_name}, {presence}_len);',
466                f'{member}[{presence}_len] = 0;']
467
468
469class TypeBinary(Type):
470    def arg_member(self, ri):
471        return [f"const void *{self.c_name}", 'size_t len']
472
473    def presence_type(self):
474        return 'len'
475
476    def struct_member(self, ri):
477        ri.cw.p(f"void *{self.c_name};")
478
479    def _attr_typol(self):
480        return f'.type = YNL_PT_BINARY,'
481
482    def _attr_policy(self, policy):
483        if len(self.checks) == 0:
484            pass
485        elif len(self.checks) == 1:
486            check_name = list(self.checks)[0]
487            if check_name not in {'exact-len', 'min-len', 'max-len'}:
488                raise Exception('Unsupported check for binary type: ' + check_name)
489        else:
490            raise Exception('More than one check for binary type not implemented, yet')
491
492        if len(self.checks) == 0:
493            mem = '{ .type = NLA_BINARY, }'
494        elif 'exact-len' in self.checks:
495            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
496        elif 'min-len' in self.checks:
497            mem = '{ .len = ' + self.get_limit_str('min-len') + ', }'
498        elif 'max-len' in self.checks:
499            mem = 'NLA_POLICY_MAX_LEN(' + self.get_limit_str('max-len') + ')'
500
501        return mem
502
503    def attr_put(self, ri, var):
504        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
505                            f"{var}->{self.c_name}, {var}->_present.{self.c_name}_len)")
506
507    def _attr_get(self, ri, var):
508        len_mem = var + '->_present.' + self.c_name + '_len'
509        return [f"{len_mem} = len;",
510                f"{var}->{self.c_name} = malloc(len);",
511                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
512               ['len = ynl_attr_data_len(attr);'], \
513               ['unsigned int len;']
514
515    def _setter_lines(self, ri, member, presence):
516        return [f"free({member});",
517                f"{presence}_len = len;",
518                f"{member} = malloc({presence}_len);",
519                f'memcpy({member}, {self.c_name}, {presence}_len);']
520
521
522class TypeBitfield32(Type):
523    def _complex_member_type(self, ri):
524        return "struct nla_bitfield32"
525
526    def _attr_typol(self):
527        return f'.type = YNL_PT_BITFIELD32, '
528
529    def _attr_policy(self, policy):
530        if not 'enum' in self.attr:
531            raise Exception('Enum required for bitfield32 attr')
532        enum = self.family.consts[self.attr['enum']]
533        mask = enum.get_mask(as_flags=True)
534        return f"NLA_POLICY_BITFIELD32({mask})"
535
536    def attr_put(self, ri, var):
537        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
538        self._attr_put_line(ri, var, line)
539
540    def _attr_get(self, ri, var):
541        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
542
543    def _setter_lines(self, ri, member, presence):
544        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
545
546
547class TypeNest(Type):
548    def is_recursive(self):
549        return self.family.pure_nested_structs[self.nested_attrs].recursive
550
551    def _complex_member_type(self, ri):
552        return self.nested_struct_type
553
554    def free(self, ri, var, ref):
555        at = '&'
556        if self.is_recursive_for_op(ri):
557            at = ''
558            ri.cw.p(f'if ({var}->{ref}{self.c_name})')
559        ri.cw.p(f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});')
560
561    def _attr_typol(self):
562        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
563
564    def _attr_policy(self, policy):
565        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
566
567    def attr_put(self, ri, var):
568        at = '' if self.is_recursive_for_op(ri) else '&'
569        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
570                            f"{self.enum_name}, {at}{var}->{self.c_name})")
571
572    def _attr_get(self, ri, var):
573        get_lines = [f"if ({self.nested_render_name}_parse(&parg, attr))",
574                     "return YNL_PARSE_CB_ERROR;"]
575        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
576                      f"parg.data = &{var}->{self.c_name};"]
577        return get_lines, init_lines, None
578
579    def setter(self, ri, space, direction, deref=False, ref=None):
580        ref = (ref if ref else []) + [self.c_name]
581
582        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
583            if attr.is_recursive():
584                continue
585            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
586
587
588class TypeMultiAttr(Type):
589    def __init__(self, family, attr_set, attr, value, base_type):
590        super().__init__(family, attr_set, attr, value)
591
592        self.base_type = base_type
593
594    def is_multi_val(self):
595        return True
596
597    def presence_type(self):
598        return 'count'
599
600    def _complex_member_type(self, ri):
601        if 'type' not in self.attr or self.attr['type'] == 'nest':
602            return self.nested_struct_type
603        elif self.attr['type'] in scalars:
604            scalar_pfx = '__' if ri.ku_space == 'user' else ''
605            return scalar_pfx + self.attr['type']
606        else:
607            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
608
609    def free_needs_iter(self):
610        return 'type' not in self.attr or self.attr['type'] == 'nest'
611
612    def free(self, ri, var, ref):
613        if self.attr['type'] in scalars:
614            ri.cw.p(f"free({var}->{ref}{self.c_name});")
615        elif 'type' not in self.attr or self.attr['type'] == 'nest':
616            ri.cw.p(f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)")
617            ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);')
618            ri.cw.p(f"free({var}->{ref}{self.c_name});")
619        else:
620            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
621
622    def _attr_policy(self, policy):
623        return self.base_type._attr_policy(policy)
624
625    def _attr_typol(self):
626        return self.base_type._attr_typol()
627
628    def _attr_get(self, ri, var):
629        return f'n_{self.c_name}++;', None, None
630
631    def attr_put(self, ri, var):
632        if self.attr['type'] in scalars:
633            put_type = self.type
634            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
635            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
636        elif 'type' not in self.attr or self.attr['type'] == 'nest':
637            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
638            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
639                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
640        else:
641            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
642
643    def _setter_lines(self, ri, member, presence):
644        # For multi-attr we have a count, not presence, hack up the presence
645        presence = presence[:-(len('_present.') + len(self.c_name))] + "n_" + self.c_name
646        return [f"free({member});",
647                f"{member} = {self.c_name};",
648                f"{presence} = n_{self.c_name};"]
649
650
651class TypeArrayNest(Type):
652    def is_multi_val(self):
653        return True
654
655    def presence_type(self):
656        return 'count'
657
658    def _complex_member_type(self, ri):
659        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
660            return self.nested_struct_type
661        elif self.attr['sub-type'] in scalars:
662            scalar_pfx = '__' if ri.ku_space == 'user' else ''
663            return scalar_pfx + self.attr['sub-type']
664        else:
665            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
666
667    def _attr_typol(self):
668        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
669
670    def _attr_get(self, ri, var):
671        local_vars = ['const struct nlattr *attr2;']
672        get_lines = [f'attr_{self.c_name} = attr;',
673                     'ynl_attr_for_each_nested(attr2, attr)',
674                     f'\t{var}->n_{self.c_name}++;']
675        return get_lines, None, local_vars
676
677
678class TypeNestTypeValue(Type):
679    def _complex_member_type(self, ri):
680        return self.nested_struct_type
681
682    def _attr_typol(self):
683        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
684
685    def _attr_get(self, ri, var):
686        prev = 'attr'
687        tv_args = ''
688        get_lines = []
689        local_vars = []
690        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
691                      f"parg.data = &{var}->{self.c_name};"]
692        if 'type-value' in self.attr:
693            tv_names = [c_lower(x) for x in self.attr["type-value"]]
694            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
695            local_vars += [f'__u32 {", ".join(tv_names)};']
696            for level in self.attr["type-value"]:
697                level = c_lower(level)
698                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
699                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
700                prev = 'attr_' + level
701
702            tv_args = f", {', '.join(tv_names)}"
703
704        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
705        return get_lines, init_lines, local_vars
706
707
708class Struct:
709    def __init__(self, family, space_name, type_list=None, inherited=None):
710        self.family = family
711        self.space_name = space_name
712        self.attr_set = family.attr_sets[space_name]
713        # Use list to catch comparisons with empty sets
714        self._inherited = inherited if inherited is not None else []
715        self.inherited = []
716
717        self.nested = type_list is None
718        if family.name == c_lower(space_name):
719            self.render_name = c_lower(family.ident_name)
720        else:
721            self.render_name = c_lower(family.ident_name + '-' + space_name)
722        self.struct_name = 'struct ' + self.render_name
723        if self.nested and space_name in family.consts:
724            self.struct_name += '_'
725        self.ptr_name = self.struct_name + ' *'
726        # All attr sets this one contains, directly or multiple levels down
727        self.child_nests = set()
728
729        self.request = False
730        self.reply = False
731        self.recursive = False
732
733        self.attr_list = []
734        self.attrs = dict()
735        if type_list is not None:
736            for t in type_list:
737                self.attr_list.append((t, self.attr_set[t]),)
738        else:
739            for t in self.attr_set:
740                self.attr_list.append((t, self.attr_set[t]),)
741
742        max_val = 0
743        self.attr_max_val = None
744        for name, attr in self.attr_list:
745            if attr.value >= max_val:
746                max_val = attr.value
747                self.attr_max_val = attr
748            self.attrs[name] = attr
749
750    def __iter__(self):
751        yield from self.attrs
752
753    def __getitem__(self, key):
754        return self.attrs[key]
755
756    def member_list(self):
757        return self.attr_list
758
759    def set_inherited(self, new_inherited):
760        if self._inherited != new_inherited:
761            raise Exception("Inheriting different members not supported")
762        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
763
764
765class EnumEntry(SpecEnumEntry):
766    def __init__(self, enum_set, yaml, prev, value_start):
767        super().__init__(enum_set, yaml, prev, value_start)
768
769        if prev:
770            self.value_change = (self.value != prev.value + 1)
771        else:
772            self.value_change = (self.value != 0)
773        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
774
775        # Added by resolve:
776        self.c_name = None
777        delattr(self, "c_name")
778
779    def resolve(self):
780        self.resolve_up(super())
781
782        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
783
784
785class EnumSet(SpecEnumSet):
786    def __init__(self, family, yaml):
787        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
788
789        if 'enum-name' in yaml:
790            if yaml['enum-name']:
791                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
792                self.user_type = self.enum_name
793            else:
794                self.enum_name = None
795        else:
796            self.enum_name = 'enum ' + self.render_name
797
798        if self.enum_name:
799            self.user_type = self.enum_name
800        else:
801            self.user_type = 'int'
802
803        self.value_pfx = yaml.get('name-prefix', f"{family.ident_name}-{yaml['name']}-")
804
805        super().__init__(family, yaml)
806
807    def new_entry(self, entry, prev_entry, value_start):
808        return EnumEntry(self, entry, prev_entry, value_start)
809
810    def value_range(self):
811        low = min([x.value for x in self.entries.values()])
812        high = max([x.value for x in self.entries.values()])
813
814        if high - low + 1 != len(self.entries):
815            raise Exception("Can't get value range for a noncontiguous enum")
816
817        return low, high
818
819
820class AttrSet(SpecAttrSet):
821    def __init__(self, family, yaml):
822        super().__init__(family, yaml)
823
824        if self.subset_of is None:
825            if 'name-prefix' in yaml:
826                pfx = yaml['name-prefix']
827            elif self.name == family.name:
828                pfx = family.ident_name + '-a-'
829            else:
830                pfx = f"{family.ident_name}-a-{self.name}-"
831            self.name_prefix = c_upper(pfx)
832            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
833            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
834        else:
835            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
836            self.max_name = family.attr_sets[self.subset_of].max_name
837            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
838
839        # Added by resolve:
840        self.c_name = None
841        delattr(self, "c_name")
842
843    def resolve(self):
844        self.c_name = c_lower(self.name)
845        if self.c_name in _C_KW:
846            self.c_name += '_'
847        if self.c_name == self.family.c_name:
848            self.c_name = ''
849
850    def new_attr(self, elem, value):
851        if elem['type'] in scalars:
852            t = TypeScalar(self.family, self, elem, value)
853        elif elem['type'] == 'unused':
854            t = TypeUnused(self.family, self, elem, value)
855        elif elem['type'] == 'pad':
856            t = TypePad(self.family, self, elem, value)
857        elif elem['type'] == 'flag':
858            t = TypeFlag(self.family, self, elem, value)
859        elif elem['type'] == 'string':
860            t = TypeString(self.family, self, elem, value)
861        elif elem['type'] == 'binary':
862            t = TypeBinary(self.family, self, elem, value)
863        elif elem['type'] == 'bitfield32':
864            t = TypeBitfield32(self.family, self, elem, value)
865        elif elem['type'] == 'nest':
866            t = TypeNest(self.family, self, elem, value)
867        elif elem['type'] == 'indexed-array' and 'sub-type' in elem:
868            if elem["sub-type"] == 'nest':
869                t = TypeArrayNest(self.family, self, elem, value)
870            else:
871                raise Exception(f'new_attr: unsupported sub-type {elem["sub-type"]}')
872        elif elem['type'] == 'nest-type-value':
873            t = TypeNestTypeValue(self.family, self, elem, value)
874        else:
875            raise Exception(f"No typed class for type {elem['type']}")
876
877        if 'multi-attr' in elem and elem['multi-attr']:
878            t = TypeMultiAttr(self.family, self, elem, value, t)
879
880        return t
881
882
883class Operation(SpecOperation):
884    def __init__(self, family, yaml, req_value, rsp_value):
885        super().__init__(family, yaml, req_value, rsp_value)
886
887        self.render_name = c_lower(family.ident_name + '_' + self.name)
888
889        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
890                         ('dump' in yaml and 'request' in yaml['dump'])
891
892        self.has_ntf = False
893
894        # Added by resolve:
895        self.enum_name = None
896        delattr(self, "enum_name")
897
898    def resolve(self):
899        self.resolve_up(super())
900
901        if not self.is_async:
902            self.enum_name = self.family.op_prefix + c_upper(self.name)
903        else:
904            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
905
906    def mark_has_ntf(self):
907        self.has_ntf = True
908
909
910class Family(SpecFamily):
911    def __init__(self, file_name, exclude_ops):
912        # Added by resolve:
913        self.c_name = None
914        delattr(self, "c_name")
915        self.op_prefix = None
916        delattr(self, "op_prefix")
917        self.async_op_prefix = None
918        delattr(self, "async_op_prefix")
919        self.mcgrps = None
920        delattr(self, "mcgrps")
921        self.consts = None
922        delattr(self, "consts")
923        self.hooks = None
924        delattr(self, "hooks")
925
926        super().__init__(file_name, exclude_ops=exclude_ops)
927
928        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
929        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
930
931        if 'definitions' not in self.yaml:
932            self.yaml['definitions'] = []
933
934        if 'uapi-header' in self.yaml:
935            self.uapi_header = self.yaml['uapi-header']
936        else:
937            self.uapi_header = f"linux/{self.ident_name}.h"
938        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
939            self.uapi_header_name = self.uapi_header[6:-2]
940        else:
941            self.uapi_header_name = self.ident_name
942
943    def resolve(self):
944        self.resolve_up(super())
945
946        if self.yaml.get('protocol', 'genetlink') not in {'genetlink', 'genetlink-c', 'genetlink-legacy'}:
947            raise Exception("Codegen only supported for genetlink")
948
949        self.c_name = c_lower(self.ident_name)
950        if 'name-prefix' in self.yaml['operations']:
951            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
952        else:
953            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
954        if 'async-prefix' in self.yaml['operations']:
955            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
956        else:
957            self.async_op_prefix = self.op_prefix
958
959        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
960
961        self.hooks = dict()
962        for when in ['pre', 'post']:
963            self.hooks[when] = dict()
964            for op_mode in ['do', 'dump']:
965                self.hooks[when][op_mode] = dict()
966                self.hooks[when][op_mode]['set'] = set()
967                self.hooks[when][op_mode]['list'] = []
968
969        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
970        self.root_sets = dict()
971        # dict space-name -> set('request', 'reply')
972        self.pure_nested_structs = dict()
973
974        self._mark_notify()
975        self._mock_up_events()
976
977        self._load_root_sets()
978        self._load_nested_sets()
979        self._load_attr_use()
980        self._load_hooks()
981
982        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
983        if self.kernel_policy == 'global':
984            self._load_global_policy()
985
986    def new_enum(self, elem):
987        return EnumSet(self, elem)
988
989    def new_attr_set(self, elem):
990        return AttrSet(self, elem)
991
992    def new_operation(self, elem, req_value, rsp_value):
993        return Operation(self, elem, req_value, rsp_value)
994
995    def _mark_notify(self):
996        for op in self.msgs.values():
997            if 'notify' in op:
998                self.ops[op['notify']].mark_has_ntf()
999
1000    # Fake a 'do' equivalent of all events, so that we can render their response parsing
1001    def _mock_up_events(self):
1002        for op in self.yaml['operations']['list']:
1003            if 'event' in op:
1004                op['do'] = {
1005                    'reply': {
1006                        'attributes': op['event']['attributes']
1007                    }
1008                }
1009
1010    def _load_root_sets(self):
1011        for op_name, op in self.msgs.items():
1012            if 'attribute-set' not in op:
1013                continue
1014
1015            req_attrs = set()
1016            rsp_attrs = set()
1017            for op_mode in ['do', 'dump']:
1018                if op_mode in op and 'request' in op[op_mode]:
1019                    req_attrs.update(set(op[op_mode]['request']['attributes']))
1020                if op_mode in op and 'reply' in op[op_mode]:
1021                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
1022            if 'event' in op:
1023                rsp_attrs.update(set(op['event']['attributes']))
1024
1025            if op['attribute-set'] not in self.root_sets:
1026                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1027            else:
1028                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1029                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1030
1031    def _sort_pure_types(self):
1032        # Try to reorder according to dependencies
1033        pns_key_list = list(self.pure_nested_structs.keys())
1034        pns_key_seen = set()
1035        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1036        for _ in range(rounds):
1037            if len(pns_key_list) == 0:
1038                break
1039            name = pns_key_list.pop(0)
1040            finished = True
1041            for _, spec in self.attr_sets[name].items():
1042                if 'nested-attributes' in spec:
1043                    nested = spec['nested-attributes']
1044                    # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1045                    if self.pure_nested_structs[nested].recursive:
1046                        continue
1047                    if nested not in pns_key_seen:
1048                        # Dicts are sorted, this will make struct last
1049                        struct = self.pure_nested_structs.pop(name)
1050                        self.pure_nested_structs[name] = struct
1051                        finished = False
1052                        break
1053            if finished:
1054                pns_key_seen.add(name)
1055            else:
1056                pns_key_list.append(name)
1057
1058    def _load_nested_sets(self):
1059        attr_set_queue = list(self.root_sets.keys())
1060        attr_set_seen = set(self.root_sets.keys())
1061
1062        while len(attr_set_queue):
1063            a_set = attr_set_queue.pop(0)
1064            for attr, spec in self.attr_sets[a_set].items():
1065                if 'nested-attributes' not in spec:
1066                    continue
1067
1068                nested = spec['nested-attributes']
1069                if nested not in attr_set_seen:
1070                    attr_set_queue.append(nested)
1071                    attr_set_seen.add(nested)
1072
1073                inherit = set()
1074                if nested not in self.root_sets:
1075                    if nested not in self.pure_nested_structs:
1076                        self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
1077                else:
1078                    raise Exception(f'Using attr set as root and nested not supported - {nested}')
1079
1080                if 'type-value' in spec:
1081                    if nested in self.root_sets:
1082                        raise Exception("Inheriting members to a space used as root not supported")
1083                    inherit.update(set(spec['type-value']))
1084                elif spec['type'] == 'indexed-array':
1085                    inherit.add('idx')
1086                self.pure_nested_structs[nested].set_inherited(inherit)
1087
1088        for root_set, rs_members in self.root_sets.items():
1089            for attr, spec in self.attr_sets[root_set].items():
1090                if 'nested-attributes' in spec:
1091                    nested = spec['nested-attributes']
1092                    if attr in rs_members['request']:
1093                        self.pure_nested_structs[nested].request = True
1094                    if attr in rs_members['reply']:
1095                        self.pure_nested_structs[nested].reply = True
1096
1097        self._sort_pure_types()
1098
1099        # Propagate the request / reply / recursive
1100        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1101            for _, spec in self.attr_sets[attr_set].items():
1102                if 'nested-attributes' in spec:
1103                    child_name = spec['nested-attributes']
1104                    struct.child_nests.add(child_name)
1105                    child = self.pure_nested_structs.get(child_name)
1106                    if child:
1107                        if not child.recursive:
1108                            struct.child_nests.update(child.child_nests)
1109                        child.request |= struct.request
1110                        child.reply |= struct.reply
1111                if attr_set in struct.child_nests:
1112                    struct.recursive = True
1113
1114        self._sort_pure_types()
1115
1116    def _load_attr_use(self):
1117        for _, struct in self.pure_nested_structs.items():
1118            if struct.request:
1119                for _, arg in struct.member_list():
1120                    arg.request = True
1121            if struct.reply:
1122                for _, arg in struct.member_list():
1123                    arg.reply = True
1124
1125        for root_set, rs_members in self.root_sets.items():
1126            for attr, spec in self.attr_sets[root_set].items():
1127                if attr in rs_members['request']:
1128                    spec.request = True
1129                if attr in rs_members['reply']:
1130                    spec.reply = True
1131
1132    def _load_global_policy(self):
1133        global_set = set()
1134        attr_set_name = None
1135        for op_name, op in self.ops.items():
1136            if not op:
1137                continue
1138            if 'attribute-set' not in op:
1139                continue
1140
1141            if attr_set_name is None:
1142                attr_set_name = op['attribute-set']
1143            if attr_set_name != op['attribute-set']:
1144                raise Exception('For a global policy all ops must use the same set')
1145
1146            for op_mode in ['do', 'dump']:
1147                if op_mode in op:
1148                    req = op[op_mode].get('request')
1149                    if req:
1150                        global_set.update(req.get('attributes', []))
1151
1152        self.global_policy = []
1153        self.global_policy_set = attr_set_name
1154        for attr in self.attr_sets[attr_set_name]:
1155            if attr in global_set:
1156                self.global_policy.append(attr)
1157
1158    def _load_hooks(self):
1159        for op in self.ops.values():
1160            for op_mode in ['do', 'dump']:
1161                if op_mode not in op:
1162                    continue
1163                for when in ['pre', 'post']:
1164                    if when not in op[op_mode]:
1165                        continue
1166                    name = op[op_mode][when]
1167                    if name in self.hooks[when][op_mode]['set']:
1168                        continue
1169                    self.hooks[when][op_mode]['set'].add(name)
1170                    self.hooks[when][op_mode]['list'].append(name)
1171
1172
1173class RenderInfo:
1174    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1175        self.family = family
1176        self.nl = cw.nlib
1177        self.ku_space = ku_space
1178        self.op_mode = op_mode
1179        self.op = op
1180
1181        self.fixed_hdr = None
1182        if op and op.fixed_header:
1183            self.fixed_hdr = 'struct ' + c_lower(op.fixed_header)
1184
1185        # 'do' and 'dump' response parsing is identical
1186        self.type_consistent = True
1187        if op_mode != 'do' and 'dump' in op:
1188            if 'do' in op:
1189                if ('reply' in op['do']) != ('reply' in op["dump"]):
1190                    self.type_consistent = False
1191                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1192                    self.type_consistent = False
1193            else:
1194                self.type_consistent = False
1195
1196        self.attr_set = attr_set
1197        if not self.attr_set:
1198            self.attr_set = op['attribute-set']
1199
1200        self.type_name_conflict = False
1201        if op:
1202            self.type_name = c_lower(op.name)
1203        else:
1204            self.type_name = c_lower(attr_set)
1205            if attr_set in family.consts:
1206                self.type_name_conflict = True
1207
1208        self.cw = cw
1209
1210        self.struct = dict()
1211        if op_mode == 'notify':
1212            op_mode = 'do'
1213        for op_dir in ['request', 'reply']:
1214            if op:
1215                type_list = []
1216                if op_dir in op[op_mode]:
1217                    type_list = op[op_mode][op_dir]['attributes']
1218                self.struct[op_dir] = Struct(family, self.attr_set, type_list=type_list)
1219        if op_mode == 'event':
1220            self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
1221
1222
1223class CodeWriter:
1224    def __init__(self, nlib, out_file=None, overwrite=True):
1225        self.nlib = nlib
1226        self._overwrite = overwrite
1227
1228        self._nl = False
1229        self._block_end = False
1230        self._silent_block = False
1231        self._ind = 0
1232        self._ifdef_block = None
1233        if out_file is None:
1234            self._out = os.sys.stdout
1235        else:
1236            self._out = tempfile.NamedTemporaryFile('w+')
1237            self._out_file = out_file
1238
1239    def __del__(self):
1240        self.close_out_file()
1241
1242    def close_out_file(self):
1243        if self._out == os.sys.stdout:
1244            return
1245        # Avoid modifying the file if contents didn't change
1246        self._out.flush()
1247        if not self._overwrite and os.path.isfile(self._out_file):
1248            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1249                return
1250        with open(self._out_file, 'w+') as out_file:
1251            self._out.seek(0)
1252            shutil.copyfileobj(self._out, out_file)
1253            self._out.close()
1254        self._out = os.sys.stdout
1255
1256    @classmethod
1257    def _is_cond(cls, line):
1258        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1259
1260    def p(self, line, add_ind=0):
1261        if self._block_end:
1262            self._block_end = False
1263            if line.startswith('else'):
1264                line = '} ' + line
1265            else:
1266                self._out.write('\t' * self._ind + '}\n')
1267
1268        if self._nl:
1269            self._out.write('\n')
1270            self._nl = False
1271
1272        ind = self._ind
1273        if line[-1] == ':':
1274            ind -= 1
1275        if self._silent_block:
1276            ind += 1
1277        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1278        if line[0] == '#':
1279            ind = 0
1280        if add_ind:
1281            ind += add_ind
1282        self._out.write('\t' * ind + line + '\n')
1283
1284    def nl(self):
1285        self._nl = True
1286
1287    def block_start(self, line=''):
1288        if line:
1289            line = line + ' '
1290        self.p(line + '{')
1291        self._ind += 1
1292
1293    def block_end(self, line=''):
1294        if line and line[0] not in {';', ','}:
1295            line = ' ' + line
1296        self._ind -= 1
1297        self._nl = False
1298        if not line:
1299            # Delay printing closing bracket in case "else" comes next
1300            if self._block_end:
1301                self._out.write('\t' * (self._ind + 1) + '}\n')
1302            self._block_end = True
1303        else:
1304            self.p('}' + line)
1305
1306    def write_doc_line(self, doc, indent=True):
1307        words = doc.split()
1308        line = ' *'
1309        for word in words:
1310            if len(line) + len(word) >= 79:
1311                self.p(line)
1312                line = ' *'
1313                if indent:
1314                    line += '  '
1315            line += ' ' + word
1316        self.p(line)
1317
1318    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1319        if not args:
1320            args = ['void']
1321
1322        if doc:
1323            self.p('/*')
1324            self.p(' * ' + doc)
1325            self.p(' */')
1326
1327        oneline = qual_ret
1328        if qual_ret[-1] != '*':
1329            oneline += ' '
1330        oneline += f"{name}({', '.join(args)}){suffix}"
1331
1332        if len(oneline) < 80:
1333            self.p(oneline)
1334            return
1335
1336        v = qual_ret
1337        if len(v) > 3:
1338            self.p(v)
1339            v = ''
1340        elif qual_ret[-1] != '*':
1341            v += ' '
1342        v += name + '('
1343        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1344        delta_ind = len(v) - len(ind)
1345        v += args[0]
1346        i = 1
1347        while i < len(args):
1348            next_len = len(v) + len(args[i])
1349            if v[0] == '\t':
1350                next_len += delta_ind
1351            if next_len > 76:
1352                self.p(v + ',')
1353                v = ind
1354            else:
1355                v += ', '
1356            v += args[i]
1357            i += 1
1358        self.p(v + ')' + suffix)
1359
1360    def write_func_lvar(self, local_vars):
1361        if not local_vars:
1362            return
1363
1364        if type(local_vars) is str:
1365            local_vars = [local_vars]
1366
1367        local_vars.sort(key=len, reverse=True)
1368        for var in local_vars:
1369            self.p(var)
1370        self.nl()
1371
1372    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1373        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1374        self.write_func_lvar(local_vars=local_vars)
1375
1376        self.block_start()
1377        for line in body:
1378            self.p(line)
1379        self.block_end()
1380
1381    def writes_defines(self, defines):
1382        longest = 0
1383        for define in defines:
1384            if len(define[0]) > longest:
1385                longest = len(define[0])
1386        longest = ((longest + 8) // 8) * 8
1387        for define in defines:
1388            line = '#define ' + define[0]
1389            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1390            if type(define[1]) is int:
1391                line += str(define[1])
1392            elif type(define[1]) is str:
1393                line += '"' + define[1] + '"'
1394            self.p(line)
1395
1396    def write_struct_init(self, members):
1397        longest = max([len(x[0]) for x in members])
1398        longest += 1  # because we prepend a .
1399        longest = ((longest + 8) // 8) * 8
1400        for one in members:
1401            line = '.' + one[0]
1402            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1403            line += '= ' + str(one[1]) + ','
1404            self.p(line)
1405
1406    def ifdef_block(self, config):
1407        config_option = None
1408        if config:
1409            config_option = 'CONFIG_' + c_upper(config)
1410        if self._ifdef_block == config_option:
1411            return
1412
1413        if self._ifdef_block:
1414            self.p('#endif /* ' + self._ifdef_block + ' */')
1415        if config_option:
1416            self.p('#ifdef ' + config_option)
1417        self._ifdef_block = config_option
1418
1419
1420scalars = {'u8', 'u16', 'u32', 'u64', 's32', 's64', 'uint', 'sint'}
1421
1422direction_to_suffix = {
1423    'reply': '_rsp',
1424    'request': '_req',
1425    '': ''
1426}
1427
1428op_mode_to_wrapper = {
1429    'do': '',
1430    'dump': '_list',
1431    'notify': '_ntf',
1432    'event': '',
1433}
1434
1435_C_KW = {
1436    'auto',
1437    'bool',
1438    'break',
1439    'case',
1440    'char',
1441    'const',
1442    'continue',
1443    'default',
1444    'do',
1445    'double',
1446    'else',
1447    'enum',
1448    'extern',
1449    'float',
1450    'for',
1451    'goto',
1452    'if',
1453    'inline',
1454    'int',
1455    'long',
1456    'register',
1457    'return',
1458    'short',
1459    'signed',
1460    'sizeof',
1461    'static',
1462    'struct',
1463    'switch',
1464    'typedef',
1465    'union',
1466    'unsigned',
1467    'void',
1468    'volatile',
1469    'while'
1470}
1471
1472
1473def rdir(direction):
1474    if direction == 'reply':
1475        return 'request'
1476    if direction == 'request':
1477        return 'reply'
1478    return direction
1479
1480
1481def op_prefix(ri, direction, deref=False):
1482    suffix = f"_{ri.type_name}"
1483
1484    if not ri.op_mode or ri.op_mode == 'do':
1485        suffix += f"{direction_to_suffix[direction]}"
1486    else:
1487        if direction == 'request':
1488            suffix += '_req_dump'
1489        else:
1490            if ri.type_consistent:
1491                if deref:
1492                    suffix += f"{direction_to_suffix[direction]}"
1493                else:
1494                    suffix += op_mode_to_wrapper[ri.op_mode]
1495            else:
1496                suffix += '_rsp'
1497                suffix += '_dump' if deref else '_list'
1498
1499    return f"{ri.family.c_name}{suffix}"
1500
1501
1502def type_name(ri, direction, deref=False):
1503    return f"struct {op_prefix(ri, direction, deref=deref)}"
1504
1505
1506def print_prototype(ri, direction, terminate=True, doc=None):
1507    suffix = ';' if terminate else ''
1508
1509    fname = ri.op.render_name
1510    if ri.op_mode == 'dump':
1511        fname += '_dump'
1512
1513    args = ['struct ynl_sock *ys']
1514    if 'request' in ri.op[ri.op_mode]:
1515        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1516
1517    ret = 'int'
1518    if 'reply' in ri.op[ri.op_mode]:
1519        ret = f"{type_name(ri, rdir(direction))} *"
1520
1521    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1522
1523
1524def print_req_prototype(ri):
1525    print_prototype(ri, "request", doc=ri.op['doc'])
1526
1527
1528def print_dump_prototype(ri):
1529    print_prototype(ri, "request")
1530
1531
1532def put_typol_fwd(cw, struct):
1533    cw.p(f'extern const struct ynl_policy_nest {struct.render_name}_nest;')
1534
1535
1536def put_typol(cw, struct):
1537    type_max = struct.attr_set.max_name
1538    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1539
1540    for _, arg in struct.member_list():
1541        arg.attr_typol(cw)
1542
1543    cw.block_end(line=';')
1544    cw.nl()
1545
1546    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1547    cw.p(f'.max_attr = {type_max},')
1548    cw.p(f'.table = {struct.render_name}_policy,')
1549    cw.block_end(line=';')
1550    cw.nl()
1551
1552
1553def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1554    args = [f'int {arg_name}']
1555    if enum:
1556        args = [enum.user_type + ' ' + arg_name]
1557    cw.write_func_prot('const char *', f'{render_name}_str', args)
1558    cw.block_start()
1559    if enum and enum.type == 'flags':
1560        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1561    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
1562    cw.p('return NULL;')
1563    cw.p(f'return {map_name}[{arg_name}];')
1564    cw.block_end()
1565    cw.nl()
1566
1567
1568def put_op_name_fwd(family, cw):
1569    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
1570
1571
1572def put_op_name(family, cw):
1573    map_name = f'{family.c_name}_op_strmap'
1574    cw.block_start(line=f"static const char * const {map_name}[] =")
1575    for op_name, op in family.msgs.items():
1576        if op.rsp_value:
1577            # Make sure we don't add duplicated entries, if multiple commands
1578            # produce the same response in legacy families.
1579            if family.rsp_by_value[op.rsp_value] != op:
1580                cw.p(f'// skip "{op_name}", duplicate reply value')
1581                continue
1582
1583            if op.req_value == op.rsp_value:
1584                cw.p(f'[{op.enum_name}] = "{op_name}",')
1585            else:
1586                cw.p(f'[{op.rsp_value}] = "{op_name}",')
1587    cw.block_end(line=';')
1588    cw.nl()
1589
1590    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
1591
1592
1593def put_enum_to_str_fwd(family, cw, enum):
1594    args = [enum.user_type + ' value']
1595    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1596
1597
1598def put_enum_to_str(family, cw, enum):
1599    map_name = f'{enum.render_name}_strmap'
1600    cw.block_start(line=f"static const char * const {map_name}[] =")
1601    for entry in enum.entries.values():
1602        cw.p(f'[{entry.value}] = "{entry.name}",')
1603    cw.block_end(line=';')
1604    cw.nl()
1605
1606    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
1607
1608
1609def put_req_nested_prototype(ri, struct, suffix=';'):
1610    func_args = ['struct nlmsghdr *nlh',
1611                 'unsigned int attr_type',
1612                 f'{struct.ptr_name}obj']
1613
1614    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
1615                          suffix=suffix)
1616
1617
1618def put_req_nested(ri, struct):
1619    put_req_nested_prototype(ri, struct, suffix='')
1620    ri.cw.block_start()
1621    ri.cw.write_func_lvar('struct nlattr *nest;')
1622
1623    ri.cw.p("nest = ynl_attr_nest_start(nlh, attr_type);")
1624
1625    for _, arg in struct.member_list():
1626        arg.attr_put(ri, "obj")
1627
1628    ri.cw.p("ynl_attr_nest_end(nlh, nest);")
1629
1630    ri.cw.nl()
1631    ri.cw.p('return 0;')
1632    ri.cw.block_end()
1633    ri.cw.nl()
1634
1635
1636def _multi_parse(ri, struct, init_lines, local_vars):
1637    if struct.nested:
1638        iter_line = "ynl_attr_for_each_nested(attr, nested)"
1639    else:
1640        if ri.fixed_hdr:
1641            local_vars += ['void *hdr;']
1642        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
1643
1644    array_nests = set()
1645    multi_attrs = set()
1646    needs_parg = False
1647    for arg, aspec in struct.member_list():
1648        if aspec['type'] == 'indexed-array' and 'sub-type' in aspec:
1649            if aspec["sub-type"] == 'nest':
1650                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1651                array_nests.add(arg)
1652            else:
1653                raise Exception(f'Not supported sub-type {aspec["sub-type"]}')
1654        if 'multi-attr' in aspec:
1655            multi_attrs.add(arg)
1656        needs_parg |= 'nested-attributes' in aspec
1657    if array_nests or multi_attrs:
1658        local_vars.append('int i;')
1659    if needs_parg:
1660        local_vars.append('struct ynl_parse_arg parg;')
1661        init_lines.append('parg.ys = yarg->ys;')
1662
1663    all_multi = array_nests | multi_attrs
1664
1665    for anest in sorted(all_multi):
1666        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
1667
1668    ri.cw.block_start()
1669    ri.cw.write_func_lvar(local_vars)
1670
1671    for line in init_lines:
1672        ri.cw.p(line)
1673    ri.cw.nl()
1674
1675    for arg in struct.inherited:
1676        ri.cw.p(f'dst->{arg} = {arg};')
1677
1678    if ri.fixed_hdr:
1679        ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
1680        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({ri.fixed_hdr}));")
1681    for anest in sorted(all_multi):
1682        aspec = struct[anest]
1683        ri.cw.p(f"if (dst->{aspec.c_name})")
1684        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
1685
1686    ri.cw.nl()
1687    ri.cw.block_start(line=iter_line)
1688    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
1689    ri.cw.nl()
1690
1691    first = True
1692    for _, arg in struct.member_list():
1693        good = arg.attr_get(ri, 'dst', first=first)
1694        # First may be 'unused' or 'pad', ignore those
1695        first &= not good
1696
1697    ri.cw.block_end()
1698    ri.cw.nl()
1699
1700    for anest in sorted(array_nests):
1701        aspec = struct[anest]
1702
1703        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1704        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1705        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1706        ri.cw.p('i = 0;')
1707        ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1708        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1709        ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1710        ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
1711        ri.cw.p('return YNL_PARSE_CB_ERROR;')
1712        ri.cw.p('i++;')
1713        ri.cw.block_end()
1714        ri.cw.block_end()
1715    ri.cw.nl()
1716
1717    for anest in sorted(multi_attrs):
1718        aspec = struct[anest]
1719        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1720        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1721        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1722        ri.cw.p('i = 0;')
1723        if 'nested-attributes' in aspec:
1724            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1725        ri.cw.block_start(line=iter_line)
1726        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
1727        if 'nested-attributes' in aspec:
1728            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1729            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1730            ri.cw.p('return YNL_PARSE_CB_ERROR;')
1731        elif aspec.type in scalars:
1732            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
1733        else:
1734            raise Exception('Nest parsing type not supported yet')
1735        ri.cw.p('i++;')
1736        ri.cw.block_end()
1737        ri.cw.block_end()
1738        ri.cw.block_end()
1739    ri.cw.nl()
1740
1741    if struct.nested:
1742        ri.cw.p('return 0;')
1743    else:
1744        ri.cw.p('return YNL_PARSE_CB_OK;')
1745    ri.cw.block_end()
1746    ri.cw.nl()
1747
1748
1749def parse_rsp_nested_prototype(ri, struct, suffix=';'):
1750    func_args = ['struct ynl_parse_arg *yarg',
1751                 'const struct nlattr *nested']
1752    for arg in struct.inherited:
1753        func_args.append('__u32 ' + arg)
1754
1755    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
1756                          suffix=suffix)
1757
1758
1759def parse_rsp_nested(ri, struct):
1760    parse_rsp_nested_prototype(ri, struct, suffix='')
1761
1762    local_vars = ['const struct nlattr *attr;',
1763                  f'{struct.ptr_name}dst = yarg->data;']
1764    init_lines = []
1765
1766    _multi_parse(ri, struct, init_lines, local_vars)
1767
1768
1769def parse_rsp_msg(ri, deref=False):
1770    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1771        return
1772
1773    func_args = ['const struct nlmsghdr *nlh',
1774                 'struct ynl_parse_arg *yarg']
1775
1776    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1777                  'const struct nlattr *attr;']
1778    init_lines = ['dst = yarg->data;']
1779
1780    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1781
1782    if ri.struct["reply"].member_list():
1783        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1784    else:
1785        # Empty reply
1786        ri.cw.block_start()
1787        ri.cw.p('return YNL_PARSE_CB_OK;')
1788        ri.cw.block_end()
1789        ri.cw.nl()
1790
1791
1792def print_req(ri):
1793    ret_ok = '0'
1794    ret_err = '-1'
1795    direction = "request"
1796    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
1797                  'struct nlmsghdr *nlh;',
1798                  'int err;']
1799
1800    if 'reply' in ri.op[ri.op_mode]:
1801        ret_ok = 'rsp'
1802        ret_err = 'NULL'
1803        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
1804
1805    if ri.fixed_hdr:
1806        local_vars += ['size_t hdr_len;',
1807                       'void *hdr;']
1808
1809    print_prototype(ri, direction, terminate=False)
1810    ri.cw.block_start()
1811    ri.cw.write_func_lvar(local_vars)
1812
1813    ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1814
1815    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1816    if 'reply' in ri.op[ri.op_mode]:
1817        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1818    ri.cw.nl()
1819
1820    if ri.fixed_hdr:
1821        ri.cw.p("hdr_len = sizeof(req->_hdr);")
1822        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
1823        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
1824        ri.cw.nl()
1825
1826    for _, attr in ri.struct["request"].member_list():
1827        attr.attr_put(ri, "req")
1828    ri.cw.nl()
1829
1830    if 'reply' in ri.op[ri.op_mode]:
1831        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1832        ri.cw.p('yrs.yarg.data = rsp;')
1833        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
1834        if ri.op.value is not None:
1835            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
1836        else:
1837            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
1838        ri.cw.nl()
1839    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
1840    ri.cw.p('if (err < 0)')
1841    if 'reply' in ri.op[ri.op_mode]:
1842        ri.cw.p('goto err_free;')
1843    else:
1844        ri.cw.p('return -1;')
1845    ri.cw.nl()
1846
1847    ri.cw.p(f"return {ret_ok};")
1848    ri.cw.nl()
1849
1850    if 'reply' in ri.op[ri.op_mode]:
1851        ri.cw.p('err_free:')
1852        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1853        ri.cw.p(f"return {ret_err};")
1854
1855    ri.cw.block_end()
1856
1857
1858def print_dump(ri):
1859    direction = "request"
1860    print_prototype(ri, direction, terminate=False)
1861    ri.cw.block_start()
1862    local_vars = ['struct ynl_dump_state yds = {};',
1863                  'struct nlmsghdr *nlh;',
1864                  'int err;']
1865
1866    if ri.fixed_hdr:
1867        local_vars += ['size_t hdr_len;',
1868                       'void *hdr;']
1869
1870    ri.cw.write_func_lvar(local_vars)
1871
1872    ri.cw.p('yds.yarg.ys = ys;')
1873    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1874    ri.cw.p("yds.yarg.data = NULL;")
1875    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1876    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1877    if ri.op.value is not None:
1878        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
1879    else:
1880        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
1881    ri.cw.nl()
1882    ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1883
1884    if ri.fixed_hdr:
1885        ri.cw.p("hdr_len = sizeof(req->_hdr);")
1886        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
1887        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
1888        ri.cw.nl()
1889
1890    if "request" in ri.op[ri.op_mode]:
1891        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1892        ri.cw.nl()
1893        for _, attr in ri.struct["request"].member_list():
1894            attr.attr_put(ri, "req")
1895    ri.cw.nl()
1896
1897    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
1898    ri.cw.p('if (err < 0)')
1899    ri.cw.p('goto free_list;')
1900    ri.cw.nl()
1901
1902    ri.cw.p('return yds.first;')
1903    ri.cw.nl()
1904    ri.cw.p('free_list:')
1905    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
1906    ri.cw.p('return NULL;')
1907    ri.cw.block_end()
1908
1909
1910def call_free(ri, direction, var):
1911    return f"{op_prefix(ri, direction)}_free({var});"
1912
1913
1914def free_arg_name(direction):
1915    if direction:
1916        return direction_to_suffix[direction][1:]
1917    return 'obj'
1918
1919
1920def print_alloc_wrapper(ri, direction):
1921    name = op_prefix(ri, direction)
1922    ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
1923    ri.cw.block_start()
1924    ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
1925    ri.cw.block_end()
1926
1927
1928def print_free_prototype(ri, direction, suffix=';'):
1929    name = op_prefix(ri, direction)
1930    struct_name = name
1931    if ri.type_name_conflict:
1932        struct_name += '_'
1933    arg = free_arg_name(direction)
1934    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
1935
1936
1937def _print_type(ri, direction, struct):
1938    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
1939    if not direction and ri.type_name_conflict:
1940        suffix += '_'
1941
1942    if ri.op_mode == 'dump':
1943        suffix += '_dump'
1944
1945    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
1946
1947    if ri.fixed_hdr:
1948        ri.cw.p(ri.fixed_hdr + ' _hdr;')
1949        ri.cw.nl()
1950
1951    meta_started = False
1952    for _, attr in struct.member_list():
1953        for type_filter in ['len', 'bit']:
1954            line = attr.presence_member(ri.ku_space, type_filter)
1955            if line:
1956                if not meta_started:
1957                    ri.cw.block_start(line=f"struct")
1958                    meta_started = True
1959                ri.cw.p(line)
1960    if meta_started:
1961        ri.cw.block_end(line='_present;')
1962        ri.cw.nl()
1963
1964    for arg in struct.inherited:
1965        ri.cw.p(f"__u32 {arg};")
1966
1967    for _, attr in struct.member_list():
1968        attr.struct_member(ri)
1969
1970    ri.cw.block_end(line=';')
1971    ri.cw.nl()
1972
1973
1974def print_type(ri, direction):
1975    _print_type(ri, direction, ri.struct[direction])
1976
1977
1978def print_type_full(ri, struct):
1979    _print_type(ri, "", struct)
1980
1981
1982def print_type_helpers(ri, direction, deref=False):
1983    print_free_prototype(ri, direction)
1984    ri.cw.nl()
1985
1986    if ri.ku_space == 'user' and direction == 'request':
1987        for _, attr in ri.struct[direction].member_list():
1988            attr.setter(ri, ri.attr_set, direction, deref=deref)
1989    ri.cw.nl()
1990
1991
1992def print_req_type_helpers(ri):
1993    if len(ri.struct["request"].attr_list) == 0:
1994        return
1995    print_alloc_wrapper(ri, "request")
1996    print_type_helpers(ri, "request")
1997
1998
1999def print_rsp_type_helpers(ri):
2000    if 'reply' not in ri.op[ri.op_mode]:
2001        return
2002    print_type_helpers(ri, "reply")
2003
2004
2005def print_parse_prototype(ri, direction, terminate=True):
2006    suffix = "_rsp" if direction == "reply" else "_req"
2007    term = ';' if terminate else ''
2008
2009    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
2010                          ['const struct nlattr **tb',
2011                           f"struct {ri.op.render_name}{suffix} *req"],
2012                          suffix=term)
2013
2014
2015def print_req_type(ri):
2016    if len(ri.struct["request"].attr_list) == 0:
2017        return
2018    print_type(ri, "request")
2019
2020
2021def print_req_free(ri):
2022    if 'request' not in ri.op[ri.op_mode]:
2023        return
2024    _free_type(ri, 'request', ri.struct['request'])
2025
2026
2027def print_rsp_type(ri):
2028    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2029        direction = 'reply'
2030    elif ri.op_mode == 'event':
2031        direction = 'reply'
2032    else:
2033        return
2034    print_type(ri, direction)
2035
2036
2037def print_wrapped_type(ri):
2038    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2039    if ri.op_mode == 'dump':
2040        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2041    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2042        ri.cw.p('__u16 family;')
2043        ri.cw.p('__u8 cmd;')
2044        ri.cw.p('struct ynl_ntf_base_type *next;')
2045        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2046    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2047    ri.cw.block_end(line=';')
2048    ri.cw.nl()
2049    print_free_prototype(ri, 'reply')
2050    ri.cw.nl()
2051
2052
2053def _free_type_members_iter(ri, struct):
2054    for _, attr in struct.member_list():
2055        if attr.free_needs_iter():
2056            ri.cw.p('unsigned int i;')
2057            ri.cw.nl()
2058            break
2059
2060
2061def _free_type_members(ri, var, struct, ref=''):
2062    for _, attr in struct.member_list():
2063        attr.free(ri, var, ref)
2064
2065
2066def _free_type(ri, direction, struct):
2067    var = free_arg_name(direction)
2068
2069    print_free_prototype(ri, direction, suffix='')
2070    ri.cw.block_start()
2071    _free_type_members_iter(ri, struct)
2072    _free_type_members(ri, var, struct)
2073    if direction:
2074        ri.cw.p(f'free({var});')
2075    ri.cw.block_end()
2076    ri.cw.nl()
2077
2078
2079def free_rsp_nested_prototype(ri):
2080        print_free_prototype(ri, "")
2081
2082
2083def free_rsp_nested(ri, struct):
2084    _free_type(ri, "", struct)
2085
2086
2087def print_rsp_free(ri):
2088    if 'reply' not in ri.op[ri.op_mode]:
2089        return
2090    _free_type(ri, 'reply', ri.struct['reply'])
2091
2092
2093def print_dump_type_free(ri):
2094    sub_type = type_name(ri, 'reply')
2095
2096    print_free_prototype(ri, 'reply', suffix='')
2097    ri.cw.block_start()
2098    ri.cw.p(f"{sub_type} *next = rsp;")
2099    ri.cw.nl()
2100    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2101    _free_type_members_iter(ri, ri.struct['reply'])
2102    ri.cw.p('rsp = next;')
2103    ri.cw.p('next = rsp->next;')
2104    ri.cw.nl()
2105
2106    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2107    ri.cw.p(f'free(rsp);')
2108    ri.cw.block_end()
2109    ri.cw.block_end()
2110    ri.cw.nl()
2111
2112
2113def print_ntf_type_free(ri):
2114    print_free_prototype(ri, 'reply', suffix='')
2115    ri.cw.block_start()
2116    _free_type_members_iter(ri, ri.struct['reply'])
2117    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2118    ri.cw.p(f'free(rsp);')
2119    ri.cw.block_end()
2120    ri.cw.nl()
2121
2122
2123def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2124    if terminate and ri and policy_should_be_static(struct.family):
2125        return
2126
2127    if terminate:
2128        prefix = 'extern '
2129    else:
2130        if ri and policy_should_be_static(struct.family):
2131            prefix = 'static '
2132        else:
2133            prefix = ''
2134
2135    suffix = ';' if terminate else ' = {'
2136
2137    max_attr = struct.attr_max_val
2138    if ri:
2139        name = ri.op.render_name
2140        if ri.op.dual_policy:
2141            name += '_' + ri.op_mode
2142    else:
2143        name = struct.render_name
2144    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2145
2146
2147def print_req_policy(cw, struct, ri=None):
2148    if ri and ri.op:
2149        cw.ifdef_block(ri.op.get('config-cond', None))
2150    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2151    for _, arg in struct.member_list():
2152        arg.attr_policy(cw)
2153    cw.p("};")
2154    cw.ifdef_block(None)
2155    cw.nl()
2156
2157
2158def kernel_can_gen_family_struct(family):
2159    return family.proto == 'genetlink'
2160
2161
2162def policy_should_be_static(family):
2163    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2164
2165
2166def print_kernel_policy_ranges(family, cw):
2167    first = True
2168    for _, attr_set in family.attr_sets.items():
2169        if attr_set.subset_of:
2170            continue
2171
2172        for _, attr in attr_set.items():
2173            if not attr.request:
2174                continue
2175            if 'full-range' not in attr.checks:
2176                continue
2177
2178            if first:
2179                cw.p('/* Integer value ranges */')
2180                first = False
2181
2182            sign = '' if attr.type[0] == 'u' else '_signed'
2183            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2184            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2185            members = []
2186            if 'min' in attr.checks:
2187                members.append(('min', attr.get_limit_str('min', suffix=suffix)))
2188            if 'max' in attr.checks:
2189                members.append(('max', attr.get_limit_str('max', suffix=suffix)))
2190            cw.write_struct_init(members)
2191            cw.block_end(line=';')
2192            cw.nl()
2193
2194
2195def print_kernel_op_table_fwd(family, cw, terminate):
2196    exported = not kernel_can_gen_family_struct(family)
2197
2198    if not terminate or exported:
2199        cw.p(f"/* Ops table for {family.ident_name} */")
2200
2201        pol_to_struct = {'global': 'genl_small_ops',
2202                         'per-op': 'genl_ops',
2203                         'split': 'genl_split_ops'}
2204        struct_type = pol_to_struct[family.kernel_policy]
2205
2206        if not exported:
2207            cnt = ""
2208        elif family.kernel_policy == 'split':
2209            cnt = 0
2210            for op in family.ops.values():
2211                if 'do' in op:
2212                    cnt += 1
2213                if 'dump' in op:
2214                    cnt += 1
2215        else:
2216            cnt = len(family.ops)
2217
2218        qual = 'static const' if not exported else 'const'
2219        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2220        if terminate:
2221            cw.p(f"extern {line};")
2222        else:
2223            cw.block_start(line=line + ' =')
2224
2225    if not terminate:
2226        return
2227
2228    cw.nl()
2229    for name in family.hooks['pre']['do']['list']:
2230        cw.write_func_prot('int', c_lower(name),
2231                           ['const struct genl_split_ops *ops',
2232                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2233    for name in family.hooks['post']['do']['list']:
2234        cw.write_func_prot('void', c_lower(name),
2235                           ['const struct genl_split_ops *ops',
2236                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2237    for name in family.hooks['pre']['dump']['list']:
2238        cw.write_func_prot('int', c_lower(name),
2239                           ['struct netlink_callback *cb'], suffix=';')
2240    for name in family.hooks['post']['dump']['list']:
2241        cw.write_func_prot('int', c_lower(name),
2242                           ['struct netlink_callback *cb'], suffix=';')
2243
2244    cw.nl()
2245
2246    for op_name, op in family.ops.items():
2247        if op.is_async:
2248            continue
2249
2250        if 'do' in op:
2251            name = c_lower(f"{family.ident_name}-nl-{op_name}-doit")
2252            cw.write_func_prot('int', name,
2253                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2254
2255        if 'dump' in op:
2256            name = c_lower(f"{family.ident_name}-nl-{op_name}-dumpit")
2257            cw.write_func_prot('int', name,
2258                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2259    cw.nl()
2260
2261
2262def print_kernel_op_table_hdr(family, cw):
2263    print_kernel_op_table_fwd(family, cw, terminate=True)
2264
2265
2266def print_kernel_op_table(family, cw):
2267    print_kernel_op_table_fwd(family, cw, terminate=False)
2268    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2269        for op_name, op in family.ops.items():
2270            if op.is_async:
2271                continue
2272
2273            cw.ifdef_block(op.get('config-cond', None))
2274            cw.block_start()
2275            members = [('cmd', op.enum_name)]
2276            if 'dont-validate' in op:
2277                members.append(('validate',
2278                                ' | '.join([c_upper('genl-dont-validate-' + x)
2279                                            for x in op['dont-validate']])), )
2280            for op_mode in ['do', 'dump']:
2281                if op_mode in op:
2282                    name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2283                    members.append((op_mode + 'it', name))
2284            if family.kernel_policy == 'per-op':
2285                struct = Struct(family, op['attribute-set'],
2286                                type_list=op['do']['request']['attributes'])
2287
2288                name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2289                members.append(('policy', name))
2290                members.append(('maxattr', struct.attr_max_val.enum_name))
2291            if 'flags' in op:
2292                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2293            cw.write_struct_init(members)
2294            cw.block_end(line=',')
2295    elif family.kernel_policy == 'split':
2296        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2297                    'dump': {'pre': 'start', 'post': 'done'}}
2298
2299        for op_name, op in family.ops.items():
2300            for op_mode in ['do', 'dump']:
2301                if op.is_async or op_mode not in op:
2302                    continue
2303
2304                cw.ifdef_block(op.get('config-cond', None))
2305                cw.block_start()
2306                members = [('cmd', op.enum_name)]
2307                if 'dont-validate' in op:
2308                    dont_validate = []
2309                    for x in op['dont-validate']:
2310                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2311                            continue
2312                        if op_mode == "dump" and x == 'strict':
2313                            continue
2314                        dont_validate.append(x)
2315
2316                    if dont_validate:
2317                        members.append(('validate',
2318                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2319                                                    for x in dont_validate])), )
2320                name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2321                if 'pre' in op[op_mode]:
2322                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2323                members.append((op_mode + 'it', name))
2324                if 'post' in op[op_mode]:
2325                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2326                if 'request' in op[op_mode]:
2327                    struct = Struct(family, op['attribute-set'],
2328                                    type_list=op[op_mode]['request']['attributes'])
2329
2330                    if op.dual_policy:
2331                        name = c_lower(f"{family.ident_name}-{op_name}-{op_mode}-nl-policy")
2332                    else:
2333                        name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2334                    members.append(('policy', name))
2335                    members.append(('maxattr', struct.attr_max_val.enum_name))
2336                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2337                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2338                cw.write_struct_init(members)
2339                cw.block_end(line=',')
2340    cw.ifdef_block(None)
2341
2342    cw.block_end(line=';')
2343    cw.nl()
2344
2345
2346def print_kernel_mcgrp_hdr(family, cw):
2347    if not family.mcgrps['list']:
2348        return
2349
2350    cw.block_start('enum')
2351    for grp in family.mcgrps['list']:
2352        grp_id = c_upper(f"{family.ident_name}-nlgrp-{grp['name']},")
2353        cw.p(grp_id)
2354    cw.block_end(';')
2355    cw.nl()
2356
2357
2358def print_kernel_mcgrp_src(family, cw):
2359    if not family.mcgrps['list']:
2360        return
2361
2362    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
2363    for grp in family.mcgrps['list']:
2364        name = grp['name']
2365        grp_id = c_upper(f"{family.ident_name}-nlgrp-{name}")
2366        cw.p('[' + grp_id + '] = { "' + name + '", },')
2367    cw.block_end(';')
2368    cw.nl()
2369
2370
2371def print_kernel_family_struct_hdr(family, cw):
2372    if not kernel_can_gen_family_struct(family):
2373        return
2374
2375    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
2376    cw.nl()
2377    if 'sock-priv' in family.kernel_family:
2378        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
2379        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
2380        cw.nl()
2381
2382
2383def print_kernel_family_struct_src(family, cw):
2384    if not kernel_can_gen_family_struct(family):
2385        return
2386
2387    cw.block_start(f"struct genl_family {family.ident_name}_nl_family __ro_after_init =")
2388    cw.p('.name\t\t= ' + family.fam_key + ',')
2389    cw.p('.version\t= ' + family.ver_key + ',')
2390    cw.p('.netnsok\t= true,')
2391    cw.p('.parallel_ops\t= true,')
2392    cw.p('.module\t\t= THIS_MODULE,')
2393    if family.kernel_policy == 'per-op':
2394        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
2395        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2396    elif family.kernel_policy == 'split':
2397        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
2398        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2399    if family.mcgrps['list']:
2400        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
2401        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
2402    if 'sock-priv' in family.kernel_family:
2403        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
2404        # Force cast here, actual helpers take pointer to the real type.
2405        cw.p(f'.sock_priv_init\t= (void *){family.c_name}_nl_sock_priv_init,')
2406        cw.p(f'.sock_priv_destroy = (void *){family.c_name}_nl_sock_priv_destroy,')
2407    cw.block_end(';')
2408
2409
2410def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2411    start_line = 'enum'
2412    if enum_name in obj:
2413        if obj[enum_name]:
2414            start_line = 'enum ' + c_lower(obj[enum_name])
2415    elif ckey and ckey in obj:
2416        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
2417    cw.block_start(line=start_line)
2418
2419
2420def render_uapi(family, cw):
2421    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
2422    hdr_prot = hdr_prot.replace('/', '_')
2423    cw.p('#ifndef ' + hdr_prot)
2424    cw.p('#define ' + hdr_prot)
2425    cw.nl()
2426
2427    defines = [(family.fam_key, family["name"]),
2428               (family.ver_key, family.get('version', 1))]
2429    cw.writes_defines(defines)
2430    cw.nl()
2431
2432    defines = []
2433    for const in family['definitions']:
2434        if const['type'] != 'const':
2435            cw.writes_defines(defines)
2436            defines = []
2437            cw.nl()
2438
2439        # Write kdoc for enum and flags (one day maybe also structs)
2440        if const['type'] == 'enum' or const['type'] == 'flags':
2441            enum = family.consts[const['name']]
2442
2443            if enum.has_doc():
2444                if enum.has_entry_doc():
2445                    cw.p('/**')
2446                    doc = ''
2447                    if 'doc' in enum:
2448                        doc = ' - ' + enum['doc']
2449                    cw.write_doc_line(enum.enum_name + doc)
2450                else:
2451                    cw.p('/*')
2452                    cw.write_doc_line(enum['doc'], indent=False)
2453                for entry in enum.entries.values():
2454                    if entry.has_doc():
2455                        doc = '@' + entry.c_name + ': ' + entry['doc']
2456                        cw.write_doc_line(doc)
2457                cw.p(' */')
2458
2459            uapi_enum_start(family, cw, const, 'name')
2460            name_pfx = const.get('name-prefix', f"{family.ident_name}-{const['name']}-")
2461            for entry in enum.entries.values():
2462                suffix = ','
2463                if entry.value_change:
2464                    suffix = f" = {entry.user_value()}" + suffix
2465                cw.p(entry.c_name + suffix)
2466
2467            if const.get('render-max', False):
2468                cw.nl()
2469                cw.p('/* private: */')
2470                if const['type'] == 'flags':
2471                    max_name = c_upper(name_pfx + 'mask')
2472                    max_val = f' = {enum.get_mask()},'
2473                    cw.p(max_name + max_val)
2474                else:
2475                    max_name = c_upper(name_pfx + 'max')
2476                    cw.p('__' + max_name + ',')
2477                    cw.p(max_name + ' = (__' + max_name + ' - 1)')
2478            cw.block_end(line=';')
2479            cw.nl()
2480        elif const['type'] == 'const':
2481            defines.append([c_upper(family.get('c-define-name',
2482                                               f"{family.ident_name}-{const['name']}")),
2483                            const['value']])
2484
2485    if defines:
2486        cw.writes_defines(defines)
2487        cw.nl()
2488
2489    max_by_define = family.get('max-by-define', False)
2490
2491    for _, attr_set in family.attr_sets.items():
2492        if attr_set.subset_of:
2493            continue
2494
2495        max_value = f"({attr_set.cnt_name} - 1)"
2496
2497        val = 0
2498        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2499        for _, attr in attr_set.items():
2500            suffix = ','
2501            if attr.value != val:
2502                suffix = f" = {attr.value},"
2503                val = attr.value
2504            val += 1
2505            cw.p(attr.enum_name + suffix)
2506        cw.nl()
2507        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
2508        if not max_by_define:
2509            cw.p(f"{attr_set.max_name} = {max_value}")
2510        cw.block_end(line=';')
2511        if max_by_define:
2512            cw.p(f"#define {attr_set.max_name} {max_value}")
2513        cw.nl()
2514
2515    # Commands
2516    separate_ntf = 'async-prefix' in family['operations']
2517
2518    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2519    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2520    max_value = f"({cnt_name} - 1)"
2521
2522    uapi_enum_start(family, cw, family['operations'], 'enum-name')
2523    val = 0
2524    for op in family.msgs.values():
2525        if separate_ntf and ('notify' in op or 'event' in op):
2526            continue
2527
2528        suffix = ','
2529        if op.value != val:
2530            suffix = f" = {op.value},"
2531            val = op.value
2532        cw.p(op.enum_name + suffix)
2533        val += 1
2534    cw.nl()
2535    cw.p(cnt_name + ('' if max_by_define else ','))
2536    if not max_by_define:
2537        cw.p(f"{max_name} = {max_value}")
2538    cw.block_end(line=';')
2539    if max_by_define:
2540        cw.p(f"#define {max_name} {max_value}")
2541    cw.nl()
2542
2543    if separate_ntf:
2544        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2545        for op in family.msgs.values():
2546            if separate_ntf and not ('notify' in op or 'event' in op):
2547                continue
2548
2549            suffix = ','
2550            if 'value' in op:
2551                suffix = f" = {op['value']},"
2552            cw.p(op.enum_name + suffix)
2553        cw.block_end(line=';')
2554        cw.nl()
2555
2556    # Multicast
2557    defines = []
2558    for grp in family.mcgrps['list']:
2559        name = grp['name']
2560        defines.append([c_upper(grp.get('c-define-name', f"{family.ident_name}-mcgrp-{name}")),
2561                        f'{name}'])
2562    cw.nl()
2563    if defines:
2564        cw.writes_defines(defines)
2565        cw.nl()
2566
2567    cw.p(f'#endif /* {hdr_prot} */')
2568
2569
2570def _render_user_ntf_entry(ri, op):
2571    ri.cw.block_start(line=f"[{op.enum_name}] = ")
2572    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
2573    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
2574    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
2575    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
2576    ri.cw.block_end(line=',')
2577
2578
2579def render_user_family(family, cw, prototype):
2580    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
2581    if prototype:
2582        cw.p(f'extern {symbol};')
2583        return
2584
2585    if family.ntfs:
2586        cw.block_start(line=f"static const struct ynl_ntf_info {family['name']}_ntf_info[] = ")
2587        for ntf_op_name, ntf_op in family.ntfs.items():
2588            if 'notify' in ntf_op:
2589                op = family.ops[ntf_op['notify']]
2590                ri = RenderInfo(cw, family, "user", op, "notify")
2591            elif 'event' in ntf_op:
2592                ri = RenderInfo(cw, family, "user", ntf_op, "event")
2593            else:
2594                raise Exception('Invalid notification ' + ntf_op_name)
2595            _render_user_ntf_entry(ri, ntf_op)
2596        for op_name, op in family.ops.items():
2597            if 'event' not in op:
2598                continue
2599            ri = RenderInfo(cw, family, "user", op, "event")
2600            _render_user_ntf_entry(ri, op)
2601        cw.block_end(line=";")
2602        cw.nl()
2603
2604    cw.block_start(f'{symbol} = ')
2605    cw.p(f'.name\t\t= "{family.c_name}",')
2606    if family.fixed_header:
2607        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
2608    else:
2609        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
2610    if family.ntfs:
2611        cw.p(f".ntf_info\t= {family['name']}_ntf_info,")
2612        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family['name']}_ntf_info),")
2613    cw.block_end(line=';')
2614
2615
2616def family_contains_bitfield32(family):
2617    for _, attr_set in family.attr_sets.items():
2618        if attr_set.subset_of:
2619            continue
2620        for _, attr in attr_set.items():
2621            if attr.type == "bitfield32":
2622                return True
2623    return False
2624
2625
2626def find_kernel_root(full_path):
2627    sub_path = ''
2628    while True:
2629        sub_path = os.path.join(os.path.basename(full_path), sub_path)
2630        full_path = os.path.dirname(full_path)
2631        maintainers = os.path.join(full_path, "MAINTAINERS")
2632        if os.path.exists(maintainers):
2633            return full_path, sub_path[:-1]
2634
2635
2636def main():
2637    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2638    parser.add_argument('--mode', dest='mode', type=str, required=True)
2639    parser.add_argument('--spec', dest='spec', type=str, required=True)
2640    parser.add_argument('--header', dest='header', action='store_true', default=None)
2641    parser.add_argument('--source', dest='header', action='store_false')
2642    parser.add_argument('--user-header', nargs='+', default=[])
2643    parser.add_argument('--cmp-out', action='store_true', default=None,
2644                        help='Do not overwrite the output file if the new output is identical to the old')
2645    parser.add_argument('--exclude-op', action='append', default=[])
2646    parser.add_argument('-o', dest='out_file', type=str, default=None)
2647    args = parser.parse_args()
2648
2649    if args.header is None:
2650        parser.error("--header or --source is required")
2651
2652    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
2653
2654    try:
2655        parsed = Family(args.spec, exclude_ops)
2656        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
2657            print('Spec license:', parsed.license)
2658            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
2659            os.sys.exit(1)
2660    except yaml.YAMLError as exc:
2661        print(exc)
2662        os.sys.exit(1)
2663        return
2664
2665    supported_models = ['unified']
2666    if args.mode in ['user', 'kernel']:
2667        supported_models += ['directional']
2668    if parsed.msg_id_model not in supported_models:
2669        print(f'Message enum-model {parsed.msg_id_model} not supported for {args.mode} generation')
2670        os.sys.exit(1)
2671
2672    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
2673
2674    _, spec_kernel = find_kernel_root(args.spec)
2675    if args.mode == 'uapi' or args.header:
2676        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
2677    else:
2678        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
2679    cw.p("/* Do not edit directly, auto-generated from: */")
2680    cw.p(f"/*\t{spec_kernel} */")
2681    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2682    if args.exclude_op or args.user_header:
2683        line = ''
2684        line += ' --user-header '.join([''] + args.user_header)
2685        line += ' --exclude-op '.join([''] + args.exclude_op)
2686        cw.p(f'/* YNL-ARG{line} */')
2687    cw.nl()
2688
2689    if args.mode == 'uapi':
2690        render_uapi(parsed, cw)
2691        return
2692
2693    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
2694    if args.header:
2695        cw.p('#ifndef ' + hdr_prot)
2696        cw.p('#define ' + hdr_prot)
2697        cw.nl()
2698
2699    hdr_file=os.path.basename(args.out_file[:-2]) + ".h"
2700
2701    if args.mode == 'kernel':
2702        cw.p('#include <net/netlink.h>')
2703        cw.p('#include <net/genetlink.h>')
2704        cw.nl()
2705        if not args.header:
2706            if args.out_file:
2707                cw.p(f'#include "{hdr_file}"')
2708            cw.nl()
2709        headers = ['uapi/' + parsed.uapi_header]
2710        headers += parsed.kernel_family.get('headers', [])
2711    else:
2712        cw.p('#include <stdlib.h>')
2713        cw.p('#include <string.h>')
2714        if args.header:
2715            cw.p('#include <linux/types.h>')
2716            if family_contains_bitfield32(parsed):
2717                cw.p('#include <linux/netlink.h>')
2718        else:
2719            cw.p(f'#include "{hdr_file}"')
2720            cw.p('#include "ynl.h"')
2721        headers = [parsed.uapi_header]
2722    for definition in parsed['definitions']:
2723        if 'header' in definition:
2724            headers.append(definition['header'])
2725    for one in headers:
2726        cw.p(f"#include <{one}>")
2727    cw.nl()
2728
2729    if args.mode == "user":
2730        if not args.header:
2731            cw.p("#include <linux/genetlink.h>")
2732            cw.nl()
2733            for one in args.user_header:
2734                cw.p(f'#include "{one}"')
2735        else:
2736            cw.p('struct ynl_sock;')
2737            cw.nl()
2738            render_user_family(parsed, cw, True)
2739        cw.nl()
2740
2741    if args.mode == "kernel":
2742        if args.header:
2743            for _, struct in sorted(parsed.pure_nested_structs.items()):
2744                if struct.request:
2745                    cw.p('/* Common nested types */')
2746                    break
2747            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2748                if struct.request:
2749                    print_req_policy_fwd(cw, struct)
2750            cw.nl()
2751
2752            if parsed.kernel_policy == 'global':
2753                cw.p(f"/* Global operation policy for {parsed.name} */")
2754
2755                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2756                print_req_policy_fwd(cw, struct)
2757                cw.nl()
2758
2759            if parsed.kernel_policy in {'per-op', 'split'}:
2760                for op_name, op in parsed.ops.items():
2761                    if 'do' in op and 'event' not in op:
2762                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
2763                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2764                        cw.nl()
2765
2766            print_kernel_op_table_hdr(parsed, cw)
2767            print_kernel_mcgrp_hdr(parsed, cw)
2768            print_kernel_family_struct_hdr(parsed, cw)
2769        else:
2770            print_kernel_policy_ranges(parsed, cw)
2771
2772            for _, struct in sorted(parsed.pure_nested_structs.items()):
2773                if struct.request:
2774                    cw.p('/* Common nested types */')
2775                    break
2776            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2777                if struct.request:
2778                    print_req_policy(cw, struct)
2779            cw.nl()
2780
2781            if parsed.kernel_policy == 'global':
2782                cw.p(f"/* Global operation policy for {parsed.name} */")
2783
2784                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2785                print_req_policy(cw, struct)
2786                cw.nl()
2787
2788            for op_name, op in parsed.ops.items():
2789                if parsed.kernel_policy in {'per-op', 'split'}:
2790                    for op_mode in ['do', 'dump']:
2791                        if op_mode in op and 'request' in op[op_mode]:
2792                            cw.p(f"/* {op.enum_name} - {op_mode} */")
2793                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
2794                            print_req_policy(cw, ri.struct['request'], ri=ri)
2795                            cw.nl()
2796
2797            print_kernel_op_table(parsed, cw)
2798            print_kernel_mcgrp_src(parsed, cw)
2799            print_kernel_family_struct_src(parsed, cw)
2800
2801    if args.mode == "user":
2802        if args.header:
2803            cw.p('/* Enums */')
2804            put_op_name_fwd(parsed, cw)
2805
2806            for name, const in parsed.consts.items():
2807                if isinstance(const, EnumSet):
2808                    put_enum_to_str_fwd(parsed, cw, const)
2809            cw.nl()
2810
2811            cw.p('/* Common nested types */')
2812            for attr_set, struct in parsed.pure_nested_structs.items():
2813                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2814                print_type_full(ri, struct)
2815
2816            for op_name, op in parsed.ops.items():
2817                cw.p(f"/* ============== {op.enum_name} ============== */")
2818
2819                if 'do' in op and 'event' not in op:
2820                    cw.p(f"/* {op.enum_name} - do */")
2821                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2822                    print_req_type(ri)
2823                    print_req_type_helpers(ri)
2824                    cw.nl()
2825                    print_rsp_type(ri)
2826                    print_rsp_type_helpers(ri)
2827                    cw.nl()
2828                    print_req_prototype(ri)
2829                    cw.nl()
2830
2831                if 'dump' in op:
2832                    cw.p(f"/* {op.enum_name} - dump */")
2833                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
2834                    print_req_type(ri)
2835                    print_req_type_helpers(ri)
2836                    if not ri.type_consistent:
2837                        print_rsp_type(ri)
2838                    print_wrapped_type(ri)
2839                    print_dump_prototype(ri)
2840                    cw.nl()
2841
2842                if op.has_ntf:
2843                    cw.p(f"/* {op.enum_name} - notify */")
2844                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2845                    if not ri.type_consistent:
2846                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2847                    print_wrapped_type(ri)
2848
2849            for op_name, op in parsed.ntfs.items():
2850                if 'event' in op:
2851                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
2852                    cw.p(f"/* {op.enum_name} - event */")
2853                    print_rsp_type(ri)
2854                    cw.nl()
2855                    print_wrapped_type(ri)
2856            cw.nl()
2857        else:
2858            cw.p('/* Enums */')
2859            put_op_name(parsed, cw)
2860
2861            for name, const in parsed.consts.items():
2862                if isinstance(const, EnumSet):
2863                    put_enum_to_str(parsed, cw, const)
2864            cw.nl()
2865
2866            has_recursive_nests = False
2867            cw.p('/* Policies */')
2868            for struct in parsed.pure_nested_structs.values():
2869                if struct.recursive:
2870                    put_typol_fwd(cw, struct)
2871                    has_recursive_nests = True
2872            if has_recursive_nests:
2873                cw.nl()
2874            for name in parsed.pure_nested_structs:
2875                struct = Struct(parsed, name)
2876                put_typol(cw, struct)
2877            for name in parsed.root_sets:
2878                struct = Struct(parsed, name)
2879                put_typol(cw, struct)
2880
2881            cw.p('/* Common nested types */')
2882            if has_recursive_nests:
2883                for attr_set, struct in parsed.pure_nested_structs.items():
2884                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2885                    free_rsp_nested_prototype(ri)
2886                    if struct.request:
2887                        put_req_nested_prototype(ri, struct)
2888                    if struct.reply:
2889                        parse_rsp_nested_prototype(ri, struct)
2890                cw.nl()
2891            for attr_set, struct in parsed.pure_nested_structs.items():
2892                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2893
2894                free_rsp_nested(ri, struct)
2895                if struct.request:
2896                    put_req_nested(ri, struct)
2897                if struct.reply:
2898                    parse_rsp_nested(ri, struct)
2899
2900            for op_name, op in parsed.ops.items():
2901                cw.p(f"/* ============== {op.enum_name} ============== */")
2902                if 'do' in op and 'event' not in op:
2903                    cw.p(f"/* {op.enum_name} - do */")
2904                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2905                    print_req_free(ri)
2906                    print_rsp_free(ri)
2907                    parse_rsp_msg(ri)
2908                    print_req(ri)
2909                    cw.nl()
2910
2911                if 'dump' in op:
2912                    cw.p(f"/* {op.enum_name} - dump */")
2913                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
2914                    if not ri.type_consistent:
2915                        parse_rsp_msg(ri, deref=True)
2916                    print_req_free(ri)
2917                    print_dump_type_free(ri)
2918                    print_dump(ri)
2919                    cw.nl()
2920
2921                if op.has_ntf:
2922                    cw.p(f"/* {op.enum_name} - notify */")
2923                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2924                    if not ri.type_consistent:
2925                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2926                    print_ntf_type_free(ri)
2927
2928            for op_name, op in parsed.ntfs.items():
2929                if 'event' in op:
2930                    cw.p(f"/* {op.enum_name} - event */")
2931
2932                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2933                    parse_rsp_msg(ri)
2934
2935                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
2936                    print_ntf_type_free(ri)
2937            cw.nl()
2938            render_user_family(parsed, cw, False)
2939
2940    if args.header:
2941        cw.p(f'#endif /* {hdr_prot} */')
2942
2943
2944if __name__ == "__main__":
2945    main()
2946