xref: /linux/tools/net/ynl/ynl-gen-c.py (revision eb4f99c56ad30cb0f8c8e93a78b1200f5987e41e)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import filecmp
7import os
8import re
9import shutil
10import tempfile
11import yaml
12
13from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
14
15
16def c_upper(name):
17    return name.upper().replace('-', '_')
18
19
20def c_lower(name):
21    return name.lower().replace('-', '_')
22
23
24def limit_to_number(name):
25    """
26    Turn a string limit like u32-max or s64-min into its numerical value
27    """
28    if name[0] == 'u' and name.endswith('-min'):
29        return 0
30    width = int(name[1:-4])
31    if name[0] == 's':
32        width -= 1
33    value = (1 << width) - 1
34    if name[0] == 's' and name.endswith('-min'):
35        value = -value - 1
36    return value
37
38
39class BaseNlLib:
40    def get_family_id(self):
41        return 'ys->family_id'
42
43
44class Type(SpecAttr):
45    def __init__(self, family, attr_set, attr, value):
46        super().__init__(family, attr_set, attr, value)
47
48        self.attr = attr
49        self.attr_set = attr_set
50        self.type = attr['type']
51        self.checks = attr.get('checks', {})
52
53        self.request = False
54        self.reply = False
55
56        if 'len' in attr:
57            self.len = attr['len']
58
59        if 'nested-attributes' in attr:
60            self.nested_attrs = attr['nested-attributes']
61            if self.nested_attrs == family.name:
62                self.nested_render_name = c_lower(f"{family.ident_name}")
63            else:
64                self.nested_render_name = c_lower(f"{family.ident_name}_{self.nested_attrs}")
65
66            if self.nested_attrs in self.family.consts:
67                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
68            else:
69                self.nested_struct_type = 'struct ' + self.nested_render_name
70
71        self.c_name = c_lower(self.name)
72        if self.c_name in _C_KW:
73            self.c_name += '_'
74
75        # Added by resolve():
76        self.enum_name = None
77        delattr(self, "enum_name")
78
79    def get_limit(self, limit, default=None):
80        value = self.checks.get(limit, default)
81        if value is None:
82            return value
83        if isinstance(value, int):
84            return value
85        if value in self.family.consts:
86            raise Exception("Resolving family constants not implemented, yet")
87        return limit_to_number(value)
88
89    def get_limit_str(self, limit, default=None, suffix=''):
90        value = self.checks.get(limit, default)
91        if value is None:
92            return ''
93        if isinstance(value, int):
94            return str(value) + suffix
95        if value in self.family.consts:
96            return c_upper(f"{self.family['name']}-{value}")
97        return c_upper(value)
98
99    def resolve(self):
100        if 'name-prefix' in self.attr:
101            enum_name = f"{self.attr['name-prefix']}{self.name}"
102        else:
103            enum_name = f"{self.attr_set.name_prefix}{self.name}"
104        self.enum_name = c_upper(enum_name)
105
106    def is_multi_val(self):
107        return None
108
109    def is_scalar(self):
110        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
111
112    def is_recursive(self):
113        return False
114
115    def is_recursive_for_op(self, ri):
116        return self.is_recursive() and not ri.op
117
118    def presence_type(self):
119        return 'bit'
120
121    def presence_member(self, space, type_filter):
122        if self.presence_type() != type_filter:
123            return
124
125        if self.presence_type() == 'bit':
126            pfx = '__' if space == 'user' else ''
127            return f"{pfx}u32 {self.c_name}:1;"
128
129        if self.presence_type() == 'len':
130            pfx = '__' if space == 'user' else ''
131            return f"{pfx}u32 {self.c_name}_len;"
132
133    def _complex_member_type(self, ri):
134        return None
135
136    def free_needs_iter(self):
137        return False
138
139    def free(self, ri, var, ref):
140        if self.is_multi_val() or self.presence_type() == 'len':
141            ri.cw.p(f'free({var}->{ref}{self.c_name});')
142
143    def arg_member(self, ri):
144        member = self._complex_member_type(ri)
145        if member:
146            arg = [member + ' *' + self.c_name]
147            if self.presence_type() == 'count':
148                arg += ['unsigned int n_' + self.c_name]
149            return arg
150        raise Exception(f"Struct member not implemented for class type {self.type}")
151
152    def struct_member(self, ri):
153        if self.is_multi_val():
154            ri.cw.p(f"unsigned int n_{self.c_name};")
155        member = self._complex_member_type(ri)
156        if member:
157            ptr = '*' if self.is_multi_val() else ''
158            if self.is_recursive_for_op(ri):
159                ptr = '*'
160            ri.cw.p(f"{member} {ptr}{self.c_name};")
161            return
162        members = self.arg_member(ri)
163        for one in members:
164            ri.cw.p(one + ';')
165
166    def _attr_policy(self, policy):
167        return '{ .type = ' + policy + ', }'
168
169    def attr_policy(self, cw):
170        policy = c_upper('nla-' + self.attr['type'])
171
172        spec = self._attr_policy(policy)
173        cw.p(f"\t[{self.enum_name}] = {spec},")
174
175    def _attr_typol(self):
176        raise Exception(f"Type policy not implemented for class type {self.type}")
177
178    def attr_typol(self, cw):
179        typol = self._attr_typol()
180        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
181
182    def _attr_put_line(self, ri, var, line):
183        if self.presence_type() == 'bit':
184            ri.cw.p(f"if ({var}->_present.{self.c_name})")
185        elif self.presence_type() == 'len':
186            ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
187        ri.cw.p(f"{line};")
188
189    def _attr_put_simple(self, ri, var, put_type):
190        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
191        self._attr_put_line(ri, var, line)
192
193    def attr_put(self, ri, var):
194        raise Exception(f"Put not implemented for class type {self.type}")
195
196    def _attr_get(self, ri, var):
197        raise Exception(f"Attr get not implemented for class type {self.type}")
198
199    def attr_get(self, ri, var, first):
200        lines, init_lines, local_vars = self._attr_get(ri, var)
201        if type(lines) is str:
202            lines = [lines]
203        if type(init_lines) is str:
204            init_lines = [init_lines]
205
206        kw = 'if' if first else 'else if'
207        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
208        if local_vars:
209            for local in local_vars:
210                ri.cw.p(local)
211            ri.cw.nl()
212
213        if not self.is_multi_val():
214            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
215            ri.cw.p("return YNL_PARSE_CB_ERROR;")
216            if self.presence_type() == 'bit':
217                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
218
219        if init_lines:
220            ri.cw.nl()
221            for line in init_lines:
222                ri.cw.p(line)
223
224        for line in lines:
225            ri.cw.p(line)
226        ri.cw.block_end()
227        return True
228
229    def _setter_lines(self, ri, member, presence):
230        raise Exception(f"Setter not implemented for class type {self.type}")
231
232    def setter(self, ri, space, direction, deref=False, ref=None):
233        ref = (ref if ref else []) + [self.c_name]
234        var = "req"
235        member = f"{var}->{'.'.join(ref)}"
236
237        code = []
238        presence = ''
239        for i in range(0, len(ref)):
240            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
241            # Every layer below last is a nest, so we know it uses bit presence
242            # last layer is "self" and may be a complex type
243            if i == len(ref) - 1 and self.presence_type() != 'bit':
244                continue
245            code.append(presence + ' = 1;')
246        code += self._setter_lines(ri, member, presence)
247
248        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
249        free = bool([x for x in code if 'free(' in x])
250        alloc = bool([x for x in code if 'alloc(' in x])
251        if free and not alloc:
252            func_name = '__' + func_name
253        ri.cw.write_func('static inline void', func_name, body=code,
254                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
255
256
257class TypeUnused(Type):
258    def presence_type(self):
259        return ''
260
261    def arg_member(self, ri):
262        return []
263
264    def _attr_get(self, ri, var):
265        return ['return YNL_PARSE_CB_ERROR;'], None, None
266
267    def _attr_typol(self):
268        return '.type = YNL_PT_REJECT, '
269
270    def attr_policy(self, cw):
271        pass
272
273    def attr_put(self, ri, var):
274        pass
275
276    def attr_get(self, ri, var, first):
277        pass
278
279    def setter(self, ri, space, direction, deref=False, ref=None):
280        pass
281
282
283class TypePad(Type):
284    def presence_type(self):
285        return ''
286
287    def arg_member(self, ri):
288        return []
289
290    def _attr_typol(self):
291        return '.type = YNL_PT_IGNORE, '
292
293    def attr_put(self, ri, var):
294        pass
295
296    def attr_get(self, ri, var, first):
297        pass
298
299    def attr_policy(self, cw):
300        pass
301
302    def setter(self, ri, space, direction, deref=False, ref=None):
303        pass
304
305
306class TypeScalar(Type):
307    def __init__(self, family, attr_set, attr, value):
308        super().__init__(family, attr_set, attr, value)
309
310        self.byte_order_comment = ''
311        if 'byte-order' in attr:
312            self.byte_order_comment = f" /* {attr['byte-order']} */"
313
314        if 'enum' in self.attr:
315            enum = self.family.consts[self.attr['enum']]
316            low, high = enum.value_range()
317            if 'min' not in self.checks:
318                if low != 0 or self.type[0] == 's':
319                    self.checks['min'] = low
320            if 'max' not in self.checks:
321                self.checks['max'] = high
322
323        if 'min' in self.checks and 'max' in self.checks:
324            if self.get_limit('min') > self.get_limit('max'):
325                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
326            self.checks['range'] = True
327
328        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
329        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
330        if low < 0 and self.type[0] == 'u':
331            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
332        if low < -32768 or high > 32767:
333            self.checks['full-range'] = True
334
335        # Added by resolve():
336        self.is_bitfield = None
337        delattr(self, "is_bitfield")
338        self.type_name = None
339        delattr(self, "type_name")
340
341    def resolve(self):
342        self.resolve_up(super())
343
344        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
345            self.is_bitfield = True
346        elif 'enum' in self.attr:
347            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
348        else:
349            self.is_bitfield = False
350
351        if not self.is_bitfield and 'enum' in self.attr:
352            self.type_name = self.family.consts[self.attr['enum']].user_type
353        elif self.is_auto_scalar:
354            self.type_name = '__' + self.type[0] + '64'
355        else:
356            self.type_name = '__' + self.type
357
358    def _attr_policy(self, policy):
359        if 'flags-mask' in self.checks or self.is_bitfield:
360            if self.is_bitfield:
361                enum = self.family.consts[self.attr['enum']]
362                mask = enum.get_mask(as_flags=True)
363            else:
364                flags = self.family.consts[self.checks['flags-mask']]
365                flag_cnt = len(flags['entries'])
366                mask = (1 << flag_cnt) - 1
367            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
368        elif 'full-range' in self.checks:
369            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
370        elif 'range' in self.checks:
371            return f"NLA_POLICY_RANGE({policy}, {self.get_limit_str('min')}, {self.get_limit_str('max')})"
372        elif 'min' in self.checks:
373            return f"NLA_POLICY_MIN({policy}, {self.get_limit_str('min')})"
374        elif 'max' in self.checks:
375            return f"NLA_POLICY_MAX({policy}, {self.get_limit_str('max')})"
376        return super()._attr_policy(policy)
377
378    def _attr_typol(self):
379        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
380
381    def arg_member(self, ri):
382        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
383
384    def attr_put(self, ri, var):
385        self._attr_put_simple(ri, var, self.type)
386
387    def _attr_get(self, ri, var):
388        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
389
390    def _setter_lines(self, ri, member, presence):
391        return [f"{member} = {self.c_name};"]
392
393
394class TypeFlag(Type):
395    def arg_member(self, ri):
396        return []
397
398    def _attr_typol(self):
399        return '.type = YNL_PT_FLAG, '
400
401    def attr_put(self, ri, var):
402        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
403
404    def _attr_get(self, ri, var):
405        return [], None, None
406
407    def _setter_lines(self, ri, member, presence):
408        return []
409
410
411class TypeString(Type):
412    def arg_member(self, ri):
413        return [f"const char *{self.c_name}"]
414
415    def presence_type(self):
416        return 'len'
417
418    def struct_member(self, ri):
419        ri.cw.p(f"char *{self.c_name};")
420
421    def _attr_typol(self):
422        return f'.type = YNL_PT_NUL_STR, '
423
424    def _attr_policy(self, policy):
425        if 'exact-len' in self.checks:
426            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
427        else:
428            mem = '{ .type = ' + policy
429            if 'max-len' in self.checks:
430                mem += ', .len = ' + self.get_limit_str('max-len')
431            mem += ', }'
432        return mem
433
434    def attr_policy(self, cw):
435        if self.checks.get('unterminated-ok', False):
436            policy = 'NLA_STRING'
437        else:
438            policy = 'NLA_NUL_STRING'
439
440        spec = self._attr_policy(policy)
441        cw.p(f"\t[{self.enum_name}] = {spec},")
442
443    def attr_put(self, ri, var):
444        self._attr_put_simple(ri, var, 'str')
445
446    def _attr_get(self, ri, var):
447        len_mem = var + '->_present.' + self.c_name + '_len'
448        return [f"{len_mem} = len;",
449                f"{var}->{self.c_name} = malloc(len + 1);",
450                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
451                f"{var}->{self.c_name}[len] = 0;"], \
452               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
453               ['unsigned int len;']
454
455    def _setter_lines(self, ri, member, presence):
456        return [f"free({member});",
457                f"{presence}_len = strlen({self.c_name});",
458                f"{member} = malloc({presence}_len + 1);",
459                f'memcpy({member}, {self.c_name}, {presence}_len);',
460                f'{member}[{presence}_len] = 0;']
461
462
463class TypeBinary(Type):
464    def arg_member(self, ri):
465        return [f"const void *{self.c_name}", 'size_t len']
466
467    def presence_type(self):
468        return 'len'
469
470    def struct_member(self, ri):
471        ri.cw.p(f"void *{self.c_name};")
472
473    def _attr_typol(self):
474        return f'.type = YNL_PT_BINARY,'
475
476    def _attr_policy(self, policy):
477        if len(self.checks) == 0:
478            pass
479        elif len(self.checks) == 1:
480            check_name = list(self.checks)[0]
481            if check_name not in {'exact-len', 'min-len'}:
482                raise Exception('Unsupported check for binary type: ' + check_name)
483        else:
484            raise Exception('More than one check for binary type not implemented, yet')
485
486        if len(self.checks) == 0:
487            mem = '{ .type = NLA_BINARY, }'
488        elif 'exact-len' in self.checks:
489            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
490        elif 'min-len' in self.checks:
491            mem = '{ .len = ' + self.get_limit_str('min-len') + ', }'
492
493        return mem
494
495    def attr_put(self, ri, var):
496        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
497                            f"{var}->{self.c_name}, {var}->_present.{self.c_name}_len)")
498
499    def _attr_get(self, ri, var):
500        len_mem = var + '->_present.' + self.c_name + '_len'
501        return [f"{len_mem} = len;",
502                f"{var}->{self.c_name} = malloc(len);",
503                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
504               ['len = ynl_attr_data_len(attr);'], \
505               ['unsigned int len;']
506
507    def _setter_lines(self, ri, member, presence):
508        return [f"free({member});",
509                f"{presence}_len = len;",
510                f"{member} = malloc({presence}_len);",
511                f'memcpy({member}, {self.c_name}, {presence}_len);']
512
513
514class TypeBitfield32(Type):
515    def _complex_member_type(self, ri):
516        return "struct nla_bitfield32"
517
518    def _attr_typol(self):
519        return f'.type = YNL_PT_BITFIELD32, '
520
521    def _attr_policy(self, policy):
522        if not 'enum' in self.attr:
523            raise Exception('Enum required for bitfield32 attr')
524        enum = self.family.consts[self.attr['enum']]
525        mask = enum.get_mask(as_flags=True)
526        return f"NLA_POLICY_BITFIELD32({mask})"
527
528    def attr_put(self, ri, var):
529        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
530        self._attr_put_line(ri, var, line)
531
532    def _attr_get(self, ri, var):
533        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
534
535    def _setter_lines(self, ri, member, presence):
536        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
537
538
539class TypeNest(Type):
540    def is_recursive(self):
541        return self.family.pure_nested_structs[self.nested_attrs].recursive
542
543    def _complex_member_type(self, ri):
544        return self.nested_struct_type
545
546    def free(self, ri, var, ref):
547        at = '&'
548        if self.is_recursive_for_op(ri):
549            at = ''
550            ri.cw.p(f'if ({var}->{ref}{self.c_name})')
551        ri.cw.p(f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});')
552
553    def _attr_typol(self):
554        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
555
556    def _attr_policy(self, policy):
557        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
558
559    def attr_put(self, ri, var):
560        at = '' if self.is_recursive_for_op(ri) else '&'
561        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
562                            f"{self.enum_name}, {at}{var}->{self.c_name})")
563
564    def _attr_get(self, ri, var):
565        get_lines = [f"if ({self.nested_render_name}_parse(&parg, attr))",
566                     "return YNL_PARSE_CB_ERROR;"]
567        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
568                      f"parg.data = &{var}->{self.c_name};"]
569        return get_lines, init_lines, None
570
571    def setter(self, ri, space, direction, deref=False, ref=None):
572        ref = (ref if ref else []) + [self.c_name]
573
574        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
575            if attr.is_recursive():
576                continue
577            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
578
579
580class TypeMultiAttr(Type):
581    def __init__(self, family, attr_set, attr, value, base_type):
582        super().__init__(family, attr_set, attr, value)
583
584        self.base_type = base_type
585
586    def is_multi_val(self):
587        return True
588
589    def presence_type(self):
590        return 'count'
591
592    def _complex_member_type(self, ri):
593        if 'type' not in self.attr or self.attr['type'] == 'nest':
594            return self.nested_struct_type
595        elif self.attr['type'] in scalars:
596            scalar_pfx = '__' if ri.ku_space == 'user' else ''
597            return scalar_pfx + self.attr['type']
598        else:
599            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
600
601    def free_needs_iter(self):
602        return 'type' not in self.attr or self.attr['type'] == 'nest'
603
604    def free(self, ri, var, ref):
605        if self.attr['type'] in scalars:
606            ri.cw.p(f"free({var}->{ref}{self.c_name});")
607        elif 'type' not in self.attr or self.attr['type'] == 'nest':
608            ri.cw.p(f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)")
609            ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);')
610            ri.cw.p(f"free({var}->{ref}{self.c_name});")
611        else:
612            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
613
614    def _attr_policy(self, policy):
615        return self.base_type._attr_policy(policy)
616
617    def _attr_typol(self):
618        return self.base_type._attr_typol()
619
620    def _attr_get(self, ri, var):
621        return f'n_{self.c_name}++;', None, None
622
623    def attr_put(self, ri, var):
624        if self.attr['type'] in scalars:
625            put_type = self.type
626            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
627            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
628        elif 'type' not in self.attr or self.attr['type'] == 'nest':
629            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
630            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
631                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
632        else:
633            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
634
635    def _setter_lines(self, ri, member, presence):
636        # For multi-attr we have a count, not presence, hack up the presence
637        presence = presence[:-(len('_present.') + len(self.c_name))] + "n_" + self.c_name
638        return [f"free({member});",
639                f"{member} = {self.c_name};",
640                f"{presence} = n_{self.c_name};"]
641
642
643class TypeArrayNest(Type):
644    def is_multi_val(self):
645        return True
646
647    def presence_type(self):
648        return 'count'
649
650    def _complex_member_type(self, ri):
651        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
652            return self.nested_struct_type
653        elif self.attr['sub-type'] in scalars:
654            scalar_pfx = '__' if ri.ku_space == 'user' else ''
655            return scalar_pfx + self.attr['sub-type']
656        else:
657            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
658
659    def _attr_typol(self):
660        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
661
662    def _attr_get(self, ri, var):
663        local_vars = ['const struct nlattr *attr2;']
664        get_lines = [f'attr_{self.c_name} = attr;',
665                     'ynl_attr_for_each_nested(attr2, attr)',
666                     f'\t{var}->n_{self.c_name}++;']
667        return get_lines, None, local_vars
668
669
670class TypeNestTypeValue(Type):
671    def _complex_member_type(self, ri):
672        return self.nested_struct_type
673
674    def _attr_typol(self):
675        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
676
677    def _attr_get(self, ri, var):
678        prev = 'attr'
679        tv_args = ''
680        get_lines = []
681        local_vars = []
682        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
683                      f"parg.data = &{var}->{self.c_name};"]
684        if 'type-value' in self.attr:
685            tv_names = [c_lower(x) for x in self.attr["type-value"]]
686            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
687            local_vars += [f'__u32 {", ".join(tv_names)};']
688            for level in self.attr["type-value"]:
689                level = c_lower(level)
690                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
691                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
692                prev = 'attr_' + level
693
694            tv_args = f", {', '.join(tv_names)}"
695
696        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
697        return get_lines, init_lines, local_vars
698
699
700class Struct:
701    def __init__(self, family, space_name, type_list=None, inherited=None):
702        self.family = family
703        self.space_name = space_name
704        self.attr_set = family.attr_sets[space_name]
705        # Use list to catch comparisons with empty sets
706        self._inherited = inherited if inherited is not None else []
707        self.inherited = []
708
709        self.nested = type_list is None
710        if family.name == c_lower(space_name):
711            self.render_name = c_lower(family.ident_name)
712        else:
713            self.render_name = c_lower(family.ident_name + '-' + space_name)
714        self.struct_name = 'struct ' + self.render_name
715        if self.nested and space_name in family.consts:
716            self.struct_name += '_'
717        self.ptr_name = self.struct_name + ' *'
718        # All attr sets this one contains, directly or multiple levels down
719        self.child_nests = set()
720
721        self.request = False
722        self.reply = False
723        self.recursive = False
724
725        self.attr_list = []
726        self.attrs = dict()
727        if type_list is not None:
728            for t in type_list:
729                self.attr_list.append((t, self.attr_set[t]),)
730        else:
731            for t in self.attr_set:
732                self.attr_list.append((t, self.attr_set[t]),)
733
734        max_val = 0
735        self.attr_max_val = None
736        for name, attr in self.attr_list:
737            if attr.value >= max_val:
738                max_val = attr.value
739                self.attr_max_val = attr
740            self.attrs[name] = attr
741
742    def __iter__(self):
743        yield from self.attrs
744
745    def __getitem__(self, key):
746        return self.attrs[key]
747
748    def member_list(self):
749        return self.attr_list
750
751    def set_inherited(self, new_inherited):
752        if self._inherited != new_inherited:
753            raise Exception("Inheriting different members not supported")
754        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
755
756
757class EnumEntry(SpecEnumEntry):
758    def __init__(self, enum_set, yaml, prev, value_start):
759        super().__init__(enum_set, yaml, prev, value_start)
760
761        if prev:
762            self.value_change = (self.value != prev.value + 1)
763        else:
764            self.value_change = (self.value != 0)
765        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
766
767        # Added by resolve:
768        self.c_name = None
769        delattr(self, "c_name")
770
771    def resolve(self):
772        self.resolve_up(super())
773
774        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
775
776
777class EnumSet(SpecEnumSet):
778    def __init__(self, family, yaml):
779        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
780
781        if 'enum-name' in yaml:
782            if yaml['enum-name']:
783                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
784                self.user_type = self.enum_name
785            else:
786                self.enum_name = None
787        else:
788            self.enum_name = 'enum ' + self.render_name
789
790        if self.enum_name:
791            self.user_type = self.enum_name
792        else:
793            self.user_type = 'int'
794
795        self.value_pfx = yaml.get('name-prefix', f"{family.ident_name}-{yaml['name']}-")
796
797        super().__init__(family, yaml)
798
799    def new_entry(self, entry, prev_entry, value_start):
800        return EnumEntry(self, entry, prev_entry, value_start)
801
802    def value_range(self):
803        low = min([x.value for x in self.entries.values()])
804        high = max([x.value for x in self.entries.values()])
805
806        if high - low + 1 != len(self.entries):
807            raise Exception("Can't get value range for a noncontiguous enum")
808
809        return low, high
810
811
812class AttrSet(SpecAttrSet):
813    def __init__(self, family, yaml):
814        super().__init__(family, yaml)
815
816        if self.subset_of is None:
817            if 'name-prefix' in yaml:
818                pfx = yaml['name-prefix']
819            elif self.name == family.name:
820                pfx = family.ident_name + '-a-'
821            else:
822                pfx = f"{family.ident_name}-a-{self.name}-"
823            self.name_prefix = c_upper(pfx)
824            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
825            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
826        else:
827            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
828            self.max_name = family.attr_sets[self.subset_of].max_name
829            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
830
831        # Added by resolve:
832        self.c_name = None
833        delattr(self, "c_name")
834
835    def resolve(self):
836        self.c_name = c_lower(self.name)
837        if self.c_name in _C_KW:
838            self.c_name += '_'
839        if self.c_name == self.family.c_name:
840            self.c_name = ''
841
842    def new_attr(self, elem, value):
843        if elem['type'] in scalars:
844            t = TypeScalar(self.family, self, elem, value)
845        elif elem['type'] == 'unused':
846            t = TypeUnused(self.family, self, elem, value)
847        elif elem['type'] == 'pad':
848            t = TypePad(self.family, self, elem, value)
849        elif elem['type'] == 'flag':
850            t = TypeFlag(self.family, self, elem, value)
851        elif elem['type'] == 'string':
852            t = TypeString(self.family, self, elem, value)
853        elif elem['type'] == 'binary':
854            t = TypeBinary(self.family, self, elem, value)
855        elif elem['type'] == 'bitfield32':
856            t = TypeBitfield32(self.family, self, elem, value)
857        elif elem['type'] == 'nest':
858            t = TypeNest(self.family, self, elem, value)
859        elif elem['type'] == 'indexed-array' and 'sub-type' in elem:
860            if elem["sub-type"] == 'nest':
861                t = TypeArrayNest(self.family, self, elem, value)
862            else:
863                raise Exception(f'new_attr: unsupported sub-type {elem["sub-type"]}')
864        elif elem['type'] == 'nest-type-value':
865            t = TypeNestTypeValue(self.family, self, elem, value)
866        else:
867            raise Exception(f"No typed class for type {elem['type']}")
868
869        if 'multi-attr' in elem and elem['multi-attr']:
870            t = TypeMultiAttr(self.family, self, elem, value, t)
871
872        return t
873
874
875class Operation(SpecOperation):
876    def __init__(self, family, yaml, req_value, rsp_value):
877        super().__init__(family, yaml, req_value, rsp_value)
878
879        self.render_name = c_lower(family.ident_name + '_' + self.name)
880
881        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
882                         ('dump' in yaml and 'request' in yaml['dump'])
883
884        self.has_ntf = False
885
886        # Added by resolve:
887        self.enum_name = None
888        delattr(self, "enum_name")
889
890    def resolve(self):
891        self.resolve_up(super())
892
893        if not self.is_async:
894            self.enum_name = self.family.op_prefix + c_upper(self.name)
895        else:
896            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
897
898    def mark_has_ntf(self):
899        self.has_ntf = True
900
901
902class Family(SpecFamily):
903    def __init__(self, file_name, exclude_ops):
904        # Added by resolve:
905        self.c_name = None
906        delattr(self, "c_name")
907        self.op_prefix = None
908        delattr(self, "op_prefix")
909        self.async_op_prefix = None
910        delattr(self, "async_op_prefix")
911        self.mcgrps = None
912        delattr(self, "mcgrps")
913        self.consts = None
914        delattr(self, "consts")
915        self.hooks = None
916        delattr(self, "hooks")
917
918        super().__init__(file_name, exclude_ops=exclude_ops)
919
920        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
921        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
922
923        if 'definitions' not in self.yaml:
924            self.yaml['definitions'] = []
925
926        if 'uapi-header' in self.yaml:
927            self.uapi_header = self.yaml['uapi-header']
928        else:
929            self.uapi_header = f"linux/{self.ident_name}.h"
930        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
931            self.uapi_header_name = self.uapi_header[6:-2]
932        else:
933            self.uapi_header_name = self.ident_name
934
935    def resolve(self):
936        self.resolve_up(super())
937
938        if self.yaml.get('protocol', 'genetlink') not in {'genetlink', 'genetlink-c', 'genetlink-legacy'}:
939            raise Exception("Codegen only supported for genetlink")
940
941        self.c_name = c_lower(self.ident_name)
942        if 'name-prefix' in self.yaml['operations']:
943            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
944        else:
945            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
946        if 'async-prefix' in self.yaml['operations']:
947            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
948        else:
949            self.async_op_prefix = self.op_prefix
950
951        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
952
953        self.hooks = dict()
954        for when in ['pre', 'post']:
955            self.hooks[when] = dict()
956            for op_mode in ['do', 'dump']:
957                self.hooks[when][op_mode] = dict()
958                self.hooks[when][op_mode]['set'] = set()
959                self.hooks[when][op_mode]['list'] = []
960
961        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
962        self.root_sets = dict()
963        # dict space-name -> set('request', 'reply')
964        self.pure_nested_structs = dict()
965
966        self._mark_notify()
967        self._mock_up_events()
968
969        self._load_root_sets()
970        self._load_nested_sets()
971        self._load_attr_use()
972        self._load_hooks()
973
974        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
975        if self.kernel_policy == 'global':
976            self._load_global_policy()
977
978    def new_enum(self, elem):
979        return EnumSet(self, elem)
980
981    def new_attr_set(self, elem):
982        return AttrSet(self, elem)
983
984    def new_operation(self, elem, req_value, rsp_value):
985        return Operation(self, elem, req_value, rsp_value)
986
987    def _mark_notify(self):
988        for op in self.msgs.values():
989            if 'notify' in op:
990                self.ops[op['notify']].mark_has_ntf()
991
992    # Fake a 'do' equivalent of all events, so that we can render their response parsing
993    def _mock_up_events(self):
994        for op in self.yaml['operations']['list']:
995            if 'event' in op:
996                op['do'] = {
997                    'reply': {
998                        'attributes': op['event']['attributes']
999                    }
1000                }
1001
1002    def _load_root_sets(self):
1003        for op_name, op in self.msgs.items():
1004            if 'attribute-set' not in op:
1005                continue
1006
1007            req_attrs = set()
1008            rsp_attrs = set()
1009            for op_mode in ['do', 'dump']:
1010                if op_mode in op and 'request' in op[op_mode]:
1011                    req_attrs.update(set(op[op_mode]['request']['attributes']))
1012                if op_mode in op and 'reply' in op[op_mode]:
1013                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
1014            if 'event' in op:
1015                rsp_attrs.update(set(op['event']['attributes']))
1016
1017            if op['attribute-set'] not in self.root_sets:
1018                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1019            else:
1020                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1021                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1022
1023    def _sort_pure_types(self):
1024        # Try to reorder according to dependencies
1025        pns_key_list = list(self.pure_nested_structs.keys())
1026        pns_key_seen = set()
1027        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1028        for _ in range(rounds):
1029            if len(pns_key_list) == 0:
1030                break
1031            name = pns_key_list.pop(0)
1032            finished = True
1033            for _, spec in self.attr_sets[name].items():
1034                if 'nested-attributes' in spec:
1035                    nested = spec['nested-attributes']
1036                    # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1037                    if self.pure_nested_structs[nested].recursive:
1038                        continue
1039                    if nested not in pns_key_seen:
1040                        # Dicts are sorted, this will make struct last
1041                        struct = self.pure_nested_structs.pop(name)
1042                        self.pure_nested_structs[name] = struct
1043                        finished = False
1044                        break
1045            if finished:
1046                pns_key_seen.add(name)
1047            else:
1048                pns_key_list.append(name)
1049
1050    def _load_nested_sets(self):
1051        attr_set_queue = list(self.root_sets.keys())
1052        attr_set_seen = set(self.root_sets.keys())
1053
1054        while len(attr_set_queue):
1055            a_set = attr_set_queue.pop(0)
1056            for attr, spec in self.attr_sets[a_set].items():
1057                if 'nested-attributes' not in spec:
1058                    continue
1059
1060                nested = spec['nested-attributes']
1061                if nested not in attr_set_seen:
1062                    attr_set_queue.append(nested)
1063                    attr_set_seen.add(nested)
1064
1065                inherit = set()
1066                if nested not in self.root_sets:
1067                    if nested not in self.pure_nested_structs:
1068                        self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
1069                else:
1070                    raise Exception(f'Using attr set as root and nested not supported - {nested}')
1071
1072                if 'type-value' in spec:
1073                    if nested in self.root_sets:
1074                        raise Exception("Inheriting members to a space used as root not supported")
1075                    inherit.update(set(spec['type-value']))
1076                elif spec['type'] == 'indexed-array':
1077                    inherit.add('idx')
1078                self.pure_nested_structs[nested].set_inherited(inherit)
1079
1080        for root_set, rs_members in self.root_sets.items():
1081            for attr, spec in self.attr_sets[root_set].items():
1082                if 'nested-attributes' in spec:
1083                    nested = spec['nested-attributes']
1084                    if attr in rs_members['request']:
1085                        self.pure_nested_structs[nested].request = True
1086                    if attr in rs_members['reply']:
1087                        self.pure_nested_structs[nested].reply = True
1088
1089        self._sort_pure_types()
1090
1091        # Propagate the request / reply / recursive
1092        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1093            for _, spec in self.attr_sets[attr_set].items():
1094                if 'nested-attributes' in spec:
1095                    child_name = spec['nested-attributes']
1096                    struct.child_nests.add(child_name)
1097                    child = self.pure_nested_structs.get(child_name)
1098                    if child:
1099                        if not child.recursive:
1100                            struct.child_nests.update(child.child_nests)
1101                        child.request |= struct.request
1102                        child.reply |= struct.reply
1103                if attr_set in struct.child_nests:
1104                    struct.recursive = True
1105
1106        self._sort_pure_types()
1107
1108    def _load_attr_use(self):
1109        for _, struct in self.pure_nested_structs.items():
1110            if struct.request:
1111                for _, arg in struct.member_list():
1112                    arg.request = True
1113            if struct.reply:
1114                for _, arg in struct.member_list():
1115                    arg.reply = True
1116
1117        for root_set, rs_members in self.root_sets.items():
1118            for attr, spec in self.attr_sets[root_set].items():
1119                if attr in rs_members['request']:
1120                    spec.request = True
1121                if attr in rs_members['reply']:
1122                    spec.reply = True
1123
1124    def _load_global_policy(self):
1125        global_set = set()
1126        attr_set_name = None
1127        for op_name, op in self.ops.items():
1128            if not op:
1129                continue
1130            if 'attribute-set' not in op:
1131                continue
1132
1133            if attr_set_name is None:
1134                attr_set_name = op['attribute-set']
1135            if attr_set_name != op['attribute-set']:
1136                raise Exception('For a global policy all ops must use the same set')
1137
1138            for op_mode in ['do', 'dump']:
1139                if op_mode in op:
1140                    req = op[op_mode].get('request')
1141                    if req:
1142                        global_set.update(req.get('attributes', []))
1143
1144        self.global_policy = []
1145        self.global_policy_set = attr_set_name
1146        for attr in self.attr_sets[attr_set_name]:
1147            if attr in global_set:
1148                self.global_policy.append(attr)
1149
1150    def _load_hooks(self):
1151        for op in self.ops.values():
1152            for op_mode in ['do', 'dump']:
1153                if op_mode not in op:
1154                    continue
1155                for when in ['pre', 'post']:
1156                    if when not in op[op_mode]:
1157                        continue
1158                    name = op[op_mode][when]
1159                    if name in self.hooks[when][op_mode]['set']:
1160                        continue
1161                    self.hooks[when][op_mode]['set'].add(name)
1162                    self.hooks[when][op_mode]['list'].append(name)
1163
1164
1165class RenderInfo:
1166    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1167        self.family = family
1168        self.nl = cw.nlib
1169        self.ku_space = ku_space
1170        self.op_mode = op_mode
1171        self.op = op
1172
1173        self.fixed_hdr = None
1174        if op and op.fixed_header:
1175            self.fixed_hdr = 'struct ' + c_lower(op.fixed_header)
1176
1177        # 'do' and 'dump' response parsing is identical
1178        self.type_consistent = True
1179        if op_mode != 'do' and 'dump' in op:
1180            if 'do' in op:
1181                if ('reply' in op['do']) != ('reply' in op["dump"]):
1182                    self.type_consistent = False
1183                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1184                    self.type_consistent = False
1185            else:
1186                self.type_consistent = False
1187
1188        self.attr_set = attr_set
1189        if not self.attr_set:
1190            self.attr_set = op['attribute-set']
1191
1192        self.type_name_conflict = False
1193        if op:
1194            self.type_name = c_lower(op.name)
1195        else:
1196            self.type_name = c_lower(attr_set)
1197            if attr_set in family.consts:
1198                self.type_name_conflict = True
1199
1200        self.cw = cw
1201
1202        self.struct = dict()
1203        if op_mode == 'notify':
1204            op_mode = 'do'
1205        for op_dir in ['request', 'reply']:
1206            if op:
1207                type_list = []
1208                if op_dir in op[op_mode]:
1209                    type_list = op[op_mode][op_dir]['attributes']
1210                self.struct[op_dir] = Struct(family, self.attr_set, type_list=type_list)
1211        if op_mode == 'event':
1212            self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
1213
1214
1215class CodeWriter:
1216    def __init__(self, nlib, out_file=None, overwrite=True):
1217        self.nlib = nlib
1218        self._overwrite = overwrite
1219
1220        self._nl = False
1221        self._block_end = False
1222        self._silent_block = False
1223        self._ind = 0
1224        self._ifdef_block = None
1225        if out_file is None:
1226            self._out = os.sys.stdout
1227        else:
1228            self._out = tempfile.NamedTemporaryFile('w+')
1229            self._out_file = out_file
1230
1231    def __del__(self):
1232        self.close_out_file()
1233
1234    def close_out_file(self):
1235        if self._out == os.sys.stdout:
1236            return
1237        # Avoid modifying the file if contents didn't change
1238        self._out.flush()
1239        if not self._overwrite and os.path.isfile(self._out_file):
1240            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1241                return
1242        with open(self._out_file, 'w+') as out_file:
1243            self._out.seek(0)
1244            shutil.copyfileobj(self._out, out_file)
1245            self._out.close()
1246        self._out = os.sys.stdout
1247
1248    @classmethod
1249    def _is_cond(cls, line):
1250        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1251
1252    def p(self, line, add_ind=0):
1253        if self._block_end:
1254            self._block_end = False
1255            if line.startswith('else'):
1256                line = '} ' + line
1257            else:
1258                self._out.write('\t' * self._ind + '}\n')
1259
1260        if self._nl:
1261            self._out.write('\n')
1262            self._nl = False
1263
1264        ind = self._ind
1265        if line[-1] == ':':
1266            ind -= 1
1267        if self._silent_block:
1268            ind += 1
1269        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1270        if line[0] == '#':
1271            ind = 0
1272        if add_ind:
1273            ind += add_ind
1274        self._out.write('\t' * ind + line + '\n')
1275
1276    def nl(self):
1277        self._nl = True
1278
1279    def block_start(self, line=''):
1280        if line:
1281            line = line + ' '
1282        self.p(line + '{')
1283        self._ind += 1
1284
1285    def block_end(self, line=''):
1286        if line and line[0] not in {';', ','}:
1287            line = ' ' + line
1288        self._ind -= 1
1289        self._nl = False
1290        if not line:
1291            # Delay printing closing bracket in case "else" comes next
1292            if self._block_end:
1293                self._out.write('\t' * (self._ind + 1) + '}\n')
1294            self._block_end = True
1295        else:
1296            self.p('}' + line)
1297
1298    def write_doc_line(self, doc, indent=True):
1299        words = doc.split()
1300        line = ' *'
1301        for word in words:
1302            if len(line) + len(word) >= 79:
1303                self.p(line)
1304                line = ' *'
1305                if indent:
1306                    line += '  '
1307            line += ' ' + word
1308        self.p(line)
1309
1310    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1311        if not args:
1312            args = ['void']
1313
1314        if doc:
1315            self.p('/*')
1316            self.p(' * ' + doc)
1317            self.p(' */')
1318
1319        oneline = qual_ret
1320        if qual_ret[-1] != '*':
1321            oneline += ' '
1322        oneline += f"{name}({', '.join(args)}){suffix}"
1323
1324        if len(oneline) < 80:
1325            self.p(oneline)
1326            return
1327
1328        v = qual_ret
1329        if len(v) > 3:
1330            self.p(v)
1331            v = ''
1332        elif qual_ret[-1] != '*':
1333            v += ' '
1334        v += name + '('
1335        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1336        delta_ind = len(v) - len(ind)
1337        v += args[0]
1338        i = 1
1339        while i < len(args):
1340            next_len = len(v) + len(args[i])
1341            if v[0] == '\t':
1342                next_len += delta_ind
1343            if next_len > 76:
1344                self.p(v + ',')
1345                v = ind
1346            else:
1347                v += ', '
1348            v += args[i]
1349            i += 1
1350        self.p(v + ')' + suffix)
1351
1352    def write_func_lvar(self, local_vars):
1353        if not local_vars:
1354            return
1355
1356        if type(local_vars) is str:
1357            local_vars = [local_vars]
1358
1359        local_vars.sort(key=len, reverse=True)
1360        for var in local_vars:
1361            self.p(var)
1362        self.nl()
1363
1364    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1365        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1366        self.write_func_lvar(local_vars=local_vars)
1367
1368        self.block_start()
1369        for line in body:
1370            self.p(line)
1371        self.block_end()
1372
1373    def writes_defines(self, defines):
1374        longest = 0
1375        for define in defines:
1376            if len(define[0]) > longest:
1377                longest = len(define[0])
1378        longest = ((longest + 8) // 8) * 8
1379        for define in defines:
1380            line = '#define ' + define[0]
1381            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1382            if type(define[1]) is int:
1383                line += str(define[1])
1384            elif type(define[1]) is str:
1385                line += '"' + define[1] + '"'
1386            self.p(line)
1387
1388    def write_struct_init(self, members):
1389        longest = max([len(x[0]) for x in members])
1390        longest += 1  # because we prepend a .
1391        longest = ((longest + 8) // 8) * 8
1392        for one in members:
1393            line = '.' + one[0]
1394            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1395            line += '= ' + str(one[1]) + ','
1396            self.p(line)
1397
1398    def ifdef_block(self, config):
1399        config_option = None
1400        if config:
1401            config_option = 'CONFIG_' + c_upper(config)
1402        if self._ifdef_block == config_option:
1403            return
1404
1405        if self._ifdef_block:
1406            self.p('#endif /* ' + self._ifdef_block + ' */')
1407        if config_option:
1408            self.p('#ifdef ' + config_option)
1409        self._ifdef_block = config_option
1410
1411
1412scalars = {'u8', 'u16', 'u32', 'u64', 's32', 's64', 'uint', 'sint'}
1413
1414direction_to_suffix = {
1415    'reply': '_rsp',
1416    'request': '_req',
1417    '': ''
1418}
1419
1420op_mode_to_wrapper = {
1421    'do': '',
1422    'dump': '_list',
1423    'notify': '_ntf',
1424    'event': '',
1425}
1426
1427_C_KW = {
1428    'auto',
1429    'bool',
1430    'break',
1431    'case',
1432    'char',
1433    'const',
1434    'continue',
1435    'default',
1436    'do',
1437    'double',
1438    'else',
1439    'enum',
1440    'extern',
1441    'float',
1442    'for',
1443    'goto',
1444    'if',
1445    'inline',
1446    'int',
1447    'long',
1448    'register',
1449    'return',
1450    'short',
1451    'signed',
1452    'sizeof',
1453    'static',
1454    'struct',
1455    'switch',
1456    'typedef',
1457    'union',
1458    'unsigned',
1459    'void',
1460    'volatile',
1461    'while'
1462}
1463
1464
1465def rdir(direction):
1466    if direction == 'reply':
1467        return 'request'
1468    if direction == 'request':
1469        return 'reply'
1470    return direction
1471
1472
1473def op_prefix(ri, direction, deref=False):
1474    suffix = f"_{ri.type_name}"
1475
1476    if not ri.op_mode or ri.op_mode == 'do':
1477        suffix += f"{direction_to_suffix[direction]}"
1478    else:
1479        if direction == 'request':
1480            suffix += '_req_dump'
1481        else:
1482            if ri.type_consistent:
1483                if deref:
1484                    suffix += f"{direction_to_suffix[direction]}"
1485                else:
1486                    suffix += op_mode_to_wrapper[ri.op_mode]
1487            else:
1488                suffix += '_rsp'
1489                suffix += '_dump' if deref else '_list'
1490
1491    return f"{ri.family.c_name}{suffix}"
1492
1493
1494def type_name(ri, direction, deref=False):
1495    return f"struct {op_prefix(ri, direction, deref=deref)}"
1496
1497
1498def print_prototype(ri, direction, terminate=True, doc=None):
1499    suffix = ';' if terminate else ''
1500
1501    fname = ri.op.render_name
1502    if ri.op_mode == 'dump':
1503        fname += '_dump'
1504
1505    args = ['struct ynl_sock *ys']
1506    if 'request' in ri.op[ri.op_mode]:
1507        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1508
1509    ret = 'int'
1510    if 'reply' in ri.op[ri.op_mode]:
1511        ret = f"{type_name(ri, rdir(direction))} *"
1512
1513    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1514
1515
1516def print_req_prototype(ri):
1517    print_prototype(ri, "request", doc=ri.op['doc'])
1518
1519
1520def print_dump_prototype(ri):
1521    print_prototype(ri, "request")
1522
1523
1524def put_typol_fwd(cw, struct):
1525    cw.p(f'extern const struct ynl_policy_nest {struct.render_name}_nest;')
1526
1527
1528def put_typol(cw, struct):
1529    type_max = struct.attr_set.max_name
1530    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1531
1532    for _, arg in struct.member_list():
1533        arg.attr_typol(cw)
1534
1535    cw.block_end(line=';')
1536    cw.nl()
1537
1538    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1539    cw.p(f'.max_attr = {type_max},')
1540    cw.p(f'.table = {struct.render_name}_policy,')
1541    cw.block_end(line=';')
1542    cw.nl()
1543
1544
1545def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1546    args = [f'int {arg_name}']
1547    if enum:
1548        args = [enum.user_type + ' ' + arg_name]
1549    cw.write_func_prot('const char *', f'{render_name}_str', args)
1550    cw.block_start()
1551    if enum and enum.type == 'flags':
1552        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1553    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
1554    cw.p('return NULL;')
1555    cw.p(f'return {map_name}[{arg_name}];')
1556    cw.block_end()
1557    cw.nl()
1558
1559
1560def put_op_name_fwd(family, cw):
1561    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
1562
1563
1564def put_op_name(family, cw):
1565    map_name = f'{family.c_name}_op_strmap'
1566    cw.block_start(line=f"static const char * const {map_name}[] =")
1567    for op_name, op in family.msgs.items():
1568        if op.rsp_value:
1569            # Make sure we don't add duplicated entries, if multiple commands
1570            # produce the same response in legacy families.
1571            if family.rsp_by_value[op.rsp_value] != op:
1572                cw.p(f'// skip "{op_name}", duplicate reply value')
1573                continue
1574
1575            if op.req_value == op.rsp_value:
1576                cw.p(f'[{op.enum_name}] = "{op_name}",')
1577            else:
1578                cw.p(f'[{op.rsp_value}] = "{op_name}",')
1579    cw.block_end(line=';')
1580    cw.nl()
1581
1582    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
1583
1584
1585def put_enum_to_str_fwd(family, cw, enum):
1586    args = [enum.user_type + ' value']
1587    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1588
1589
1590def put_enum_to_str(family, cw, enum):
1591    map_name = f'{enum.render_name}_strmap'
1592    cw.block_start(line=f"static const char * const {map_name}[] =")
1593    for entry in enum.entries.values():
1594        cw.p(f'[{entry.value}] = "{entry.name}",')
1595    cw.block_end(line=';')
1596    cw.nl()
1597
1598    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
1599
1600
1601def put_req_nested_prototype(ri, struct, suffix=';'):
1602    func_args = ['struct nlmsghdr *nlh',
1603                 'unsigned int attr_type',
1604                 f'{struct.ptr_name}obj']
1605
1606    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
1607                          suffix=suffix)
1608
1609
1610def put_req_nested(ri, struct):
1611    put_req_nested_prototype(ri, struct, suffix='')
1612    ri.cw.block_start()
1613    ri.cw.write_func_lvar('struct nlattr *nest;')
1614
1615    ri.cw.p("nest = ynl_attr_nest_start(nlh, attr_type);")
1616
1617    for _, arg in struct.member_list():
1618        arg.attr_put(ri, "obj")
1619
1620    ri.cw.p("ynl_attr_nest_end(nlh, nest);")
1621
1622    ri.cw.nl()
1623    ri.cw.p('return 0;')
1624    ri.cw.block_end()
1625    ri.cw.nl()
1626
1627
1628def _multi_parse(ri, struct, init_lines, local_vars):
1629    if struct.nested:
1630        iter_line = "ynl_attr_for_each_nested(attr, nested)"
1631    else:
1632        if ri.fixed_hdr:
1633            local_vars += ['void *hdr;']
1634        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
1635
1636    array_nests = set()
1637    multi_attrs = set()
1638    needs_parg = False
1639    for arg, aspec in struct.member_list():
1640        if aspec['type'] == 'indexed-array' and 'sub-type' in aspec:
1641            if aspec["sub-type"] == 'nest':
1642                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1643                array_nests.add(arg)
1644            else:
1645                raise Exception(f'Not supported sub-type {aspec["sub-type"]}')
1646        if 'multi-attr' in aspec:
1647            multi_attrs.add(arg)
1648        needs_parg |= 'nested-attributes' in aspec
1649    if array_nests or multi_attrs:
1650        local_vars.append('int i;')
1651    if needs_parg:
1652        local_vars.append('struct ynl_parse_arg parg;')
1653        init_lines.append('parg.ys = yarg->ys;')
1654
1655    all_multi = array_nests | multi_attrs
1656
1657    for anest in sorted(all_multi):
1658        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
1659
1660    ri.cw.block_start()
1661    ri.cw.write_func_lvar(local_vars)
1662
1663    for line in init_lines:
1664        ri.cw.p(line)
1665    ri.cw.nl()
1666
1667    for arg in struct.inherited:
1668        ri.cw.p(f'dst->{arg} = {arg};')
1669
1670    if ri.fixed_hdr:
1671        ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
1672        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({ri.fixed_hdr}));")
1673    for anest in sorted(all_multi):
1674        aspec = struct[anest]
1675        ri.cw.p(f"if (dst->{aspec.c_name})")
1676        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
1677
1678    ri.cw.nl()
1679    ri.cw.block_start(line=iter_line)
1680    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
1681    ri.cw.nl()
1682
1683    first = True
1684    for _, arg in struct.member_list():
1685        good = arg.attr_get(ri, 'dst', first=first)
1686        # First may be 'unused' or 'pad', ignore those
1687        first &= not good
1688
1689    ri.cw.block_end()
1690    ri.cw.nl()
1691
1692    for anest in sorted(array_nests):
1693        aspec = struct[anest]
1694
1695        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1696        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1697        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1698        ri.cw.p('i = 0;')
1699        ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1700        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1701        ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1702        ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
1703        ri.cw.p('return YNL_PARSE_CB_ERROR;')
1704        ri.cw.p('i++;')
1705        ri.cw.block_end()
1706        ri.cw.block_end()
1707    ri.cw.nl()
1708
1709    for anest in sorted(multi_attrs):
1710        aspec = struct[anest]
1711        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1712        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1713        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1714        ri.cw.p('i = 0;')
1715        if 'nested-attributes' in aspec:
1716            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1717        ri.cw.block_start(line=iter_line)
1718        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
1719        if 'nested-attributes' in aspec:
1720            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1721            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1722            ri.cw.p('return YNL_PARSE_CB_ERROR;')
1723        elif aspec.type in scalars:
1724            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
1725        else:
1726            raise Exception('Nest parsing type not supported yet')
1727        ri.cw.p('i++;')
1728        ri.cw.block_end()
1729        ri.cw.block_end()
1730        ri.cw.block_end()
1731    ri.cw.nl()
1732
1733    if struct.nested:
1734        ri.cw.p('return 0;')
1735    else:
1736        ri.cw.p('return YNL_PARSE_CB_OK;')
1737    ri.cw.block_end()
1738    ri.cw.nl()
1739
1740
1741def parse_rsp_nested_prototype(ri, struct, suffix=';'):
1742    func_args = ['struct ynl_parse_arg *yarg',
1743                 'const struct nlattr *nested']
1744    for arg in struct.inherited:
1745        func_args.append('__u32 ' + arg)
1746
1747    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
1748                          suffix=suffix)
1749
1750
1751def parse_rsp_nested(ri, struct):
1752    parse_rsp_nested_prototype(ri, struct, suffix='')
1753
1754    local_vars = ['const struct nlattr *attr;',
1755                  f'{struct.ptr_name}dst = yarg->data;']
1756    init_lines = []
1757
1758    _multi_parse(ri, struct, init_lines, local_vars)
1759
1760
1761def parse_rsp_msg(ri, deref=False):
1762    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1763        return
1764
1765    func_args = ['const struct nlmsghdr *nlh',
1766                 'struct ynl_parse_arg *yarg']
1767
1768    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1769                  'const struct nlattr *attr;']
1770    init_lines = ['dst = yarg->data;']
1771
1772    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1773
1774    if ri.struct["reply"].member_list():
1775        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1776    else:
1777        # Empty reply
1778        ri.cw.block_start()
1779        ri.cw.p('return YNL_PARSE_CB_OK;')
1780        ri.cw.block_end()
1781        ri.cw.nl()
1782
1783
1784def print_req(ri):
1785    ret_ok = '0'
1786    ret_err = '-1'
1787    direction = "request"
1788    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
1789                  'struct nlmsghdr *nlh;',
1790                  'int err;']
1791
1792    if 'reply' in ri.op[ri.op_mode]:
1793        ret_ok = 'rsp'
1794        ret_err = 'NULL'
1795        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
1796
1797    if ri.fixed_hdr:
1798        local_vars += ['size_t hdr_len;',
1799                       'void *hdr;']
1800
1801    print_prototype(ri, direction, terminate=False)
1802    ri.cw.block_start()
1803    ri.cw.write_func_lvar(local_vars)
1804
1805    ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1806
1807    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1808    if 'reply' in ri.op[ri.op_mode]:
1809        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1810    ri.cw.nl()
1811
1812    if ri.fixed_hdr:
1813        ri.cw.p("hdr_len = sizeof(req->_hdr);")
1814        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
1815        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
1816        ri.cw.nl()
1817
1818    for _, attr in ri.struct["request"].member_list():
1819        attr.attr_put(ri, "req")
1820    ri.cw.nl()
1821
1822    if 'reply' in ri.op[ri.op_mode]:
1823        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1824        ri.cw.p('yrs.yarg.data = rsp;')
1825        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
1826        if ri.op.value is not None:
1827            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
1828        else:
1829            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
1830        ri.cw.nl()
1831    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
1832    ri.cw.p('if (err < 0)')
1833    if 'reply' in ri.op[ri.op_mode]:
1834        ri.cw.p('goto err_free;')
1835    else:
1836        ri.cw.p('return -1;')
1837    ri.cw.nl()
1838
1839    ri.cw.p(f"return {ret_ok};")
1840    ri.cw.nl()
1841
1842    if 'reply' in ri.op[ri.op_mode]:
1843        ri.cw.p('err_free:')
1844        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1845        ri.cw.p(f"return {ret_err};")
1846
1847    ri.cw.block_end()
1848
1849
1850def print_dump(ri):
1851    direction = "request"
1852    print_prototype(ri, direction, terminate=False)
1853    ri.cw.block_start()
1854    local_vars = ['struct ynl_dump_state yds = {};',
1855                  'struct nlmsghdr *nlh;',
1856                  'int err;']
1857
1858    if ri.fixed_hdr:
1859        local_vars += ['size_t hdr_len;',
1860                       'void *hdr;']
1861
1862    ri.cw.write_func_lvar(local_vars)
1863
1864    ri.cw.p('yds.yarg.ys = ys;')
1865    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1866    ri.cw.p("yds.yarg.data = NULL;")
1867    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1868    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1869    if ri.op.value is not None:
1870        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
1871    else:
1872        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
1873    ri.cw.nl()
1874    ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1875
1876    if ri.fixed_hdr:
1877        ri.cw.p("hdr_len = sizeof(req->_hdr);")
1878        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
1879        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
1880        ri.cw.nl()
1881
1882    if "request" in ri.op[ri.op_mode]:
1883        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1884        ri.cw.nl()
1885        for _, attr in ri.struct["request"].member_list():
1886            attr.attr_put(ri, "req")
1887    ri.cw.nl()
1888
1889    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
1890    ri.cw.p('if (err < 0)')
1891    ri.cw.p('goto free_list;')
1892    ri.cw.nl()
1893
1894    ri.cw.p('return yds.first;')
1895    ri.cw.nl()
1896    ri.cw.p('free_list:')
1897    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
1898    ri.cw.p('return NULL;')
1899    ri.cw.block_end()
1900
1901
1902def call_free(ri, direction, var):
1903    return f"{op_prefix(ri, direction)}_free({var});"
1904
1905
1906def free_arg_name(direction):
1907    if direction:
1908        return direction_to_suffix[direction][1:]
1909    return 'obj'
1910
1911
1912def print_alloc_wrapper(ri, direction):
1913    name = op_prefix(ri, direction)
1914    ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
1915    ri.cw.block_start()
1916    ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
1917    ri.cw.block_end()
1918
1919
1920def print_free_prototype(ri, direction, suffix=';'):
1921    name = op_prefix(ri, direction)
1922    struct_name = name
1923    if ri.type_name_conflict:
1924        struct_name += '_'
1925    arg = free_arg_name(direction)
1926    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
1927
1928
1929def _print_type(ri, direction, struct):
1930    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
1931    if not direction and ri.type_name_conflict:
1932        suffix += '_'
1933
1934    if ri.op_mode == 'dump':
1935        suffix += '_dump'
1936
1937    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
1938
1939    if ri.fixed_hdr:
1940        ri.cw.p(ri.fixed_hdr + ' _hdr;')
1941        ri.cw.nl()
1942
1943    meta_started = False
1944    for _, attr in struct.member_list():
1945        for type_filter in ['len', 'bit']:
1946            line = attr.presence_member(ri.ku_space, type_filter)
1947            if line:
1948                if not meta_started:
1949                    ri.cw.block_start(line=f"struct")
1950                    meta_started = True
1951                ri.cw.p(line)
1952    if meta_started:
1953        ri.cw.block_end(line='_present;')
1954        ri.cw.nl()
1955
1956    for arg in struct.inherited:
1957        ri.cw.p(f"__u32 {arg};")
1958
1959    for _, attr in struct.member_list():
1960        attr.struct_member(ri)
1961
1962    ri.cw.block_end(line=';')
1963    ri.cw.nl()
1964
1965
1966def print_type(ri, direction):
1967    _print_type(ri, direction, ri.struct[direction])
1968
1969
1970def print_type_full(ri, struct):
1971    _print_type(ri, "", struct)
1972
1973
1974def print_type_helpers(ri, direction, deref=False):
1975    print_free_prototype(ri, direction)
1976    ri.cw.nl()
1977
1978    if ri.ku_space == 'user' and direction == 'request':
1979        for _, attr in ri.struct[direction].member_list():
1980            attr.setter(ri, ri.attr_set, direction, deref=deref)
1981    ri.cw.nl()
1982
1983
1984def print_req_type_helpers(ri):
1985    if len(ri.struct["request"].attr_list) == 0:
1986        return
1987    print_alloc_wrapper(ri, "request")
1988    print_type_helpers(ri, "request")
1989
1990
1991def print_rsp_type_helpers(ri):
1992    if 'reply' not in ri.op[ri.op_mode]:
1993        return
1994    print_type_helpers(ri, "reply")
1995
1996
1997def print_parse_prototype(ri, direction, terminate=True):
1998    suffix = "_rsp" if direction == "reply" else "_req"
1999    term = ';' if terminate else ''
2000
2001    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
2002                          ['const struct nlattr **tb',
2003                           f"struct {ri.op.render_name}{suffix} *req"],
2004                          suffix=term)
2005
2006
2007def print_req_type(ri):
2008    if len(ri.struct["request"].attr_list) == 0:
2009        return
2010    print_type(ri, "request")
2011
2012
2013def print_req_free(ri):
2014    if 'request' not in ri.op[ri.op_mode]:
2015        return
2016    _free_type(ri, 'request', ri.struct['request'])
2017
2018
2019def print_rsp_type(ri):
2020    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2021        direction = 'reply'
2022    elif ri.op_mode == 'event':
2023        direction = 'reply'
2024    else:
2025        return
2026    print_type(ri, direction)
2027
2028
2029def print_wrapped_type(ri):
2030    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2031    if ri.op_mode == 'dump':
2032        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2033    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2034        ri.cw.p('__u16 family;')
2035        ri.cw.p('__u8 cmd;')
2036        ri.cw.p('struct ynl_ntf_base_type *next;')
2037        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2038    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2039    ri.cw.block_end(line=';')
2040    ri.cw.nl()
2041    print_free_prototype(ri, 'reply')
2042    ri.cw.nl()
2043
2044
2045def _free_type_members_iter(ri, struct):
2046    for _, attr in struct.member_list():
2047        if attr.free_needs_iter():
2048            ri.cw.p('unsigned int i;')
2049            ri.cw.nl()
2050            break
2051
2052
2053def _free_type_members(ri, var, struct, ref=''):
2054    for _, attr in struct.member_list():
2055        attr.free(ri, var, ref)
2056
2057
2058def _free_type(ri, direction, struct):
2059    var = free_arg_name(direction)
2060
2061    print_free_prototype(ri, direction, suffix='')
2062    ri.cw.block_start()
2063    _free_type_members_iter(ri, struct)
2064    _free_type_members(ri, var, struct)
2065    if direction:
2066        ri.cw.p(f'free({var});')
2067    ri.cw.block_end()
2068    ri.cw.nl()
2069
2070
2071def free_rsp_nested_prototype(ri):
2072        print_free_prototype(ri, "")
2073
2074
2075def free_rsp_nested(ri, struct):
2076    _free_type(ri, "", struct)
2077
2078
2079def print_rsp_free(ri):
2080    if 'reply' not in ri.op[ri.op_mode]:
2081        return
2082    _free_type(ri, 'reply', ri.struct['reply'])
2083
2084
2085def print_dump_type_free(ri):
2086    sub_type = type_name(ri, 'reply')
2087
2088    print_free_prototype(ri, 'reply', suffix='')
2089    ri.cw.block_start()
2090    ri.cw.p(f"{sub_type} *next = rsp;")
2091    ri.cw.nl()
2092    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2093    _free_type_members_iter(ri, ri.struct['reply'])
2094    ri.cw.p('rsp = next;')
2095    ri.cw.p('next = rsp->next;')
2096    ri.cw.nl()
2097
2098    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2099    ri.cw.p(f'free(rsp);')
2100    ri.cw.block_end()
2101    ri.cw.block_end()
2102    ri.cw.nl()
2103
2104
2105def print_ntf_type_free(ri):
2106    print_free_prototype(ri, 'reply', suffix='')
2107    ri.cw.block_start()
2108    _free_type_members_iter(ri, ri.struct['reply'])
2109    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2110    ri.cw.p(f'free(rsp);')
2111    ri.cw.block_end()
2112    ri.cw.nl()
2113
2114
2115def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2116    if terminate and ri and policy_should_be_static(struct.family):
2117        return
2118
2119    if terminate:
2120        prefix = 'extern '
2121    else:
2122        if ri and policy_should_be_static(struct.family):
2123            prefix = 'static '
2124        else:
2125            prefix = ''
2126
2127    suffix = ';' if terminate else ' = {'
2128
2129    max_attr = struct.attr_max_val
2130    if ri:
2131        name = ri.op.render_name
2132        if ri.op.dual_policy:
2133            name += '_' + ri.op_mode
2134    else:
2135        name = struct.render_name
2136    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2137
2138
2139def print_req_policy(cw, struct, ri=None):
2140    if ri and ri.op:
2141        cw.ifdef_block(ri.op.get('config-cond', None))
2142    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2143    for _, arg in struct.member_list():
2144        arg.attr_policy(cw)
2145    cw.p("};")
2146    cw.ifdef_block(None)
2147    cw.nl()
2148
2149
2150def kernel_can_gen_family_struct(family):
2151    return family.proto == 'genetlink'
2152
2153
2154def policy_should_be_static(family):
2155    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2156
2157
2158def print_kernel_policy_ranges(family, cw):
2159    first = True
2160    for _, attr_set in family.attr_sets.items():
2161        if attr_set.subset_of:
2162            continue
2163
2164        for _, attr in attr_set.items():
2165            if not attr.request:
2166                continue
2167            if 'full-range' not in attr.checks:
2168                continue
2169
2170            if first:
2171                cw.p('/* Integer value ranges */')
2172                first = False
2173
2174            sign = '' if attr.type[0] == 'u' else '_signed'
2175            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2176            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2177            members = []
2178            if 'min' in attr.checks:
2179                members.append(('min', attr.get_limit_str('min', suffix=suffix)))
2180            if 'max' in attr.checks:
2181                members.append(('max', attr.get_limit_str('max', suffix=suffix)))
2182            cw.write_struct_init(members)
2183            cw.block_end(line=';')
2184            cw.nl()
2185
2186
2187def print_kernel_op_table_fwd(family, cw, terminate):
2188    exported = not kernel_can_gen_family_struct(family)
2189
2190    if not terminate or exported:
2191        cw.p(f"/* Ops table for {family.ident_name} */")
2192
2193        pol_to_struct = {'global': 'genl_small_ops',
2194                         'per-op': 'genl_ops',
2195                         'split': 'genl_split_ops'}
2196        struct_type = pol_to_struct[family.kernel_policy]
2197
2198        if not exported:
2199            cnt = ""
2200        elif family.kernel_policy == 'split':
2201            cnt = 0
2202            for op in family.ops.values():
2203                if 'do' in op:
2204                    cnt += 1
2205                if 'dump' in op:
2206                    cnt += 1
2207        else:
2208            cnt = len(family.ops)
2209
2210        qual = 'static const' if not exported else 'const'
2211        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2212        if terminate:
2213            cw.p(f"extern {line};")
2214        else:
2215            cw.block_start(line=line + ' =')
2216
2217    if not terminate:
2218        return
2219
2220    cw.nl()
2221    for name in family.hooks['pre']['do']['list']:
2222        cw.write_func_prot('int', c_lower(name),
2223                           ['const struct genl_split_ops *ops',
2224                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2225    for name in family.hooks['post']['do']['list']:
2226        cw.write_func_prot('void', c_lower(name),
2227                           ['const struct genl_split_ops *ops',
2228                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2229    for name in family.hooks['pre']['dump']['list']:
2230        cw.write_func_prot('int', c_lower(name),
2231                           ['struct netlink_callback *cb'], suffix=';')
2232    for name in family.hooks['post']['dump']['list']:
2233        cw.write_func_prot('int', c_lower(name),
2234                           ['struct netlink_callback *cb'], suffix=';')
2235
2236    cw.nl()
2237
2238    for op_name, op in family.ops.items():
2239        if op.is_async:
2240            continue
2241
2242        if 'do' in op:
2243            name = c_lower(f"{family.ident_name}-nl-{op_name}-doit")
2244            cw.write_func_prot('int', name,
2245                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2246
2247        if 'dump' in op:
2248            name = c_lower(f"{family.ident_name}-nl-{op_name}-dumpit")
2249            cw.write_func_prot('int', name,
2250                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2251    cw.nl()
2252
2253
2254def print_kernel_op_table_hdr(family, cw):
2255    print_kernel_op_table_fwd(family, cw, terminate=True)
2256
2257
2258def print_kernel_op_table(family, cw):
2259    print_kernel_op_table_fwd(family, cw, terminate=False)
2260    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2261        for op_name, op in family.ops.items():
2262            if op.is_async:
2263                continue
2264
2265            cw.ifdef_block(op.get('config-cond', None))
2266            cw.block_start()
2267            members = [('cmd', op.enum_name)]
2268            if 'dont-validate' in op:
2269                members.append(('validate',
2270                                ' | '.join([c_upper('genl-dont-validate-' + x)
2271                                            for x in op['dont-validate']])), )
2272            for op_mode in ['do', 'dump']:
2273                if op_mode in op:
2274                    name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2275                    members.append((op_mode + 'it', name))
2276            if family.kernel_policy == 'per-op':
2277                struct = Struct(family, op['attribute-set'],
2278                                type_list=op['do']['request']['attributes'])
2279
2280                name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2281                members.append(('policy', name))
2282                members.append(('maxattr', struct.attr_max_val.enum_name))
2283            if 'flags' in op:
2284                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2285            cw.write_struct_init(members)
2286            cw.block_end(line=',')
2287    elif family.kernel_policy == 'split':
2288        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2289                    'dump': {'pre': 'start', 'post': 'done'}}
2290
2291        for op_name, op in family.ops.items():
2292            for op_mode in ['do', 'dump']:
2293                if op.is_async or op_mode not in op:
2294                    continue
2295
2296                cw.ifdef_block(op.get('config-cond', None))
2297                cw.block_start()
2298                members = [('cmd', op.enum_name)]
2299                if 'dont-validate' in op:
2300                    dont_validate = []
2301                    for x in op['dont-validate']:
2302                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2303                            continue
2304                        if op_mode == "dump" and x == 'strict':
2305                            continue
2306                        dont_validate.append(x)
2307
2308                    if dont_validate:
2309                        members.append(('validate',
2310                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2311                                                    for x in dont_validate])), )
2312                name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2313                if 'pre' in op[op_mode]:
2314                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2315                members.append((op_mode + 'it', name))
2316                if 'post' in op[op_mode]:
2317                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2318                if 'request' in op[op_mode]:
2319                    struct = Struct(family, op['attribute-set'],
2320                                    type_list=op[op_mode]['request']['attributes'])
2321
2322                    if op.dual_policy:
2323                        name = c_lower(f"{family.ident_name}-{op_name}-{op_mode}-nl-policy")
2324                    else:
2325                        name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2326                    members.append(('policy', name))
2327                    members.append(('maxattr', struct.attr_max_val.enum_name))
2328                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2329                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2330                cw.write_struct_init(members)
2331                cw.block_end(line=',')
2332    cw.ifdef_block(None)
2333
2334    cw.block_end(line=';')
2335    cw.nl()
2336
2337
2338def print_kernel_mcgrp_hdr(family, cw):
2339    if not family.mcgrps['list']:
2340        return
2341
2342    cw.block_start('enum')
2343    for grp in family.mcgrps['list']:
2344        grp_id = c_upper(f"{family.ident_name}-nlgrp-{grp['name']},")
2345        cw.p(grp_id)
2346    cw.block_end(';')
2347    cw.nl()
2348
2349
2350def print_kernel_mcgrp_src(family, cw):
2351    if not family.mcgrps['list']:
2352        return
2353
2354    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
2355    for grp in family.mcgrps['list']:
2356        name = grp['name']
2357        grp_id = c_upper(f"{family.ident_name}-nlgrp-{name}")
2358        cw.p('[' + grp_id + '] = { "' + name + '", },')
2359    cw.block_end(';')
2360    cw.nl()
2361
2362
2363def print_kernel_family_struct_hdr(family, cw):
2364    if not kernel_can_gen_family_struct(family):
2365        return
2366
2367    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
2368    cw.nl()
2369    if 'sock-priv' in family.kernel_family:
2370        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
2371        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
2372        cw.nl()
2373
2374
2375def print_kernel_family_struct_src(family, cw):
2376    if not kernel_can_gen_family_struct(family):
2377        return
2378
2379    cw.block_start(f"struct genl_family {family.ident_name}_nl_family __ro_after_init =")
2380    cw.p('.name\t\t= ' + family.fam_key + ',')
2381    cw.p('.version\t= ' + family.ver_key + ',')
2382    cw.p('.netnsok\t= true,')
2383    cw.p('.parallel_ops\t= true,')
2384    cw.p('.module\t\t= THIS_MODULE,')
2385    if family.kernel_policy == 'per-op':
2386        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
2387        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2388    elif family.kernel_policy == 'split':
2389        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
2390        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2391    if family.mcgrps['list']:
2392        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
2393        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
2394    if 'sock-priv' in family.kernel_family:
2395        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
2396        # Force cast here, actual helpers take pointer to the real type.
2397        cw.p(f'.sock_priv_init\t= (void *){family.c_name}_nl_sock_priv_init,')
2398        cw.p(f'.sock_priv_destroy = (void *){family.c_name}_nl_sock_priv_destroy,')
2399    cw.block_end(';')
2400
2401
2402def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2403    start_line = 'enum'
2404    if enum_name in obj:
2405        if obj[enum_name]:
2406            start_line = 'enum ' + c_lower(obj[enum_name])
2407    elif ckey and ckey in obj:
2408        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
2409    cw.block_start(line=start_line)
2410
2411
2412def render_uapi(family, cw):
2413    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
2414    cw.p('#ifndef ' + hdr_prot)
2415    cw.p('#define ' + hdr_prot)
2416    cw.nl()
2417
2418    defines = [(family.fam_key, family["name"]),
2419               (family.ver_key, family.get('version', 1))]
2420    cw.writes_defines(defines)
2421    cw.nl()
2422
2423    defines = []
2424    for const in family['definitions']:
2425        if const['type'] != 'const':
2426            cw.writes_defines(defines)
2427            defines = []
2428            cw.nl()
2429
2430        # Write kdoc for enum and flags (one day maybe also structs)
2431        if const['type'] == 'enum' or const['type'] == 'flags':
2432            enum = family.consts[const['name']]
2433
2434            if enum.has_doc():
2435                cw.p('/**')
2436                doc = ''
2437                if 'doc' in enum:
2438                    doc = ' - ' + enum['doc']
2439                cw.write_doc_line(enum.enum_name + doc)
2440                for entry in enum.entries.values():
2441                    if entry.has_doc():
2442                        doc = '@' + entry.c_name + ': ' + entry['doc']
2443                        cw.write_doc_line(doc)
2444                cw.p(' */')
2445
2446            uapi_enum_start(family, cw, const, 'name')
2447            name_pfx = const.get('name-prefix', f"{family.ident_name}-{const['name']}-")
2448            for entry in enum.entries.values():
2449                suffix = ','
2450                if entry.value_change:
2451                    suffix = f" = {entry.user_value()}" + suffix
2452                cw.p(entry.c_name + suffix)
2453
2454            if const.get('render-max', False):
2455                cw.nl()
2456                cw.p('/* private: */')
2457                if const['type'] == 'flags':
2458                    max_name = c_upper(name_pfx + 'mask')
2459                    max_val = f' = {enum.get_mask()},'
2460                    cw.p(max_name + max_val)
2461                else:
2462                    max_name = c_upper(name_pfx + 'max')
2463                    cw.p('__' + max_name + ',')
2464                    cw.p(max_name + ' = (__' + max_name + ' - 1)')
2465            cw.block_end(line=';')
2466            cw.nl()
2467        elif const['type'] == 'const':
2468            defines.append([c_upper(family.get('c-define-name',
2469                                               f"{family.ident_name}-{const['name']}")),
2470                            const['value']])
2471
2472    if defines:
2473        cw.writes_defines(defines)
2474        cw.nl()
2475
2476    max_by_define = family.get('max-by-define', False)
2477
2478    for _, attr_set in family.attr_sets.items():
2479        if attr_set.subset_of:
2480            continue
2481
2482        max_value = f"({attr_set.cnt_name} - 1)"
2483
2484        val = 0
2485        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2486        for _, attr in attr_set.items():
2487            suffix = ','
2488            if attr.value != val:
2489                suffix = f" = {attr.value},"
2490                val = attr.value
2491            val += 1
2492            cw.p(attr.enum_name + suffix)
2493        cw.nl()
2494        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
2495        if not max_by_define:
2496            cw.p(f"{attr_set.max_name} = {max_value}")
2497        cw.block_end(line=';')
2498        if max_by_define:
2499            cw.p(f"#define {attr_set.max_name} {max_value}")
2500        cw.nl()
2501
2502    # Commands
2503    separate_ntf = 'async-prefix' in family['operations']
2504
2505    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2506    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2507    max_value = f"({cnt_name} - 1)"
2508
2509    uapi_enum_start(family, cw, family['operations'], 'enum-name')
2510    val = 0
2511    for op in family.msgs.values():
2512        if separate_ntf and ('notify' in op or 'event' in op):
2513            continue
2514
2515        suffix = ','
2516        if op.value != val:
2517            suffix = f" = {op.value},"
2518            val = op.value
2519        cw.p(op.enum_name + suffix)
2520        val += 1
2521    cw.nl()
2522    cw.p(cnt_name + ('' if max_by_define else ','))
2523    if not max_by_define:
2524        cw.p(f"{max_name} = {max_value}")
2525    cw.block_end(line=';')
2526    if max_by_define:
2527        cw.p(f"#define {max_name} {max_value}")
2528    cw.nl()
2529
2530    if separate_ntf:
2531        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2532        for op in family.msgs.values():
2533            if separate_ntf and not ('notify' in op or 'event' in op):
2534                continue
2535
2536            suffix = ','
2537            if 'value' in op:
2538                suffix = f" = {op['value']},"
2539            cw.p(op.enum_name + suffix)
2540        cw.block_end(line=';')
2541        cw.nl()
2542
2543    # Multicast
2544    defines = []
2545    for grp in family.mcgrps['list']:
2546        name = grp['name']
2547        defines.append([c_upper(grp.get('c-define-name', f"{family.ident_name}-mcgrp-{name}")),
2548                        f'{name}'])
2549    cw.nl()
2550    if defines:
2551        cw.writes_defines(defines)
2552        cw.nl()
2553
2554    cw.p(f'#endif /* {hdr_prot} */')
2555
2556
2557def _render_user_ntf_entry(ri, op):
2558    ri.cw.block_start(line=f"[{op.enum_name}] = ")
2559    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
2560    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
2561    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
2562    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
2563    ri.cw.block_end(line=',')
2564
2565
2566def render_user_family(family, cw, prototype):
2567    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
2568    if prototype:
2569        cw.p(f'extern {symbol};')
2570        return
2571
2572    if family.ntfs:
2573        cw.block_start(line=f"static const struct ynl_ntf_info {family['name']}_ntf_info[] = ")
2574        for ntf_op_name, ntf_op in family.ntfs.items():
2575            if 'notify' in ntf_op:
2576                op = family.ops[ntf_op['notify']]
2577                ri = RenderInfo(cw, family, "user", op, "notify")
2578            elif 'event' in ntf_op:
2579                ri = RenderInfo(cw, family, "user", ntf_op, "event")
2580            else:
2581                raise Exception('Invalid notification ' + ntf_op_name)
2582            _render_user_ntf_entry(ri, ntf_op)
2583        for op_name, op in family.ops.items():
2584            if 'event' not in op:
2585                continue
2586            ri = RenderInfo(cw, family, "user", op, "event")
2587            _render_user_ntf_entry(ri, op)
2588        cw.block_end(line=";")
2589        cw.nl()
2590
2591    cw.block_start(f'{symbol} = ')
2592    cw.p(f'.name\t\t= "{family.c_name}",')
2593    if family.fixed_header:
2594        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
2595    else:
2596        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
2597    if family.ntfs:
2598        cw.p(f".ntf_info\t= {family['name']}_ntf_info,")
2599        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family['name']}_ntf_info),")
2600    cw.block_end(line=';')
2601
2602
2603def family_contains_bitfield32(family):
2604    for _, attr_set in family.attr_sets.items():
2605        if attr_set.subset_of:
2606            continue
2607        for _, attr in attr_set.items():
2608            if attr.type == "bitfield32":
2609                return True
2610    return False
2611
2612
2613def find_kernel_root(full_path):
2614    sub_path = ''
2615    while True:
2616        sub_path = os.path.join(os.path.basename(full_path), sub_path)
2617        full_path = os.path.dirname(full_path)
2618        maintainers = os.path.join(full_path, "MAINTAINERS")
2619        if os.path.exists(maintainers):
2620            return full_path, sub_path[:-1]
2621
2622
2623def main():
2624    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2625    parser.add_argument('--mode', dest='mode', type=str, required=True)
2626    parser.add_argument('--spec', dest='spec', type=str, required=True)
2627    parser.add_argument('--header', dest='header', action='store_true', default=None)
2628    parser.add_argument('--source', dest='header', action='store_false')
2629    parser.add_argument('--user-header', nargs='+', default=[])
2630    parser.add_argument('--cmp-out', action='store_true', default=None,
2631                        help='Do not overwrite the output file if the new output is identical to the old')
2632    parser.add_argument('--exclude-op', action='append', default=[])
2633    parser.add_argument('-o', dest='out_file', type=str, default=None)
2634    args = parser.parse_args()
2635
2636    if args.header is None:
2637        parser.error("--header or --source is required")
2638
2639    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
2640
2641    try:
2642        parsed = Family(args.spec, exclude_ops)
2643        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
2644            print('Spec license:', parsed.license)
2645            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
2646            os.sys.exit(1)
2647    except yaml.YAMLError as exc:
2648        print(exc)
2649        os.sys.exit(1)
2650        return
2651
2652    supported_models = ['unified']
2653    if args.mode in ['user', 'kernel']:
2654        supported_models += ['directional']
2655    if parsed.msg_id_model not in supported_models:
2656        print(f'Message enum-model {parsed.msg_id_model} not supported for {args.mode} generation')
2657        os.sys.exit(1)
2658
2659    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
2660
2661    _, spec_kernel = find_kernel_root(args.spec)
2662    if args.mode == 'uapi' or args.header:
2663        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
2664    else:
2665        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
2666    cw.p("/* Do not edit directly, auto-generated from: */")
2667    cw.p(f"/*\t{spec_kernel} */")
2668    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2669    if args.exclude_op or args.user_header:
2670        line = ''
2671        line += ' --user-header '.join([''] + args.user_header)
2672        line += ' --exclude-op '.join([''] + args.exclude_op)
2673        cw.p(f'/* YNL-ARG{line} */')
2674    cw.nl()
2675
2676    if args.mode == 'uapi':
2677        render_uapi(parsed, cw)
2678        return
2679
2680    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
2681    if args.header:
2682        cw.p('#ifndef ' + hdr_prot)
2683        cw.p('#define ' + hdr_prot)
2684        cw.nl()
2685
2686    hdr_file=os.path.basename(args.out_file[:-2]) + ".h"
2687
2688    if args.mode == 'kernel':
2689        cw.p('#include <net/netlink.h>')
2690        cw.p('#include <net/genetlink.h>')
2691        cw.nl()
2692        if not args.header:
2693            if args.out_file:
2694                cw.p(f'#include "{hdr_file}"')
2695            cw.nl()
2696        headers = ['uapi/' + parsed.uapi_header]
2697        headers += parsed.kernel_family.get('headers', [])
2698    else:
2699        cw.p('#include <stdlib.h>')
2700        cw.p('#include <string.h>')
2701        if args.header:
2702            cw.p('#include <linux/types.h>')
2703            if family_contains_bitfield32(parsed):
2704                cw.p('#include <linux/netlink.h>')
2705        else:
2706            cw.p(f'#include "{hdr_file}"')
2707            cw.p('#include "ynl.h"')
2708        headers = [parsed.uapi_header]
2709    for definition in parsed['definitions']:
2710        if 'header' in definition:
2711            headers.append(definition['header'])
2712    for one in headers:
2713        cw.p(f"#include <{one}>")
2714    cw.nl()
2715
2716    if args.mode == "user":
2717        if not args.header:
2718            cw.p("#include <linux/genetlink.h>")
2719            cw.nl()
2720            for one in args.user_header:
2721                cw.p(f'#include "{one}"')
2722        else:
2723            cw.p('struct ynl_sock;')
2724            cw.nl()
2725            render_user_family(parsed, cw, True)
2726        cw.nl()
2727
2728    if args.mode == "kernel":
2729        if args.header:
2730            for _, struct in sorted(parsed.pure_nested_structs.items()):
2731                if struct.request:
2732                    cw.p('/* Common nested types */')
2733                    break
2734            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2735                if struct.request:
2736                    print_req_policy_fwd(cw, struct)
2737            cw.nl()
2738
2739            if parsed.kernel_policy == 'global':
2740                cw.p(f"/* Global operation policy for {parsed.name} */")
2741
2742                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2743                print_req_policy_fwd(cw, struct)
2744                cw.nl()
2745
2746            if parsed.kernel_policy in {'per-op', 'split'}:
2747                for op_name, op in parsed.ops.items():
2748                    if 'do' in op and 'event' not in op:
2749                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
2750                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2751                        cw.nl()
2752
2753            print_kernel_op_table_hdr(parsed, cw)
2754            print_kernel_mcgrp_hdr(parsed, cw)
2755            print_kernel_family_struct_hdr(parsed, cw)
2756        else:
2757            print_kernel_policy_ranges(parsed, cw)
2758
2759            for _, struct in sorted(parsed.pure_nested_structs.items()):
2760                if struct.request:
2761                    cw.p('/* Common nested types */')
2762                    break
2763            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2764                if struct.request:
2765                    print_req_policy(cw, struct)
2766            cw.nl()
2767
2768            if parsed.kernel_policy == 'global':
2769                cw.p(f"/* Global operation policy for {parsed.name} */")
2770
2771                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2772                print_req_policy(cw, struct)
2773                cw.nl()
2774
2775            for op_name, op in parsed.ops.items():
2776                if parsed.kernel_policy in {'per-op', 'split'}:
2777                    for op_mode in ['do', 'dump']:
2778                        if op_mode in op and 'request' in op[op_mode]:
2779                            cw.p(f"/* {op.enum_name} - {op_mode} */")
2780                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
2781                            print_req_policy(cw, ri.struct['request'], ri=ri)
2782                            cw.nl()
2783
2784            print_kernel_op_table(parsed, cw)
2785            print_kernel_mcgrp_src(parsed, cw)
2786            print_kernel_family_struct_src(parsed, cw)
2787
2788    if args.mode == "user":
2789        if args.header:
2790            cw.p('/* Enums */')
2791            put_op_name_fwd(parsed, cw)
2792
2793            for name, const in parsed.consts.items():
2794                if isinstance(const, EnumSet):
2795                    put_enum_to_str_fwd(parsed, cw, const)
2796            cw.nl()
2797
2798            cw.p('/* Common nested types */')
2799            for attr_set, struct in parsed.pure_nested_structs.items():
2800                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2801                print_type_full(ri, struct)
2802
2803            for op_name, op in parsed.ops.items():
2804                cw.p(f"/* ============== {op.enum_name} ============== */")
2805
2806                if 'do' in op and 'event' not in op:
2807                    cw.p(f"/* {op.enum_name} - do */")
2808                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2809                    print_req_type(ri)
2810                    print_req_type_helpers(ri)
2811                    cw.nl()
2812                    print_rsp_type(ri)
2813                    print_rsp_type_helpers(ri)
2814                    cw.nl()
2815                    print_req_prototype(ri)
2816                    cw.nl()
2817
2818                if 'dump' in op:
2819                    cw.p(f"/* {op.enum_name} - dump */")
2820                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
2821                    print_req_type(ri)
2822                    print_req_type_helpers(ri)
2823                    if not ri.type_consistent:
2824                        print_rsp_type(ri)
2825                    print_wrapped_type(ri)
2826                    print_dump_prototype(ri)
2827                    cw.nl()
2828
2829                if op.has_ntf:
2830                    cw.p(f"/* {op.enum_name} - notify */")
2831                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2832                    if not ri.type_consistent:
2833                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2834                    print_wrapped_type(ri)
2835
2836            for op_name, op in parsed.ntfs.items():
2837                if 'event' in op:
2838                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
2839                    cw.p(f"/* {op.enum_name} - event */")
2840                    print_rsp_type(ri)
2841                    cw.nl()
2842                    print_wrapped_type(ri)
2843            cw.nl()
2844        else:
2845            cw.p('/* Enums */')
2846            put_op_name(parsed, cw)
2847
2848            for name, const in parsed.consts.items():
2849                if isinstance(const, EnumSet):
2850                    put_enum_to_str(parsed, cw, const)
2851            cw.nl()
2852
2853            has_recursive_nests = False
2854            cw.p('/* Policies */')
2855            for struct in parsed.pure_nested_structs.values():
2856                if struct.recursive:
2857                    put_typol_fwd(cw, struct)
2858                    has_recursive_nests = True
2859            if has_recursive_nests:
2860                cw.nl()
2861            for name in parsed.pure_nested_structs:
2862                struct = Struct(parsed, name)
2863                put_typol(cw, struct)
2864            for name in parsed.root_sets:
2865                struct = Struct(parsed, name)
2866                put_typol(cw, struct)
2867
2868            cw.p('/* Common nested types */')
2869            if has_recursive_nests:
2870                for attr_set, struct in parsed.pure_nested_structs.items():
2871                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2872                    free_rsp_nested_prototype(ri)
2873                    if struct.request:
2874                        put_req_nested_prototype(ri, struct)
2875                    if struct.reply:
2876                        parse_rsp_nested_prototype(ri, struct)
2877                cw.nl()
2878            for attr_set, struct in parsed.pure_nested_structs.items():
2879                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2880
2881                free_rsp_nested(ri, struct)
2882                if struct.request:
2883                    put_req_nested(ri, struct)
2884                if struct.reply:
2885                    parse_rsp_nested(ri, struct)
2886
2887            for op_name, op in parsed.ops.items():
2888                cw.p(f"/* ============== {op.enum_name} ============== */")
2889                if 'do' in op and 'event' not in op:
2890                    cw.p(f"/* {op.enum_name} - do */")
2891                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2892                    print_req_free(ri)
2893                    print_rsp_free(ri)
2894                    parse_rsp_msg(ri)
2895                    print_req(ri)
2896                    cw.nl()
2897
2898                if 'dump' in op:
2899                    cw.p(f"/* {op.enum_name} - dump */")
2900                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
2901                    if not ri.type_consistent:
2902                        parse_rsp_msg(ri, deref=True)
2903                    print_req_free(ri)
2904                    print_dump_type_free(ri)
2905                    print_dump(ri)
2906                    cw.nl()
2907
2908                if op.has_ntf:
2909                    cw.p(f"/* {op.enum_name} - notify */")
2910                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2911                    if not ri.type_consistent:
2912                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2913                    print_ntf_type_free(ri)
2914
2915            for op_name, op in parsed.ntfs.items():
2916                if 'event' in op:
2917                    cw.p(f"/* {op.enum_name} - event */")
2918
2919                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2920                    parse_rsp_msg(ri)
2921
2922                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
2923                    print_ntf_type_free(ri)
2924            cw.nl()
2925            render_user_family(parsed, cw, False)
2926
2927    if args.header:
2928        cw.p(f'#endif /* {hdr_prot} */')
2929
2930
2931if __name__ == "__main__":
2932    main()
2933