xref: /linux/tools/net/ynl/pyynl/ynl_gen_c.py (revision 1a9239bb4253f9076b5b4b2a1a4e8d7defd77a95)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import filecmp
7import pathlib
8import os
9import re
10import shutil
11import sys
12import tempfile
13import yaml
14
15sys.path.append(pathlib.Path(__file__).resolve().parent.as_posix())
16from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
17
18
19def c_upper(name):
20    return name.upper().replace('-', '_')
21
22
23def c_lower(name):
24    return name.lower().replace('-', '_')
25
26
27def limit_to_number(name):
28    """
29    Turn a string limit like u32-max or s64-min into its numerical value
30    """
31    if name[0] == 'u' and name.endswith('-min'):
32        return 0
33    width = int(name[1:-4])
34    if name[0] == 's':
35        width -= 1
36    value = (1 << width) - 1
37    if name[0] == 's' and name.endswith('-min'):
38        value = -value - 1
39    return value
40
41
42class BaseNlLib:
43    def get_family_id(self):
44        return 'ys->family_id'
45
46
47class Type(SpecAttr):
48    def __init__(self, family, attr_set, attr, value):
49        super().__init__(family, attr_set, attr, value)
50
51        self.attr = attr
52        self.attr_set = attr_set
53        self.type = attr['type']
54        self.checks = attr.get('checks', {})
55
56        self.request = False
57        self.reply = False
58
59        if 'len' in attr:
60            self.len = attr['len']
61
62        if 'nested-attributes' in attr:
63            self.nested_attrs = attr['nested-attributes']
64            if self.nested_attrs == family.name:
65                self.nested_render_name = c_lower(f"{family.ident_name}")
66            else:
67                self.nested_render_name = c_lower(f"{family.ident_name}_{self.nested_attrs}")
68
69            if self.nested_attrs in self.family.consts:
70                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
71            else:
72                self.nested_struct_type = 'struct ' + self.nested_render_name
73
74        self.c_name = c_lower(self.name)
75        if self.c_name in _C_KW:
76            self.c_name += '_'
77        if self.c_name[0].isdigit():
78            self.c_name = '_' + self.c_name
79
80        # Added by resolve():
81        self.enum_name = None
82        delattr(self, "enum_name")
83
84    def _get_real_attr(self):
85        # if the attr is for a subset return the "real" attr (just one down, does not recurse)
86        return self.family.attr_sets[self.attr_set.subset_of][self.name]
87
88    def set_request(self):
89        self.request = True
90        if self.attr_set.subset_of:
91            self._get_real_attr().set_request()
92
93    def set_reply(self):
94        self.reply = True
95        if self.attr_set.subset_of:
96            self._get_real_attr().set_reply()
97
98    def get_limit(self, limit, default=None):
99        value = self.checks.get(limit, default)
100        if value is None:
101            return value
102        if isinstance(value, int):
103            return value
104        if value in self.family.consts:
105            return self.family.consts[value]["value"]
106        return limit_to_number(value)
107
108    def get_limit_str(self, limit, default=None, suffix=''):
109        value = self.checks.get(limit, default)
110        if value is None:
111            return ''
112        if isinstance(value, int):
113            return str(value) + suffix
114        if value in self.family.consts:
115            const = self.family.consts[value]
116            if const.get('header'):
117                return c_upper(value)
118            return c_upper(f"{self.family['name']}-{value}")
119        return c_upper(value)
120
121    def resolve(self):
122        if 'name-prefix' in self.attr:
123            enum_name = f"{self.attr['name-prefix']}{self.name}"
124        else:
125            enum_name = f"{self.attr_set.name_prefix}{self.name}"
126        self.enum_name = c_upper(enum_name)
127
128        if self.attr_set.subset_of:
129            if self.checks != self._get_real_attr().checks:
130                raise Exception("Overriding checks not supported by codegen, yet")
131
132    def is_multi_val(self):
133        return None
134
135    def is_scalar(self):
136        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
137
138    def is_recursive(self):
139        return False
140
141    def is_recursive_for_op(self, ri):
142        return self.is_recursive() and not ri.op
143
144    def presence_type(self):
145        return 'bit'
146
147    def presence_member(self, space, type_filter):
148        if self.presence_type() != type_filter:
149            return
150
151        if self.presence_type() == 'bit':
152            pfx = '__' if space == 'user' else ''
153            return f"{pfx}u32 {self.c_name}:1;"
154
155        if self.presence_type() == 'len':
156            pfx = '__' if space == 'user' else ''
157            return f"{pfx}u32 {self.c_name}_len;"
158
159    def _complex_member_type(self, ri):
160        return None
161
162    def free_needs_iter(self):
163        return False
164
165    def free(self, ri, var, ref):
166        if self.is_multi_val() or self.presence_type() == 'len':
167            ri.cw.p(f'free({var}->{ref}{self.c_name});')
168
169    def arg_member(self, ri):
170        member = self._complex_member_type(ri)
171        if member:
172            arg = [member + ' *' + self.c_name]
173            if self.presence_type() == 'count':
174                arg += ['unsigned int n_' + self.c_name]
175            return arg
176        raise Exception(f"Struct member not implemented for class type {self.type}")
177
178    def struct_member(self, ri):
179        if self.is_multi_val():
180            ri.cw.p(f"unsigned int n_{self.c_name};")
181        member = self._complex_member_type(ri)
182        if member:
183            ptr = '*' if self.is_multi_val() else ''
184            if self.is_recursive_for_op(ri):
185                ptr = '*'
186            ri.cw.p(f"{member} {ptr}{self.c_name};")
187            return
188        members = self.arg_member(ri)
189        for one in members:
190            ri.cw.p(one + ';')
191
192    def _attr_policy(self, policy):
193        return '{ .type = ' + policy + ', }'
194
195    def attr_policy(self, cw):
196        policy = f'NLA_{c_upper(self.type)}'
197        if self.attr.get('byte-order') == 'big-endian':
198            if self.type in {'u16', 'u32'}:
199                policy = f'NLA_BE{self.type[1:]}'
200
201        spec = self._attr_policy(policy)
202        cw.p(f"\t[{self.enum_name}] = {spec},")
203
204    def _attr_typol(self):
205        raise Exception(f"Type policy not implemented for class type {self.type}")
206
207    def attr_typol(self, cw):
208        typol = self._attr_typol()
209        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
210
211    def _attr_put_line(self, ri, var, line):
212        if self.presence_type() == 'bit':
213            ri.cw.p(f"if ({var}->_present.{self.c_name})")
214        elif self.presence_type() == 'len':
215            ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
216        ri.cw.p(f"{line};")
217
218    def _attr_put_simple(self, ri, var, put_type):
219        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
220        self._attr_put_line(ri, var, line)
221
222    def attr_put(self, ri, var):
223        raise Exception(f"Put not implemented for class type {self.type}")
224
225    def _attr_get(self, ri, var):
226        raise Exception(f"Attr get not implemented for class type {self.type}")
227
228    def attr_get(self, ri, var, first):
229        lines, init_lines, local_vars = self._attr_get(ri, var)
230        if type(lines) is str:
231            lines = [lines]
232        if type(init_lines) is str:
233            init_lines = [init_lines]
234
235        kw = 'if' if first else 'else if'
236        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
237        if local_vars:
238            for local in local_vars:
239                ri.cw.p(local)
240            ri.cw.nl()
241
242        if not self.is_multi_val():
243            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
244            ri.cw.p("return YNL_PARSE_CB_ERROR;")
245            if self.presence_type() == 'bit':
246                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
247
248        if init_lines:
249            ri.cw.nl()
250            for line in init_lines:
251                ri.cw.p(line)
252
253        for line in lines:
254            ri.cw.p(line)
255        ri.cw.block_end()
256        return True
257
258    def _setter_lines(self, ri, member, presence):
259        raise Exception(f"Setter not implemented for class type {self.type}")
260
261    def setter(self, ri, space, direction, deref=False, ref=None):
262        ref = (ref if ref else []) + [self.c_name]
263        var = "req"
264        member = f"{var}->{'.'.join(ref)}"
265
266        code = []
267        presence = ''
268        for i in range(0, len(ref)):
269            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
270            # Every layer below last is a nest, so we know it uses bit presence
271            # last layer is "self" and may be a complex type
272            if i == len(ref) - 1 and self.presence_type() != 'bit':
273                continue
274            code.append(presence + ' = 1;')
275        code += self._setter_lines(ri, member, presence)
276
277        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
278        free = bool([x for x in code if 'free(' in x])
279        alloc = bool([x for x in code if 'alloc(' in x])
280        if free and not alloc:
281            func_name = '__' + func_name
282        ri.cw.write_func('static inline void', func_name, body=code,
283                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
284
285
286class TypeUnused(Type):
287    def presence_type(self):
288        return ''
289
290    def arg_member(self, ri):
291        return []
292
293    def _attr_get(self, ri, var):
294        return ['return YNL_PARSE_CB_ERROR;'], None, None
295
296    def _attr_typol(self):
297        return '.type = YNL_PT_REJECT, '
298
299    def attr_policy(self, cw):
300        pass
301
302    def attr_put(self, ri, var):
303        pass
304
305    def attr_get(self, ri, var, first):
306        pass
307
308    def setter(self, ri, space, direction, deref=False, ref=None):
309        pass
310
311
312class TypePad(Type):
313    def presence_type(self):
314        return ''
315
316    def arg_member(self, ri):
317        return []
318
319    def _attr_typol(self):
320        return '.type = YNL_PT_IGNORE, '
321
322    def attr_put(self, ri, var):
323        pass
324
325    def attr_get(self, ri, var, first):
326        pass
327
328    def attr_policy(self, cw):
329        pass
330
331    def setter(self, ri, space, direction, deref=False, ref=None):
332        pass
333
334
335class TypeScalar(Type):
336    def __init__(self, family, attr_set, attr, value):
337        super().__init__(family, attr_set, attr, value)
338
339        self.byte_order_comment = ''
340        if 'byte-order' in attr:
341            self.byte_order_comment = f" /* {attr['byte-order']} */"
342
343        if 'enum' in self.attr:
344            enum = self.family.consts[self.attr['enum']]
345            low, high = enum.value_range()
346            if 'min' not in self.checks:
347                if low != 0 or self.type[0] == 's':
348                    self.checks['min'] = low
349            if 'max' not in self.checks:
350                self.checks['max'] = high
351
352        if 'min' in self.checks and 'max' in self.checks:
353            if self.get_limit('min') > self.get_limit('max'):
354                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
355            self.checks['range'] = True
356
357        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
358        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
359        if low < 0 and self.type[0] == 'u':
360            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
361        if low < -32768 or high > 32767:
362            self.checks['full-range'] = True
363
364        # Added by resolve():
365        self.is_bitfield = None
366        delattr(self, "is_bitfield")
367        self.type_name = None
368        delattr(self, "type_name")
369
370    def resolve(self):
371        self.resolve_up(super())
372
373        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
374            self.is_bitfield = True
375        elif 'enum' in self.attr:
376            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
377        else:
378            self.is_bitfield = False
379
380        if not self.is_bitfield and 'enum' in self.attr:
381            self.type_name = self.family.consts[self.attr['enum']].user_type
382        elif self.is_auto_scalar:
383            self.type_name = '__' + self.type[0] + '64'
384        else:
385            self.type_name = '__' + self.type
386
387    def _attr_policy(self, policy):
388        if 'flags-mask' in self.checks or self.is_bitfield:
389            if self.is_bitfield:
390                enum = self.family.consts[self.attr['enum']]
391                mask = enum.get_mask(as_flags=True)
392            else:
393                flags = self.family.consts[self.checks['flags-mask']]
394                flag_cnt = len(flags['entries'])
395                mask = (1 << flag_cnt) - 1
396            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
397        elif 'full-range' in self.checks:
398            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
399        elif 'range' in self.checks:
400            return f"NLA_POLICY_RANGE({policy}, {self.get_limit_str('min')}, {self.get_limit_str('max')})"
401        elif 'min' in self.checks:
402            return f"NLA_POLICY_MIN({policy}, {self.get_limit_str('min')})"
403        elif 'max' in self.checks:
404            return f"NLA_POLICY_MAX({policy}, {self.get_limit_str('max')})"
405        return super()._attr_policy(policy)
406
407    def _attr_typol(self):
408        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
409
410    def arg_member(self, ri):
411        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
412
413    def attr_put(self, ri, var):
414        self._attr_put_simple(ri, var, self.type)
415
416    def _attr_get(self, ri, var):
417        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
418
419    def _setter_lines(self, ri, member, presence):
420        return [f"{member} = {self.c_name};"]
421
422
423class TypeFlag(Type):
424    def arg_member(self, ri):
425        return []
426
427    def _attr_typol(self):
428        return '.type = YNL_PT_FLAG, '
429
430    def attr_put(self, ri, var):
431        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
432
433    def _attr_get(self, ri, var):
434        return [], None, None
435
436    def _setter_lines(self, ri, member, presence):
437        return []
438
439
440class TypeString(Type):
441    def arg_member(self, ri):
442        return [f"const char *{self.c_name}"]
443
444    def presence_type(self):
445        return 'len'
446
447    def struct_member(self, ri):
448        ri.cw.p(f"char *{self.c_name};")
449
450    def _attr_typol(self):
451        return f'.type = YNL_PT_NUL_STR, '
452
453    def _attr_policy(self, policy):
454        if 'exact-len' in self.checks:
455            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
456        else:
457            mem = '{ .type = ' + policy
458            if 'max-len' in self.checks:
459                mem += ', .len = ' + self.get_limit_str('max-len')
460            mem += ', }'
461        return mem
462
463    def attr_policy(self, cw):
464        if self.checks.get('unterminated-ok', False):
465            policy = 'NLA_STRING'
466        else:
467            policy = 'NLA_NUL_STRING'
468
469        spec = self._attr_policy(policy)
470        cw.p(f"\t[{self.enum_name}] = {spec},")
471
472    def attr_put(self, ri, var):
473        self._attr_put_simple(ri, var, 'str')
474
475    def _attr_get(self, ri, var):
476        len_mem = var + '->_present.' + self.c_name + '_len'
477        return [f"{len_mem} = len;",
478                f"{var}->{self.c_name} = malloc(len + 1);",
479                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
480                f"{var}->{self.c_name}[len] = 0;"], \
481               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
482               ['unsigned int len;']
483
484    def _setter_lines(self, ri, member, presence):
485        return [f"free({member});",
486                f"{presence}_len = strlen({self.c_name});",
487                f"{member} = malloc({presence}_len + 1);",
488                f'memcpy({member}, {self.c_name}, {presence}_len);',
489                f'{member}[{presence}_len] = 0;']
490
491
492class TypeBinary(Type):
493    def arg_member(self, ri):
494        return [f"const void *{self.c_name}", 'size_t len']
495
496    def presence_type(self):
497        return 'len'
498
499    def struct_member(self, ri):
500        ri.cw.p(f"void *{self.c_name};")
501
502    def _attr_typol(self):
503        return f'.type = YNL_PT_BINARY,'
504
505    def _attr_policy(self, policy):
506        if len(self.checks) == 0:
507            pass
508        elif len(self.checks) == 1:
509            check_name = list(self.checks)[0]
510            if check_name not in {'exact-len', 'min-len', 'max-len'}:
511                raise Exception('Unsupported check for binary type: ' + check_name)
512        else:
513            raise Exception('More than one check for binary type not implemented, yet')
514
515        if len(self.checks) == 0:
516            mem = '{ .type = NLA_BINARY, }'
517        elif 'exact-len' in self.checks:
518            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
519        elif 'min-len' in self.checks:
520            mem = '{ .len = ' + self.get_limit_str('min-len') + ', }'
521        elif 'max-len' in self.checks:
522            mem = 'NLA_POLICY_MAX_LEN(' + self.get_limit_str('max-len') + ')'
523
524        return mem
525
526    def attr_put(self, ri, var):
527        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
528                            f"{var}->{self.c_name}, {var}->_present.{self.c_name}_len)")
529
530    def _attr_get(self, ri, var):
531        len_mem = var + '->_present.' + self.c_name + '_len'
532        return [f"{len_mem} = len;",
533                f"{var}->{self.c_name} = malloc(len);",
534                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
535               ['len = ynl_attr_data_len(attr);'], \
536               ['unsigned int len;']
537
538    def _setter_lines(self, ri, member, presence):
539        return [f"free({member});",
540                f"{presence}_len = len;",
541                f"{member} = malloc({presence}_len);",
542                f'memcpy({member}, {self.c_name}, {presence}_len);']
543
544
545class TypeBitfield32(Type):
546    def _complex_member_type(self, ri):
547        return "struct nla_bitfield32"
548
549    def _attr_typol(self):
550        return f'.type = YNL_PT_BITFIELD32, '
551
552    def _attr_policy(self, policy):
553        if not 'enum' in self.attr:
554            raise Exception('Enum required for bitfield32 attr')
555        enum = self.family.consts[self.attr['enum']]
556        mask = enum.get_mask(as_flags=True)
557        return f"NLA_POLICY_BITFIELD32({mask})"
558
559    def attr_put(self, ri, var):
560        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
561        self._attr_put_line(ri, var, line)
562
563    def _attr_get(self, ri, var):
564        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
565
566    def _setter_lines(self, ri, member, presence):
567        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
568
569
570class TypeNest(Type):
571    def is_recursive(self):
572        return self.family.pure_nested_structs[self.nested_attrs].recursive
573
574    def _complex_member_type(self, ri):
575        return self.nested_struct_type
576
577    def free(self, ri, var, ref):
578        at = '&'
579        if self.is_recursive_for_op(ri):
580            at = ''
581            ri.cw.p(f'if ({var}->{ref}{self.c_name})')
582        ri.cw.p(f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});')
583
584    def _attr_typol(self):
585        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
586
587    def _attr_policy(self, policy):
588        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
589
590    def attr_put(self, ri, var):
591        at = '' if self.is_recursive_for_op(ri) else '&'
592        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
593                            f"{self.enum_name}, {at}{var}->{self.c_name})")
594
595    def _attr_get(self, ri, var):
596        get_lines = [f"if ({self.nested_render_name}_parse(&parg, attr))",
597                     "return YNL_PARSE_CB_ERROR;"]
598        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
599                      f"parg.data = &{var}->{self.c_name};"]
600        return get_lines, init_lines, None
601
602    def setter(self, ri, space, direction, deref=False, ref=None):
603        ref = (ref if ref else []) + [self.c_name]
604
605        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
606            if attr.is_recursive():
607                continue
608            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
609
610
611class TypeMultiAttr(Type):
612    def __init__(self, family, attr_set, attr, value, base_type):
613        super().__init__(family, attr_set, attr, value)
614
615        self.base_type = base_type
616
617    def is_multi_val(self):
618        return True
619
620    def presence_type(self):
621        return 'count'
622
623    def _complex_member_type(self, ri):
624        if 'type' not in self.attr or self.attr['type'] == 'nest':
625            return self.nested_struct_type
626        elif self.attr['type'] in scalars:
627            scalar_pfx = '__' if ri.ku_space == 'user' else ''
628            return scalar_pfx + self.attr['type']
629        else:
630            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
631
632    def free_needs_iter(self):
633        return 'type' not in self.attr or self.attr['type'] == 'nest'
634
635    def free(self, ri, var, ref):
636        if self.attr['type'] in scalars:
637            ri.cw.p(f"free({var}->{ref}{self.c_name});")
638        elif 'type' not in self.attr or self.attr['type'] == 'nest':
639            ri.cw.p(f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)")
640            ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);')
641            ri.cw.p(f"free({var}->{ref}{self.c_name});")
642        else:
643            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
644
645    def _attr_policy(self, policy):
646        return self.base_type._attr_policy(policy)
647
648    def _attr_typol(self):
649        return self.base_type._attr_typol()
650
651    def _attr_get(self, ri, var):
652        return f'n_{self.c_name}++;', None, None
653
654    def attr_put(self, ri, var):
655        if self.attr['type'] in scalars:
656            put_type = self.type
657            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
658            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
659        elif 'type' not in self.attr or self.attr['type'] == 'nest':
660            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
661            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
662                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
663        else:
664            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
665
666    def _setter_lines(self, ri, member, presence):
667        # For multi-attr we have a count, not presence, hack up the presence
668        presence = presence[:-(len('_present.') + len(self.c_name))] + "n_" + self.c_name
669        return [f"free({member});",
670                f"{member} = {self.c_name};",
671                f"{presence} = n_{self.c_name};"]
672
673
674class TypeArrayNest(Type):
675    def is_multi_val(self):
676        return True
677
678    def presence_type(self):
679        return 'count'
680
681    def _complex_member_type(self, ri):
682        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
683            return self.nested_struct_type
684        elif self.attr['sub-type'] in scalars:
685            scalar_pfx = '__' if ri.ku_space == 'user' else ''
686            return scalar_pfx + self.attr['sub-type']
687        else:
688            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
689
690    def _attr_typol(self):
691        if self.attr['sub-type'] in scalars:
692            return f'.type = YNL_PT_U{c_upper(self.sub_type[1:])}, '
693        else:
694            return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
695
696    def _attr_get(self, ri, var):
697        local_vars = ['const struct nlattr *attr2;']
698        get_lines = [f'attr_{self.c_name} = attr;',
699                     'ynl_attr_for_each_nested(attr2, attr)',
700                     f'\t{var}->n_{self.c_name}++;']
701        return get_lines, None, local_vars
702
703
704class TypeNestTypeValue(Type):
705    def _complex_member_type(self, ri):
706        return self.nested_struct_type
707
708    def _attr_typol(self):
709        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
710
711    def _attr_get(self, ri, var):
712        prev = 'attr'
713        tv_args = ''
714        get_lines = []
715        local_vars = []
716        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
717                      f"parg.data = &{var}->{self.c_name};"]
718        if 'type-value' in self.attr:
719            tv_names = [c_lower(x) for x in self.attr["type-value"]]
720            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
721            local_vars += [f'__u32 {", ".join(tv_names)};']
722            for level in self.attr["type-value"]:
723                level = c_lower(level)
724                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
725                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
726                prev = 'attr_' + level
727
728            tv_args = f", {', '.join(tv_names)}"
729
730        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
731        return get_lines, init_lines, local_vars
732
733
734class Struct:
735    def __init__(self, family, space_name, type_list=None, inherited=None):
736        self.family = family
737        self.space_name = space_name
738        self.attr_set = family.attr_sets[space_name]
739        # Use list to catch comparisons with empty sets
740        self._inherited = inherited if inherited is not None else []
741        self.inherited = []
742
743        self.nested = type_list is None
744        if family.name == c_lower(space_name):
745            self.render_name = c_lower(family.ident_name)
746        else:
747            self.render_name = c_lower(family.ident_name + '-' + space_name)
748        self.struct_name = 'struct ' + self.render_name
749        if self.nested and space_name in family.consts:
750            self.struct_name += '_'
751        self.ptr_name = self.struct_name + ' *'
752        # All attr sets this one contains, directly or multiple levels down
753        self.child_nests = set()
754
755        self.request = False
756        self.reply = False
757        self.recursive = False
758
759        self.attr_list = []
760        self.attrs = dict()
761        if type_list is not None:
762            for t in type_list:
763                self.attr_list.append((t, self.attr_set[t]),)
764        else:
765            for t in self.attr_set:
766                self.attr_list.append((t, self.attr_set[t]),)
767
768        max_val = 0
769        self.attr_max_val = None
770        for name, attr in self.attr_list:
771            if attr.value >= max_val:
772                max_val = attr.value
773                self.attr_max_val = attr
774            self.attrs[name] = attr
775
776    def __iter__(self):
777        yield from self.attrs
778
779    def __getitem__(self, key):
780        return self.attrs[key]
781
782    def member_list(self):
783        return self.attr_list
784
785    def set_inherited(self, new_inherited):
786        if self._inherited != new_inherited:
787            raise Exception("Inheriting different members not supported")
788        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
789
790
791class EnumEntry(SpecEnumEntry):
792    def __init__(self, enum_set, yaml, prev, value_start):
793        super().__init__(enum_set, yaml, prev, value_start)
794
795        if prev:
796            self.value_change = (self.value != prev.value + 1)
797        else:
798            self.value_change = (self.value != 0)
799        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
800
801        # Added by resolve:
802        self.c_name = None
803        delattr(self, "c_name")
804
805    def resolve(self):
806        self.resolve_up(super())
807
808        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
809
810
811class EnumSet(SpecEnumSet):
812    def __init__(self, family, yaml):
813        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
814
815        if 'enum-name' in yaml:
816            if yaml['enum-name']:
817                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
818                self.user_type = self.enum_name
819            else:
820                self.enum_name = None
821        else:
822            self.enum_name = 'enum ' + self.render_name
823
824        if self.enum_name:
825            self.user_type = self.enum_name
826        else:
827            self.user_type = 'int'
828
829        self.value_pfx = yaml.get('name-prefix', f"{family.ident_name}-{yaml['name']}-")
830        self.header = yaml.get('header', None)
831        self.enum_cnt_name = yaml.get('enum-cnt-name', None)
832
833        super().__init__(family, yaml)
834
835    def new_entry(self, entry, prev_entry, value_start):
836        return EnumEntry(self, entry, prev_entry, value_start)
837
838    def value_range(self):
839        low = min([x.value for x in self.entries.values()])
840        high = max([x.value for x in self.entries.values()])
841
842        if high - low + 1 != len(self.entries):
843            raise Exception("Can't get value range for a noncontiguous enum")
844
845        return low, high
846
847
848class AttrSet(SpecAttrSet):
849    def __init__(self, family, yaml):
850        super().__init__(family, yaml)
851
852        if self.subset_of is None:
853            if 'name-prefix' in yaml:
854                pfx = yaml['name-prefix']
855            elif self.name == family.name:
856                pfx = family.ident_name + '-a-'
857            else:
858                pfx = f"{family.ident_name}-a-{self.name}-"
859            self.name_prefix = c_upper(pfx)
860            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
861            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
862        else:
863            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
864            self.max_name = family.attr_sets[self.subset_of].max_name
865            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
866
867        # Added by resolve:
868        self.c_name = None
869        delattr(self, "c_name")
870
871    def resolve(self):
872        self.c_name = c_lower(self.name)
873        if self.c_name in _C_KW:
874            self.c_name += '_'
875        if self.c_name == self.family.c_name:
876            self.c_name = ''
877
878    def new_attr(self, elem, value):
879        if elem['type'] in scalars:
880            t = TypeScalar(self.family, self, elem, value)
881        elif elem['type'] == 'unused':
882            t = TypeUnused(self.family, self, elem, value)
883        elif elem['type'] == 'pad':
884            t = TypePad(self.family, self, elem, value)
885        elif elem['type'] == 'flag':
886            t = TypeFlag(self.family, self, elem, value)
887        elif elem['type'] == 'string':
888            t = TypeString(self.family, self, elem, value)
889        elif elem['type'] == 'binary':
890            t = TypeBinary(self.family, self, elem, value)
891        elif elem['type'] == 'bitfield32':
892            t = TypeBitfield32(self.family, self, elem, value)
893        elif elem['type'] == 'nest':
894            t = TypeNest(self.family, self, elem, value)
895        elif elem['type'] == 'indexed-array' and 'sub-type' in elem:
896            if elem["sub-type"] in ['nest', 'u32']:
897                t = TypeArrayNest(self.family, self, elem, value)
898            else:
899                raise Exception(f'new_attr: unsupported sub-type {elem["sub-type"]}')
900        elif elem['type'] == 'nest-type-value':
901            t = TypeNestTypeValue(self.family, self, elem, value)
902        else:
903            raise Exception(f"No typed class for type {elem['type']}")
904
905        if 'multi-attr' in elem and elem['multi-attr']:
906            t = TypeMultiAttr(self.family, self, elem, value, t)
907
908        return t
909
910
911class Operation(SpecOperation):
912    def __init__(self, family, yaml, req_value, rsp_value):
913        super().__init__(family, yaml, req_value, rsp_value)
914
915        self.render_name = c_lower(family.ident_name + '_' + self.name)
916
917        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
918                         ('dump' in yaml and 'request' in yaml['dump'])
919
920        self.has_ntf = False
921
922        # Added by resolve:
923        self.enum_name = None
924        delattr(self, "enum_name")
925
926    def resolve(self):
927        self.resolve_up(super())
928
929        if not self.is_async:
930            self.enum_name = self.family.op_prefix + c_upper(self.name)
931        else:
932            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
933
934    def mark_has_ntf(self):
935        self.has_ntf = True
936
937
938class Family(SpecFamily):
939    def __init__(self, file_name, exclude_ops):
940        # Added by resolve:
941        self.c_name = None
942        delattr(self, "c_name")
943        self.op_prefix = None
944        delattr(self, "op_prefix")
945        self.async_op_prefix = None
946        delattr(self, "async_op_prefix")
947        self.mcgrps = None
948        delattr(self, "mcgrps")
949        self.consts = None
950        delattr(self, "consts")
951        self.hooks = None
952        delattr(self, "hooks")
953
954        super().__init__(file_name, exclude_ops=exclude_ops)
955
956        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
957        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
958
959        if 'definitions' not in self.yaml:
960            self.yaml['definitions'] = []
961
962        if 'uapi-header' in self.yaml:
963            self.uapi_header = self.yaml['uapi-header']
964        else:
965            self.uapi_header = f"linux/{self.ident_name}.h"
966        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
967            self.uapi_header_name = self.uapi_header[6:-2]
968        else:
969            self.uapi_header_name = self.ident_name
970
971    def resolve(self):
972        self.resolve_up(super())
973
974        if self.yaml.get('protocol', 'genetlink') not in {'genetlink', 'genetlink-c', 'genetlink-legacy'}:
975            raise Exception("Codegen only supported for genetlink")
976
977        self.c_name = c_lower(self.ident_name)
978        if 'name-prefix' in self.yaml['operations']:
979            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
980        else:
981            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
982        if 'async-prefix' in self.yaml['operations']:
983            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
984        else:
985            self.async_op_prefix = self.op_prefix
986
987        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
988
989        self.hooks = dict()
990        for when in ['pre', 'post']:
991            self.hooks[when] = dict()
992            for op_mode in ['do', 'dump']:
993                self.hooks[when][op_mode] = dict()
994                self.hooks[when][op_mode]['set'] = set()
995                self.hooks[when][op_mode]['list'] = []
996
997        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
998        self.root_sets = dict()
999        # dict space-name -> set('request', 'reply')
1000        self.pure_nested_structs = dict()
1001
1002        self._mark_notify()
1003        self._mock_up_events()
1004
1005        self._load_root_sets()
1006        self._load_nested_sets()
1007        self._load_attr_use()
1008        self._load_hooks()
1009
1010        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
1011        if self.kernel_policy == 'global':
1012            self._load_global_policy()
1013
1014    def new_enum(self, elem):
1015        return EnumSet(self, elem)
1016
1017    def new_attr_set(self, elem):
1018        return AttrSet(self, elem)
1019
1020    def new_operation(self, elem, req_value, rsp_value):
1021        return Operation(self, elem, req_value, rsp_value)
1022
1023    def _mark_notify(self):
1024        for op in self.msgs.values():
1025            if 'notify' in op:
1026                self.ops[op['notify']].mark_has_ntf()
1027
1028    # Fake a 'do' equivalent of all events, so that we can render their response parsing
1029    def _mock_up_events(self):
1030        for op in self.yaml['operations']['list']:
1031            if 'event' in op:
1032                op['do'] = {
1033                    'reply': {
1034                        'attributes': op['event']['attributes']
1035                    }
1036                }
1037
1038    def _load_root_sets(self):
1039        for op_name, op in self.msgs.items():
1040            if 'attribute-set' not in op:
1041                continue
1042
1043            req_attrs = set()
1044            rsp_attrs = set()
1045            for op_mode in ['do', 'dump']:
1046                if op_mode in op and 'request' in op[op_mode]:
1047                    req_attrs.update(set(op[op_mode]['request']['attributes']))
1048                if op_mode in op and 'reply' in op[op_mode]:
1049                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
1050            if 'event' in op:
1051                rsp_attrs.update(set(op['event']['attributes']))
1052
1053            if op['attribute-set'] not in self.root_sets:
1054                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1055            else:
1056                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1057                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1058
1059    def _sort_pure_types(self):
1060        # Try to reorder according to dependencies
1061        pns_key_list = list(self.pure_nested_structs.keys())
1062        pns_key_seen = set()
1063        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1064        for _ in range(rounds):
1065            if len(pns_key_list) == 0:
1066                break
1067            name = pns_key_list.pop(0)
1068            finished = True
1069            for _, spec in self.attr_sets[name].items():
1070                if 'nested-attributes' in spec:
1071                    nested = spec['nested-attributes']
1072                    # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1073                    if self.pure_nested_structs[nested].recursive:
1074                        continue
1075                    if nested not in pns_key_seen:
1076                        # Dicts are sorted, this will make struct last
1077                        struct = self.pure_nested_structs.pop(name)
1078                        self.pure_nested_structs[name] = struct
1079                        finished = False
1080                        break
1081            if finished:
1082                pns_key_seen.add(name)
1083            else:
1084                pns_key_list.append(name)
1085
1086    def _load_nested_sets(self):
1087        attr_set_queue = list(self.root_sets.keys())
1088        attr_set_seen = set(self.root_sets.keys())
1089
1090        while len(attr_set_queue):
1091            a_set = attr_set_queue.pop(0)
1092            for attr, spec in self.attr_sets[a_set].items():
1093                if 'nested-attributes' not in spec:
1094                    continue
1095
1096                nested = spec['nested-attributes']
1097                if nested not in attr_set_seen:
1098                    attr_set_queue.append(nested)
1099                    attr_set_seen.add(nested)
1100
1101                inherit = set()
1102                if nested not in self.root_sets:
1103                    if nested not in self.pure_nested_structs:
1104                        self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
1105                else:
1106                    raise Exception(f'Using attr set as root and nested not supported - {nested}')
1107
1108                if 'type-value' in spec:
1109                    if nested in self.root_sets:
1110                        raise Exception("Inheriting members to a space used as root not supported")
1111                    inherit.update(set(spec['type-value']))
1112                elif spec['type'] == 'indexed-array':
1113                    inherit.add('idx')
1114                self.pure_nested_structs[nested].set_inherited(inherit)
1115
1116        for root_set, rs_members in self.root_sets.items():
1117            for attr, spec in self.attr_sets[root_set].items():
1118                if 'nested-attributes' in spec:
1119                    nested = spec['nested-attributes']
1120                    if attr in rs_members['request']:
1121                        self.pure_nested_structs[nested].request = True
1122                    if attr in rs_members['reply']:
1123                        self.pure_nested_structs[nested].reply = True
1124
1125        self._sort_pure_types()
1126
1127        # Propagate the request / reply / recursive
1128        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1129            for _, spec in self.attr_sets[attr_set].items():
1130                if 'nested-attributes' in spec:
1131                    child_name = spec['nested-attributes']
1132                    struct.child_nests.add(child_name)
1133                    child = self.pure_nested_structs.get(child_name)
1134                    if child:
1135                        if not child.recursive:
1136                            struct.child_nests.update(child.child_nests)
1137                        child.request |= struct.request
1138                        child.reply |= struct.reply
1139                if attr_set in struct.child_nests:
1140                    struct.recursive = True
1141
1142        self._sort_pure_types()
1143
1144    def _load_attr_use(self):
1145        for _, struct in self.pure_nested_structs.items():
1146            if struct.request:
1147                for _, arg in struct.member_list():
1148                    arg.set_request()
1149            if struct.reply:
1150                for _, arg in struct.member_list():
1151                    arg.set_reply()
1152
1153        for root_set, rs_members in self.root_sets.items():
1154            for attr, spec in self.attr_sets[root_set].items():
1155                if attr in rs_members['request']:
1156                    spec.set_request()
1157                if attr in rs_members['reply']:
1158                    spec.set_reply()
1159
1160    def _load_global_policy(self):
1161        global_set = set()
1162        attr_set_name = None
1163        for op_name, op in self.ops.items():
1164            if not op:
1165                continue
1166            if 'attribute-set' not in op:
1167                continue
1168
1169            if attr_set_name is None:
1170                attr_set_name = op['attribute-set']
1171            if attr_set_name != op['attribute-set']:
1172                raise Exception('For a global policy all ops must use the same set')
1173
1174            for op_mode in ['do', 'dump']:
1175                if op_mode in op:
1176                    req = op[op_mode].get('request')
1177                    if req:
1178                        global_set.update(req.get('attributes', []))
1179
1180        self.global_policy = []
1181        self.global_policy_set = attr_set_name
1182        for attr in self.attr_sets[attr_set_name]:
1183            if attr in global_set:
1184                self.global_policy.append(attr)
1185
1186    def _load_hooks(self):
1187        for op in self.ops.values():
1188            for op_mode in ['do', 'dump']:
1189                if op_mode not in op:
1190                    continue
1191                for when in ['pre', 'post']:
1192                    if when not in op[op_mode]:
1193                        continue
1194                    name = op[op_mode][when]
1195                    if name in self.hooks[when][op_mode]['set']:
1196                        continue
1197                    self.hooks[when][op_mode]['set'].add(name)
1198                    self.hooks[when][op_mode]['list'].append(name)
1199
1200
1201class RenderInfo:
1202    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1203        self.family = family
1204        self.nl = cw.nlib
1205        self.ku_space = ku_space
1206        self.op_mode = op_mode
1207        self.op = op
1208
1209        self.fixed_hdr = None
1210        if op and op.fixed_header:
1211            self.fixed_hdr = 'struct ' + c_lower(op.fixed_header)
1212
1213        # 'do' and 'dump' response parsing is identical
1214        self.type_consistent = True
1215        if op_mode != 'do' and 'dump' in op:
1216            if 'do' in op:
1217                if ('reply' in op['do']) != ('reply' in op["dump"]):
1218                    self.type_consistent = False
1219                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1220                    self.type_consistent = False
1221            else:
1222                self.type_consistent = False
1223
1224        self.attr_set = attr_set
1225        if not self.attr_set:
1226            self.attr_set = op['attribute-set']
1227
1228        self.type_name_conflict = False
1229        if op:
1230            self.type_name = c_lower(op.name)
1231        else:
1232            self.type_name = c_lower(attr_set)
1233            if attr_set in family.consts:
1234                self.type_name_conflict = True
1235
1236        self.cw = cw
1237
1238        self.struct = dict()
1239        if op_mode == 'notify':
1240            op_mode = 'do'
1241        for op_dir in ['request', 'reply']:
1242            if op:
1243                type_list = []
1244                if op_dir in op[op_mode]:
1245                    type_list = op[op_mode][op_dir]['attributes']
1246                self.struct[op_dir] = Struct(family, self.attr_set, type_list=type_list)
1247        if op_mode == 'event':
1248            self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
1249
1250
1251class CodeWriter:
1252    def __init__(self, nlib, out_file=None, overwrite=True):
1253        self.nlib = nlib
1254        self._overwrite = overwrite
1255
1256        self._nl = False
1257        self._block_end = False
1258        self._silent_block = False
1259        self._ind = 0
1260        self._ifdef_block = None
1261        if out_file is None:
1262            self._out = os.sys.stdout
1263        else:
1264            self._out = tempfile.NamedTemporaryFile('w+')
1265            self._out_file = out_file
1266
1267    def __del__(self):
1268        self.close_out_file()
1269
1270    def close_out_file(self):
1271        if self._out == os.sys.stdout:
1272            return
1273        # Avoid modifying the file if contents didn't change
1274        self._out.flush()
1275        if not self._overwrite and os.path.isfile(self._out_file):
1276            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1277                return
1278        with open(self._out_file, 'w+') as out_file:
1279            self._out.seek(0)
1280            shutil.copyfileobj(self._out, out_file)
1281            self._out.close()
1282        self._out = os.sys.stdout
1283
1284    @classmethod
1285    def _is_cond(cls, line):
1286        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1287
1288    def p(self, line, add_ind=0):
1289        if self._block_end:
1290            self._block_end = False
1291            if line.startswith('else'):
1292                line = '} ' + line
1293            else:
1294                self._out.write('\t' * self._ind + '}\n')
1295
1296        if self._nl:
1297            self._out.write('\n')
1298            self._nl = False
1299
1300        ind = self._ind
1301        if line[-1] == ':':
1302            ind -= 1
1303        if self._silent_block:
1304            ind += 1
1305        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1306        if line[0] == '#':
1307            ind = 0
1308        if add_ind:
1309            ind += add_ind
1310        self._out.write('\t' * ind + line + '\n')
1311
1312    def nl(self):
1313        self._nl = True
1314
1315    def block_start(self, line=''):
1316        if line:
1317            line = line + ' '
1318        self.p(line + '{')
1319        self._ind += 1
1320
1321    def block_end(self, line=''):
1322        if line and line[0] not in {';', ','}:
1323            line = ' ' + line
1324        self._ind -= 1
1325        self._nl = False
1326        if not line:
1327            # Delay printing closing bracket in case "else" comes next
1328            if self._block_end:
1329                self._out.write('\t' * (self._ind + 1) + '}\n')
1330            self._block_end = True
1331        else:
1332            self.p('}' + line)
1333
1334    def write_doc_line(self, doc, indent=True):
1335        words = doc.split()
1336        line = ' *'
1337        for word in words:
1338            if len(line) + len(word) >= 79:
1339                self.p(line)
1340                line = ' *'
1341                if indent:
1342                    line += '  '
1343            line += ' ' + word
1344        self.p(line)
1345
1346    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1347        if not args:
1348            args = ['void']
1349
1350        if doc:
1351            self.p('/*')
1352            self.p(' * ' + doc)
1353            self.p(' */')
1354
1355        oneline = qual_ret
1356        if qual_ret[-1] != '*':
1357            oneline += ' '
1358        oneline += f"{name}({', '.join(args)}){suffix}"
1359
1360        if len(oneline) < 80:
1361            self.p(oneline)
1362            return
1363
1364        v = qual_ret
1365        if len(v) > 3:
1366            self.p(v)
1367            v = ''
1368        elif qual_ret[-1] != '*':
1369            v += ' '
1370        v += name + '('
1371        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1372        delta_ind = len(v) - len(ind)
1373        v += args[0]
1374        i = 1
1375        while i < len(args):
1376            next_len = len(v) + len(args[i])
1377            if v[0] == '\t':
1378                next_len += delta_ind
1379            if next_len > 76:
1380                self.p(v + ',')
1381                v = ind
1382            else:
1383                v += ', '
1384            v += args[i]
1385            i += 1
1386        self.p(v + ')' + suffix)
1387
1388    def write_func_lvar(self, local_vars):
1389        if not local_vars:
1390            return
1391
1392        if type(local_vars) is str:
1393            local_vars = [local_vars]
1394
1395        local_vars.sort(key=len, reverse=True)
1396        for var in local_vars:
1397            self.p(var)
1398        self.nl()
1399
1400    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1401        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1402        self.write_func_lvar(local_vars=local_vars)
1403
1404        self.block_start()
1405        for line in body:
1406            self.p(line)
1407        self.block_end()
1408
1409    def writes_defines(self, defines):
1410        longest = 0
1411        for define in defines:
1412            if len(define[0]) > longest:
1413                longest = len(define[0])
1414        longest = ((longest + 8) // 8) * 8
1415        for define in defines:
1416            line = '#define ' + define[0]
1417            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1418            if type(define[1]) is int:
1419                line += str(define[1])
1420            elif type(define[1]) is str:
1421                line += '"' + define[1] + '"'
1422            self.p(line)
1423
1424    def write_struct_init(self, members):
1425        longest = max([len(x[0]) for x in members])
1426        longest += 1  # because we prepend a .
1427        longest = ((longest + 8) // 8) * 8
1428        for one in members:
1429            line = '.' + one[0]
1430            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1431            line += '= ' + str(one[1]) + ','
1432            self.p(line)
1433
1434    def ifdef_block(self, config):
1435        config_option = None
1436        if config:
1437            config_option = 'CONFIG_' + c_upper(config)
1438        if self._ifdef_block == config_option:
1439            return
1440
1441        if self._ifdef_block:
1442            self.p('#endif /* ' + self._ifdef_block + ' */')
1443        if config_option:
1444            self.p('#ifdef ' + config_option)
1445        self._ifdef_block = config_option
1446
1447
1448scalars = {'u8', 'u16', 'u32', 'u64', 's8', 's16', 's32', 's64', 'uint', 'sint'}
1449
1450direction_to_suffix = {
1451    'reply': '_rsp',
1452    'request': '_req',
1453    '': ''
1454}
1455
1456op_mode_to_wrapper = {
1457    'do': '',
1458    'dump': '_list',
1459    'notify': '_ntf',
1460    'event': '',
1461}
1462
1463_C_KW = {
1464    'auto',
1465    'bool',
1466    'break',
1467    'case',
1468    'char',
1469    'const',
1470    'continue',
1471    'default',
1472    'do',
1473    'double',
1474    'else',
1475    'enum',
1476    'extern',
1477    'float',
1478    'for',
1479    'goto',
1480    'if',
1481    'inline',
1482    'int',
1483    'long',
1484    'register',
1485    'return',
1486    'short',
1487    'signed',
1488    'sizeof',
1489    'static',
1490    'struct',
1491    'switch',
1492    'typedef',
1493    'union',
1494    'unsigned',
1495    'void',
1496    'volatile',
1497    'while'
1498}
1499
1500
1501def rdir(direction):
1502    if direction == 'reply':
1503        return 'request'
1504    if direction == 'request':
1505        return 'reply'
1506    return direction
1507
1508
1509def op_prefix(ri, direction, deref=False):
1510    suffix = f"_{ri.type_name}"
1511
1512    if not ri.op_mode or ri.op_mode == 'do':
1513        suffix += f"{direction_to_suffix[direction]}"
1514    else:
1515        if direction == 'request':
1516            suffix += '_req_dump'
1517        else:
1518            if ri.type_consistent:
1519                if deref:
1520                    suffix += f"{direction_to_suffix[direction]}"
1521                else:
1522                    suffix += op_mode_to_wrapper[ri.op_mode]
1523            else:
1524                suffix += '_rsp'
1525                suffix += '_dump' if deref else '_list'
1526
1527    return f"{ri.family.c_name}{suffix}"
1528
1529
1530def type_name(ri, direction, deref=False):
1531    return f"struct {op_prefix(ri, direction, deref=deref)}"
1532
1533
1534def print_prototype(ri, direction, terminate=True, doc=None):
1535    suffix = ';' if terminate else ''
1536
1537    fname = ri.op.render_name
1538    if ri.op_mode == 'dump':
1539        fname += '_dump'
1540
1541    args = ['struct ynl_sock *ys']
1542    if 'request' in ri.op[ri.op_mode]:
1543        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1544
1545    ret = 'int'
1546    if 'reply' in ri.op[ri.op_mode]:
1547        ret = f"{type_name(ri, rdir(direction))} *"
1548
1549    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1550
1551
1552def print_req_prototype(ri):
1553    print_prototype(ri, "request", doc=ri.op['doc'])
1554
1555
1556def print_dump_prototype(ri):
1557    print_prototype(ri, "request")
1558
1559
1560def put_typol_fwd(cw, struct):
1561    cw.p(f'extern const struct ynl_policy_nest {struct.render_name}_nest;')
1562
1563
1564def put_typol(cw, struct):
1565    type_max = struct.attr_set.max_name
1566    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1567
1568    for _, arg in struct.member_list():
1569        arg.attr_typol(cw)
1570
1571    cw.block_end(line=';')
1572    cw.nl()
1573
1574    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1575    cw.p(f'.max_attr = {type_max},')
1576    cw.p(f'.table = {struct.render_name}_policy,')
1577    cw.block_end(line=';')
1578    cw.nl()
1579
1580
1581def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1582    args = [f'int {arg_name}']
1583    if enum:
1584        args = [enum.user_type + ' ' + arg_name]
1585    cw.write_func_prot('const char *', f'{render_name}_str', args)
1586    cw.block_start()
1587    if enum and enum.type == 'flags':
1588        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1589    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
1590    cw.p('return NULL;')
1591    cw.p(f'return {map_name}[{arg_name}];')
1592    cw.block_end()
1593    cw.nl()
1594
1595
1596def put_op_name_fwd(family, cw):
1597    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
1598
1599
1600def put_op_name(family, cw):
1601    map_name = f'{family.c_name}_op_strmap'
1602    cw.block_start(line=f"static const char * const {map_name}[] =")
1603    for op_name, op in family.msgs.items():
1604        if op.rsp_value:
1605            # Make sure we don't add duplicated entries, if multiple commands
1606            # produce the same response in legacy families.
1607            if family.rsp_by_value[op.rsp_value] != op:
1608                cw.p(f'// skip "{op_name}", duplicate reply value')
1609                continue
1610
1611            if op.req_value == op.rsp_value:
1612                cw.p(f'[{op.enum_name}] = "{op_name}",')
1613            else:
1614                cw.p(f'[{op.rsp_value}] = "{op_name}",')
1615    cw.block_end(line=';')
1616    cw.nl()
1617
1618    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
1619
1620
1621def put_enum_to_str_fwd(family, cw, enum):
1622    args = [enum.user_type + ' value']
1623    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1624
1625
1626def put_enum_to_str(family, cw, enum):
1627    map_name = f'{enum.render_name}_strmap'
1628    cw.block_start(line=f"static const char * const {map_name}[] =")
1629    for entry in enum.entries.values():
1630        cw.p(f'[{entry.value}] = "{entry.name}",')
1631    cw.block_end(line=';')
1632    cw.nl()
1633
1634    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
1635
1636
1637def put_req_nested_prototype(ri, struct, suffix=';'):
1638    func_args = ['struct nlmsghdr *nlh',
1639                 'unsigned int attr_type',
1640                 f'{struct.ptr_name}obj']
1641
1642    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
1643                          suffix=suffix)
1644
1645
1646def put_req_nested(ri, struct):
1647    put_req_nested_prototype(ri, struct, suffix='')
1648    ri.cw.block_start()
1649    ri.cw.write_func_lvar('struct nlattr *nest;')
1650
1651    ri.cw.p("nest = ynl_attr_nest_start(nlh, attr_type);")
1652
1653    for _, arg in struct.member_list():
1654        arg.attr_put(ri, "obj")
1655
1656    ri.cw.p("ynl_attr_nest_end(nlh, nest);")
1657
1658    ri.cw.nl()
1659    ri.cw.p('return 0;')
1660    ri.cw.block_end()
1661    ri.cw.nl()
1662
1663
1664def _multi_parse(ri, struct, init_lines, local_vars):
1665    if struct.nested:
1666        iter_line = "ynl_attr_for_each_nested(attr, nested)"
1667    else:
1668        if ri.fixed_hdr:
1669            local_vars += ['void *hdr;']
1670        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
1671
1672    array_nests = set()
1673    multi_attrs = set()
1674    needs_parg = False
1675    for arg, aspec in struct.member_list():
1676        if aspec['type'] == 'indexed-array' and 'sub-type' in aspec:
1677            if aspec["sub-type"] == 'nest':
1678                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1679                array_nests.add(arg)
1680            elif aspec['sub-type'] in scalars:
1681                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1682                array_nests.add(arg)
1683            else:
1684                raise Exception(f'Not supported sub-type {aspec["sub-type"]}')
1685        if 'multi-attr' in aspec:
1686            multi_attrs.add(arg)
1687        needs_parg |= 'nested-attributes' in aspec
1688    if array_nests or multi_attrs:
1689        local_vars.append('int i;')
1690    if needs_parg:
1691        local_vars.append('struct ynl_parse_arg parg;')
1692        init_lines.append('parg.ys = yarg->ys;')
1693
1694    all_multi = array_nests | multi_attrs
1695
1696    for anest in sorted(all_multi):
1697        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
1698
1699    ri.cw.block_start()
1700    ri.cw.write_func_lvar(local_vars)
1701
1702    for line in init_lines:
1703        ri.cw.p(line)
1704    ri.cw.nl()
1705
1706    for arg in struct.inherited:
1707        ri.cw.p(f'dst->{arg} = {arg};')
1708
1709    if ri.fixed_hdr:
1710        ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
1711        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({ri.fixed_hdr}));")
1712    for anest in sorted(all_multi):
1713        aspec = struct[anest]
1714        ri.cw.p(f"if (dst->{aspec.c_name})")
1715        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
1716
1717    ri.cw.nl()
1718    ri.cw.block_start(line=iter_line)
1719    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
1720    ri.cw.nl()
1721
1722    first = True
1723    for _, arg in struct.member_list():
1724        good = arg.attr_get(ri, 'dst', first=first)
1725        # First may be 'unused' or 'pad', ignore those
1726        first &= not good
1727
1728    ri.cw.block_end()
1729    ri.cw.nl()
1730
1731    for anest in sorted(array_nests):
1732        aspec = struct[anest]
1733
1734        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1735        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1736        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1737        ri.cw.p('i = 0;')
1738        if 'nested-attributes' in aspec:
1739            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1740        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1741        if 'nested-attributes' in aspec:
1742            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1743            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
1744            ri.cw.p('return YNL_PARSE_CB_ERROR;')
1745        elif aspec.sub_type in scalars:
1746            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.sub_type}(attr);")
1747        else:
1748            raise Exception(f"Nest parsing type not supported in {aspec['name']}")
1749        ri.cw.p('i++;')
1750        ri.cw.block_end()
1751        ri.cw.block_end()
1752    ri.cw.nl()
1753
1754    for anest in sorted(multi_attrs):
1755        aspec = struct[anest]
1756        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1757        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1758        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1759        ri.cw.p('i = 0;')
1760        if 'nested-attributes' in aspec:
1761            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1762        ri.cw.block_start(line=iter_line)
1763        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
1764        if 'nested-attributes' in aspec:
1765            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1766            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1767            ri.cw.p('return YNL_PARSE_CB_ERROR;')
1768        elif aspec.type in scalars:
1769            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
1770        else:
1771            raise Exception('Nest parsing type not supported yet')
1772        ri.cw.p('i++;')
1773        ri.cw.block_end()
1774        ri.cw.block_end()
1775        ri.cw.block_end()
1776    ri.cw.nl()
1777
1778    if struct.nested:
1779        ri.cw.p('return 0;')
1780    else:
1781        ri.cw.p('return YNL_PARSE_CB_OK;')
1782    ri.cw.block_end()
1783    ri.cw.nl()
1784
1785
1786def parse_rsp_nested_prototype(ri, struct, suffix=';'):
1787    func_args = ['struct ynl_parse_arg *yarg',
1788                 'const struct nlattr *nested']
1789    for arg in struct.inherited:
1790        func_args.append('__u32 ' + arg)
1791
1792    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
1793                          suffix=suffix)
1794
1795
1796def parse_rsp_nested(ri, struct):
1797    parse_rsp_nested_prototype(ri, struct, suffix='')
1798
1799    local_vars = ['const struct nlattr *attr;',
1800                  f'{struct.ptr_name}dst = yarg->data;']
1801    init_lines = []
1802
1803    if struct.member_list():
1804        _multi_parse(ri, struct, init_lines, local_vars)
1805    else:
1806        # Empty nest
1807        ri.cw.block_start()
1808        ri.cw.p('return 0;')
1809        ri.cw.block_end()
1810        ri.cw.nl()
1811
1812
1813def parse_rsp_msg(ri, deref=False):
1814    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1815        return
1816
1817    func_args = ['const struct nlmsghdr *nlh',
1818                 'struct ynl_parse_arg *yarg']
1819
1820    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1821                  'const struct nlattr *attr;']
1822    init_lines = ['dst = yarg->data;']
1823
1824    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1825
1826    if ri.struct["reply"].member_list():
1827        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1828    else:
1829        # Empty reply
1830        ri.cw.block_start()
1831        ri.cw.p('return YNL_PARSE_CB_OK;')
1832        ri.cw.block_end()
1833        ri.cw.nl()
1834
1835
1836def print_req(ri):
1837    ret_ok = '0'
1838    ret_err = '-1'
1839    direction = "request"
1840    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
1841                  'struct nlmsghdr *nlh;',
1842                  'int err;']
1843
1844    if 'reply' in ri.op[ri.op_mode]:
1845        ret_ok = 'rsp'
1846        ret_err = 'NULL'
1847        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
1848
1849    if ri.fixed_hdr:
1850        local_vars += ['size_t hdr_len;',
1851                       'void *hdr;']
1852
1853    print_prototype(ri, direction, terminate=False)
1854    ri.cw.block_start()
1855    ri.cw.write_func_lvar(local_vars)
1856
1857    ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1858
1859    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1860    if 'reply' in ri.op[ri.op_mode]:
1861        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1862    ri.cw.nl()
1863
1864    if ri.fixed_hdr:
1865        ri.cw.p("hdr_len = sizeof(req->_hdr);")
1866        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
1867        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
1868        ri.cw.nl()
1869
1870    for _, attr in ri.struct["request"].member_list():
1871        attr.attr_put(ri, "req")
1872    ri.cw.nl()
1873
1874    if 'reply' in ri.op[ri.op_mode]:
1875        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1876        ri.cw.p('yrs.yarg.data = rsp;')
1877        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
1878        if ri.op.value is not None:
1879            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
1880        else:
1881            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
1882        ri.cw.nl()
1883    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
1884    ri.cw.p('if (err < 0)')
1885    if 'reply' in ri.op[ri.op_mode]:
1886        ri.cw.p('goto err_free;')
1887    else:
1888        ri.cw.p('return -1;')
1889    ri.cw.nl()
1890
1891    ri.cw.p(f"return {ret_ok};")
1892    ri.cw.nl()
1893
1894    if 'reply' in ri.op[ri.op_mode]:
1895        ri.cw.p('err_free:')
1896        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1897        ri.cw.p(f"return {ret_err};")
1898
1899    ri.cw.block_end()
1900
1901
1902def print_dump(ri):
1903    direction = "request"
1904    print_prototype(ri, direction, terminate=False)
1905    ri.cw.block_start()
1906    local_vars = ['struct ynl_dump_state yds = {};',
1907                  'struct nlmsghdr *nlh;',
1908                  'int err;']
1909
1910    if ri.fixed_hdr:
1911        local_vars += ['size_t hdr_len;',
1912                       'void *hdr;']
1913
1914    ri.cw.write_func_lvar(local_vars)
1915
1916    ri.cw.p('yds.yarg.ys = ys;')
1917    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1918    ri.cw.p("yds.yarg.data = NULL;")
1919    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1920    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1921    if ri.op.value is not None:
1922        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
1923    else:
1924        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
1925    ri.cw.nl()
1926    ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1927
1928    if ri.fixed_hdr:
1929        ri.cw.p("hdr_len = sizeof(req->_hdr);")
1930        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
1931        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
1932        ri.cw.nl()
1933
1934    if "request" in ri.op[ri.op_mode]:
1935        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1936        ri.cw.nl()
1937        for _, attr in ri.struct["request"].member_list():
1938            attr.attr_put(ri, "req")
1939    ri.cw.nl()
1940
1941    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
1942    ri.cw.p('if (err < 0)')
1943    ri.cw.p('goto free_list;')
1944    ri.cw.nl()
1945
1946    ri.cw.p('return yds.first;')
1947    ri.cw.nl()
1948    ri.cw.p('free_list:')
1949    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
1950    ri.cw.p('return NULL;')
1951    ri.cw.block_end()
1952
1953
1954def call_free(ri, direction, var):
1955    return f"{op_prefix(ri, direction)}_free({var});"
1956
1957
1958def free_arg_name(direction):
1959    if direction:
1960        return direction_to_suffix[direction][1:]
1961    return 'obj'
1962
1963
1964def print_alloc_wrapper(ri, direction):
1965    name = op_prefix(ri, direction)
1966    ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
1967    ri.cw.block_start()
1968    ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
1969    ri.cw.block_end()
1970
1971
1972def print_free_prototype(ri, direction, suffix=';'):
1973    name = op_prefix(ri, direction)
1974    struct_name = name
1975    if ri.type_name_conflict:
1976        struct_name += '_'
1977    arg = free_arg_name(direction)
1978    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
1979
1980
1981def _print_type(ri, direction, struct):
1982    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
1983    if not direction and ri.type_name_conflict:
1984        suffix += '_'
1985
1986    if ri.op_mode == 'dump':
1987        suffix += '_dump'
1988
1989    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
1990
1991    if ri.fixed_hdr:
1992        ri.cw.p(ri.fixed_hdr + ' _hdr;')
1993        ri.cw.nl()
1994
1995    meta_started = False
1996    for _, attr in struct.member_list():
1997        for type_filter in ['len', 'bit']:
1998            line = attr.presence_member(ri.ku_space, type_filter)
1999            if line:
2000                if not meta_started:
2001                    ri.cw.block_start(line=f"struct")
2002                    meta_started = True
2003                ri.cw.p(line)
2004    if meta_started:
2005        ri.cw.block_end(line='_present;')
2006        ri.cw.nl()
2007
2008    for arg in struct.inherited:
2009        ri.cw.p(f"__u32 {arg};")
2010
2011    for _, attr in struct.member_list():
2012        attr.struct_member(ri)
2013
2014    ri.cw.block_end(line=';')
2015    ri.cw.nl()
2016
2017
2018def print_type(ri, direction):
2019    _print_type(ri, direction, ri.struct[direction])
2020
2021
2022def print_type_full(ri, struct):
2023    _print_type(ri, "", struct)
2024
2025
2026def print_type_helpers(ri, direction, deref=False):
2027    print_free_prototype(ri, direction)
2028    ri.cw.nl()
2029
2030    if ri.ku_space == 'user' and direction == 'request':
2031        for _, attr in ri.struct[direction].member_list():
2032            attr.setter(ri, ri.attr_set, direction, deref=deref)
2033    ri.cw.nl()
2034
2035
2036def print_req_type_helpers(ri):
2037    if len(ri.struct["request"].attr_list) == 0:
2038        return
2039    print_alloc_wrapper(ri, "request")
2040    print_type_helpers(ri, "request")
2041
2042
2043def print_rsp_type_helpers(ri):
2044    if 'reply' not in ri.op[ri.op_mode]:
2045        return
2046    print_type_helpers(ri, "reply")
2047
2048
2049def print_parse_prototype(ri, direction, terminate=True):
2050    suffix = "_rsp" if direction == "reply" else "_req"
2051    term = ';' if terminate else ''
2052
2053    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
2054                          ['const struct nlattr **tb',
2055                           f"struct {ri.op.render_name}{suffix} *req"],
2056                          suffix=term)
2057
2058
2059def print_req_type(ri):
2060    if len(ri.struct["request"].attr_list) == 0:
2061        return
2062    print_type(ri, "request")
2063
2064
2065def print_req_free(ri):
2066    if 'request' not in ri.op[ri.op_mode]:
2067        return
2068    _free_type(ri, 'request', ri.struct['request'])
2069
2070
2071def print_rsp_type(ri):
2072    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2073        direction = 'reply'
2074    elif ri.op_mode == 'event':
2075        direction = 'reply'
2076    else:
2077        return
2078    print_type(ri, direction)
2079
2080
2081def print_wrapped_type(ri):
2082    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2083    if ri.op_mode == 'dump':
2084        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2085    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2086        ri.cw.p('__u16 family;')
2087        ri.cw.p('__u8 cmd;')
2088        ri.cw.p('struct ynl_ntf_base_type *next;')
2089        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2090    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2091    ri.cw.block_end(line=';')
2092    ri.cw.nl()
2093    print_free_prototype(ri, 'reply')
2094    ri.cw.nl()
2095
2096
2097def _free_type_members_iter(ri, struct):
2098    for _, attr in struct.member_list():
2099        if attr.free_needs_iter():
2100            ri.cw.p('unsigned int i;')
2101            ri.cw.nl()
2102            break
2103
2104
2105def _free_type_members(ri, var, struct, ref=''):
2106    for _, attr in struct.member_list():
2107        attr.free(ri, var, ref)
2108
2109
2110def _free_type(ri, direction, struct):
2111    var = free_arg_name(direction)
2112
2113    print_free_prototype(ri, direction, suffix='')
2114    ri.cw.block_start()
2115    _free_type_members_iter(ri, struct)
2116    _free_type_members(ri, var, struct)
2117    if direction:
2118        ri.cw.p(f'free({var});')
2119    ri.cw.block_end()
2120    ri.cw.nl()
2121
2122
2123def free_rsp_nested_prototype(ri):
2124        print_free_prototype(ri, "")
2125
2126
2127def free_rsp_nested(ri, struct):
2128    _free_type(ri, "", struct)
2129
2130
2131def print_rsp_free(ri):
2132    if 'reply' not in ri.op[ri.op_mode]:
2133        return
2134    _free_type(ri, 'reply', ri.struct['reply'])
2135
2136
2137def print_dump_type_free(ri):
2138    sub_type = type_name(ri, 'reply')
2139
2140    print_free_prototype(ri, 'reply', suffix='')
2141    ri.cw.block_start()
2142    ri.cw.p(f"{sub_type} *next = rsp;")
2143    ri.cw.nl()
2144    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2145    _free_type_members_iter(ri, ri.struct['reply'])
2146    ri.cw.p('rsp = next;')
2147    ri.cw.p('next = rsp->next;')
2148    ri.cw.nl()
2149
2150    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2151    ri.cw.p(f'free(rsp);')
2152    ri.cw.block_end()
2153    ri.cw.block_end()
2154    ri.cw.nl()
2155
2156
2157def print_ntf_type_free(ri):
2158    print_free_prototype(ri, 'reply', suffix='')
2159    ri.cw.block_start()
2160    _free_type_members_iter(ri, ri.struct['reply'])
2161    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2162    ri.cw.p(f'free(rsp);')
2163    ri.cw.block_end()
2164    ri.cw.nl()
2165
2166
2167def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2168    if terminate and ri and policy_should_be_static(struct.family):
2169        return
2170
2171    if terminate:
2172        prefix = 'extern '
2173    else:
2174        if ri and policy_should_be_static(struct.family):
2175            prefix = 'static '
2176        else:
2177            prefix = ''
2178
2179    suffix = ';' if terminate else ' = {'
2180
2181    max_attr = struct.attr_max_val
2182    if ri:
2183        name = ri.op.render_name
2184        if ri.op.dual_policy:
2185            name += '_' + ri.op_mode
2186    else:
2187        name = struct.render_name
2188    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2189
2190
2191def print_req_policy(cw, struct, ri=None):
2192    if ri and ri.op:
2193        cw.ifdef_block(ri.op.get('config-cond', None))
2194    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2195    for _, arg in struct.member_list():
2196        arg.attr_policy(cw)
2197    cw.p("};")
2198    cw.ifdef_block(None)
2199    cw.nl()
2200
2201
2202def kernel_can_gen_family_struct(family):
2203    return family.proto == 'genetlink'
2204
2205
2206def policy_should_be_static(family):
2207    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2208
2209
2210def print_kernel_policy_ranges(family, cw):
2211    first = True
2212    for _, attr_set in family.attr_sets.items():
2213        if attr_set.subset_of:
2214            continue
2215
2216        for _, attr in attr_set.items():
2217            if not attr.request:
2218                continue
2219            if 'full-range' not in attr.checks:
2220                continue
2221
2222            if first:
2223                cw.p('/* Integer value ranges */')
2224                first = False
2225
2226            sign = '' if attr.type[0] == 'u' else '_signed'
2227            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2228            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2229            members = []
2230            if 'min' in attr.checks:
2231                members.append(('min', attr.get_limit_str('min', suffix=suffix)))
2232            if 'max' in attr.checks:
2233                members.append(('max', attr.get_limit_str('max', suffix=suffix)))
2234            cw.write_struct_init(members)
2235            cw.block_end(line=';')
2236            cw.nl()
2237
2238
2239def print_kernel_op_table_fwd(family, cw, terminate):
2240    exported = not kernel_can_gen_family_struct(family)
2241
2242    if not terminate or exported:
2243        cw.p(f"/* Ops table for {family.ident_name} */")
2244
2245        pol_to_struct = {'global': 'genl_small_ops',
2246                         'per-op': 'genl_ops',
2247                         'split': 'genl_split_ops'}
2248        struct_type = pol_to_struct[family.kernel_policy]
2249
2250        if not exported:
2251            cnt = ""
2252        elif family.kernel_policy == 'split':
2253            cnt = 0
2254            for op in family.ops.values():
2255                if 'do' in op:
2256                    cnt += 1
2257                if 'dump' in op:
2258                    cnt += 1
2259        else:
2260            cnt = len(family.ops)
2261
2262        qual = 'static const' if not exported else 'const'
2263        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2264        if terminate:
2265            cw.p(f"extern {line};")
2266        else:
2267            cw.block_start(line=line + ' =')
2268
2269    if not terminate:
2270        return
2271
2272    cw.nl()
2273    for name in family.hooks['pre']['do']['list']:
2274        cw.write_func_prot('int', c_lower(name),
2275                           ['const struct genl_split_ops *ops',
2276                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2277    for name in family.hooks['post']['do']['list']:
2278        cw.write_func_prot('void', c_lower(name),
2279                           ['const struct genl_split_ops *ops',
2280                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2281    for name in family.hooks['pre']['dump']['list']:
2282        cw.write_func_prot('int', c_lower(name),
2283                           ['struct netlink_callback *cb'], suffix=';')
2284    for name in family.hooks['post']['dump']['list']:
2285        cw.write_func_prot('int', c_lower(name),
2286                           ['struct netlink_callback *cb'], suffix=';')
2287
2288    cw.nl()
2289
2290    for op_name, op in family.ops.items():
2291        if op.is_async:
2292            continue
2293
2294        if 'do' in op:
2295            name = c_lower(f"{family.ident_name}-nl-{op_name}-doit")
2296            cw.write_func_prot('int', name,
2297                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2298
2299        if 'dump' in op:
2300            name = c_lower(f"{family.ident_name}-nl-{op_name}-dumpit")
2301            cw.write_func_prot('int', name,
2302                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2303    cw.nl()
2304
2305
2306def print_kernel_op_table_hdr(family, cw):
2307    print_kernel_op_table_fwd(family, cw, terminate=True)
2308
2309
2310def print_kernel_op_table(family, cw):
2311    print_kernel_op_table_fwd(family, cw, terminate=False)
2312    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2313        for op_name, op in family.ops.items():
2314            if op.is_async:
2315                continue
2316
2317            cw.ifdef_block(op.get('config-cond', None))
2318            cw.block_start()
2319            members = [('cmd', op.enum_name)]
2320            if 'dont-validate' in op:
2321                members.append(('validate',
2322                                ' | '.join([c_upper('genl-dont-validate-' + x)
2323                                            for x in op['dont-validate']])), )
2324            for op_mode in ['do', 'dump']:
2325                if op_mode in op:
2326                    name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2327                    members.append((op_mode + 'it', name))
2328            if family.kernel_policy == 'per-op':
2329                struct = Struct(family, op['attribute-set'],
2330                                type_list=op['do']['request']['attributes'])
2331
2332                name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2333                members.append(('policy', name))
2334                members.append(('maxattr', struct.attr_max_val.enum_name))
2335            if 'flags' in op:
2336                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2337            cw.write_struct_init(members)
2338            cw.block_end(line=',')
2339    elif family.kernel_policy == 'split':
2340        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2341                    'dump': {'pre': 'start', 'post': 'done'}}
2342
2343        for op_name, op in family.ops.items():
2344            for op_mode in ['do', 'dump']:
2345                if op.is_async or op_mode not in op:
2346                    continue
2347
2348                cw.ifdef_block(op.get('config-cond', None))
2349                cw.block_start()
2350                members = [('cmd', op.enum_name)]
2351                if 'dont-validate' in op:
2352                    dont_validate = []
2353                    for x in op['dont-validate']:
2354                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2355                            continue
2356                        if op_mode == "dump" and x == 'strict':
2357                            continue
2358                        dont_validate.append(x)
2359
2360                    if dont_validate:
2361                        members.append(('validate',
2362                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2363                                                    for x in dont_validate])), )
2364                name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2365                if 'pre' in op[op_mode]:
2366                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2367                members.append((op_mode + 'it', name))
2368                if 'post' in op[op_mode]:
2369                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2370                if 'request' in op[op_mode]:
2371                    struct = Struct(family, op['attribute-set'],
2372                                    type_list=op[op_mode]['request']['attributes'])
2373
2374                    if op.dual_policy:
2375                        name = c_lower(f"{family.ident_name}-{op_name}-{op_mode}-nl-policy")
2376                    else:
2377                        name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2378                    members.append(('policy', name))
2379                    members.append(('maxattr', struct.attr_max_val.enum_name))
2380                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2381                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2382                cw.write_struct_init(members)
2383                cw.block_end(line=',')
2384    cw.ifdef_block(None)
2385
2386    cw.block_end(line=';')
2387    cw.nl()
2388
2389
2390def print_kernel_mcgrp_hdr(family, cw):
2391    if not family.mcgrps['list']:
2392        return
2393
2394    cw.block_start('enum')
2395    for grp in family.mcgrps['list']:
2396        grp_id = c_upper(f"{family.ident_name}-nlgrp-{grp['name']},")
2397        cw.p(grp_id)
2398    cw.block_end(';')
2399    cw.nl()
2400
2401
2402def print_kernel_mcgrp_src(family, cw):
2403    if not family.mcgrps['list']:
2404        return
2405
2406    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
2407    for grp in family.mcgrps['list']:
2408        name = grp['name']
2409        grp_id = c_upper(f"{family.ident_name}-nlgrp-{name}")
2410        cw.p('[' + grp_id + '] = { "' + name + '", },')
2411    cw.block_end(';')
2412    cw.nl()
2413
2414
2415def print_kernel_family_struct_hdr(family, cw):
2416    if not kernel_can_gen_family_struct(family):
2417        return
2418
2419    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
2420    cw.nl()
2421    if 'sock-priv' in family.kernel_family:
2422        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
2423        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
2424        cw.nl()
2425
2426
2427def print_kernel_family_struct_src(family, cw):
2428    if not kernel_can_gen_family_struct(family):
2429        return
2430
2431    if 'sock-priv' in family.kernel_family:
2432        # Generate "trampolines" to make CFI happy
2433        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_init",
2434                      [f"{family.c_name}_nl_sock_priv_init(priv);"],
2435                      ["void *priv"])
2436        cw.nl()
2437        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_destroy",
2438                      [f"{family.c_name}_nl_sock_priv_destroy(priv);"],
2439                      ["void *priv"])
2440        cw.nl()
2441
2442    cw.block_start(f"struct genl_family {family.ident_name}_nl_family __ro_after_init =")
2443    cw.p('.name\t\t= ' + family.fam_key + ',')
2444    cw.p('.version\t= ' + family.ver_key + ',')
2445    cw.p('.netnsok\t= true,')
2446    cw.p('.parallel_ops\t= true,')
2447    cw.p('.module\t\t= THIS_MODULE,')
2448    if family.kernel_policy == 'per-op':
2449        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
2450        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2451    elif family.kernel_policy == 'split':
2452        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
2453        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2454    if family.mcgrps['list']:
2455        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
2456        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
2457    if 'sock-priv' in family.kernel_family:
2458        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
2459        cw.p(f'.sock_priv_init\t= __{family.c_name}_nl_sock_priv_init,')
2460        cw.p(f'.sock_priv_destroy = __{family.c_name}_nl_sock_priv_destroy,')
2461    cw.block_end(';')
2462
2463
2464def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2465    start_line = 'enum'
2466    if enum_name in obj:
2467        if obj[enum_name]:
2468            start_line = 'enum ' + c_lower(obj[enum_name])
2469    elif ckey and ckey in obj:
2470        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
2471    cw.block_start(line=start_line)
2472
2473
2474def render_uapi_unified(family, cw, max_by_define, separate_ntf):
2475    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2476    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2477    max_value = f"({cnt_name} - 1)"
2478
2479    uapi_enum_start(family, cw, family['operations'], 'enum-name')
2480    val = 0
2481    for op in family.msgs.values():
2482        if separate_ntf and ('notify' in op or 'event' in op):
2483            continue
2484
2485        suffix = ','
2486        if op.value != val:
2487            suffix = f" = {op.value},"
2488            val = op.value
2489        cw.p(op.enum_name + suffix)
2490        val += 1
2491    cw.nl()
2492    cw.p(cnt_name + ('' if max_by_define else ','))
2493    if not max_by_define:
2494        cw.p(f"{max_name} = {max_value}")
2495    cw.block_end(line=';')
2496    if max_by_define:
2497        cw.p(f"#define {max_name} {max_value}")
2498    cw.nl()
2499
2500
2501def render_uapi_directional(family, cw, max_by_define):
2502    max_name = f"{family.op_prefix}USER_MAX"
2503    cnt_name = f"__{family.op_prefix}USER_CNT"
2504    max_value = f"({cnt_name} - 1)"
2505
2506    cw.block_start(line='enum')
2507    cw.p(c_upper(f'{family.name}_MSG_USER_NONE = 0,'))
2508    val = 0
2509    for op in family.msgs.values():
2510        if 'do' in op and 'event' not in op:
2511            suffix = ','
2512            if op.value and op.value != val:
2513                suffix = f" = {op.value},"
2514                val = op.value
2515            cw.p(op.enum_name + suffix)
2516            val += 1
2517    cw.nl()
2518    cw.p(cnt_name + ('' if max_by_define else ','))
2519    if not max_by_define:
2520        cw.p(f"{max_name} = {max_value}")
2521    cw.block_end(line=';')
2522    if max_by_define:
2523        cw.p(f"#define {max_name} {max_value}")
2524    cw.nl()
2525
2526    max_name = f"{family.op_prefix}KERNEL_MAX"
2527    cnt_name = f"__{family.op_prefix}KERNEL_CNT"
2528    max_value = f"({cnt_name} - 1)"
2529
2530    cw.block_start(line='enum')
2531    cw.p(c_upper(f'{family.name}_MSG_KERNEL_NONE = 0,'))
2532    val = 0
2533    for op in family.msgs.values():
2534        if ('do' in op and 'reply' in op['do']) or 'notify' in op or 'event' in op:
2535            enum_name = op.enum_name
2536            if 'event' not in op and 'notify' not in op:
2537                enum_name = f'{enum_name}_REPLY'
2538
2539            suffix = ','
2540            if op.value and op.value != val:
2541                suffix = f" = {op.value},"
2542                val = op.value
2543            cw.p(enum_name + suffix)
2544            val += 1
2545    cw.nl()
2546    cw.p(cnt_name + ('' if max_by_define else ','))
2547    if not max_by_define:
2548        cw.p(f"{max_name} = {max_value}")
2549    cw.block_end(line=';')
2550    if max_by_define:
2551        cw.p(f"#define {max_name} {max_value}")
2552    cw.nl()
2553
2554
2555def render_uapi(family, cw):
2556    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
2557    hdr_prot = hdr_prot.replace('/', '_')
2558    cw.p('#ifndef ' + hdr_prot)
2559    cw.p('#define ' + hdr_prot)
2560    cw.nl()
2561
2562    defines = [(family.fam_key, family["name"]),
2563               (family.ver_key, family.get('version', 1))]
2564    cw.writes_defines(defines)
2565    cw.nl()
2566
2567    defines = []
2568    for const in family['definitions']:
2569        if const.get('header'):
2570            continue
2571
2572        if const['type'] != 'const':
2573            cw.writes_defines(defines)
2574            defines = []
2575            cw.nl()
2576
2577        # Write kdoc for enum and flags (one day maybe also structs)
2578        if const['type'] == 'enum' or const['type'] == 'flags':
2579            enum = family.consts[const['name']]
2580
2581            if enum.header:
2582                continue
2583
2584            if enum.has_doc():
2585                if enum.has_entry_doc():
2586                    cw.p('/**')
2587                    doc = ''
2588                    if 'doc' in enum:
2589                        doc = ' - ' + enum['doc']
2590                    cw.write_doc_line(enum.enum_name + doc)
2591                else:
2592                    cw.p('/*')
2593                    cw.write_doc_line(enum['doc'], indent=False)
2594                for entry in enum.entries.values():
2595                    if entry.has_doc():
2596                        doc = '@' + entry.c_name + ': ' + entry['doc']
2597                        cw.write_doc_line(doc)
2598                cw.p(' */')
2599
2600            uapi_enum_start(family, cw, const, 'name')
2601            name_pfx = const.get('name-prefix', f"{family.ident_name}-{const['name']}-")
2602            for entry in enum.entries.values():
2603                suffix = ','
2604                if entry.value_change:
2605                    suffix = f" = {entry.user_value()}" + suffix
2606                cw.p(entry.c_name + suffix)
2607
2608            if const.get('render-max', False):
2609                cw.nl()
2610                cw.p('/* private: */')
2611                if const['type'] == 'flags':
2612                    max_name = c_upper(name_pfx + 'mask')
2613                    max_val = f' = {enum.get_mask()},'
2614                    cw.p(max_name + max_val)
2615                else:
2616                    cnt_name = enum.enum_cnt_name
2617                    max_name = c_upper(name_pfx + 'max')
2618                    if not cnt_name:
2619                        cnt_name = '__' + name_pfx + 'max'
2620                    cw.p(c_upper(cnt_name) + ',')
2621                    cw.p(max_name + ' = (' + c_upper(cnt_name) + ' - 1)')
2622            cw.block_end(line=';')
2623            cw.nl()
2624        elif const['type'] == 'const':
2625            defines.append([c_upper(family.get('c-define-name',
2626                                               f"{family.ident_name}-{const['name']}")),
2627                            const['value']])
2628
2629    if defines:
2630        cw.writes_defines(defines)
2631        cw.nl()
2632
2633    max_by_define = family.get('max-by-define', False)
2634
2635    for _, attr_set in family.attr_sets.items():
2636        if attr_set.subset_of:
2637            continue
2638
2639        max_value = f"({attr_set.cnt_name} - 1)"
2640
2641        val = 0
2642        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2643        for _, attr in attr_set.items():
2644            suffix = ','
2645            if attr.value != val:
2646                suffix = f" = {attr.value},"
2647                val = attr.value
2648            val += 1
2649            cw.p(attr.enum_name + suffix)
2650        if attr_set.items():
2651            cw.nl()
2652        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
2653        if not max_by_define:
2654            cw.p(f"{attr_set.max_name} = {max_value}")
2655        cw.block_end(line=';')
2656        if max_by_define:
2657            cw.p(f"#define {attr_set.max_name} {max_value}")
2658        cw.nl()
2659
2660    # Commands
2661    separate_ntf = 'async-prefix' in family['operations']
2662
2663    if family.msg_id_model == 'unified':
2664        render_uapi_unified(family, cw, max_by_define, separate_ntf)
2665    elif family.msg_id_model == 'directional':
2666        render_uapi_directional(family, cw, max_by_define)
2667    else:
2668        raise Exception(f'Unsupported message enum-model {family.msg_id_model}')
2669
2670    if separate_ntf:
2671        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2672        for op in family.msgs.values():
2673            if separate_ntf and not ('notify' in op or 'event' in op):
2674                continue
2675
2676            suffix = ','
2677            if 'value' in op:
2678                suffix = f" = {op['value']},"
2679            cw.p(op.enum_name + suffix)
2680        cw.block_end(line=';')
2681        cw.nl()
2682
2683    # Multicast
2684    defines = []
2685    for grp in family.mcgrps['list']:
2686        name = grp['name']
2687        defines.append([c_upper(grp.get('c-define-name', f"{family.ident_name}-mcgrp-{name}")),
2688                        f'{name}'])
2689    cw.nl()
2690    if defines:
2691        cw.writes_defines(defines)
2692        cw.nl()
2693
2694    cw.p(f'#endif /* {hdr_prot} */')
2695
2696
2697def _render_user_ntf_entry(ri, op):
2698    ri.cw.block_start(line=f"[{op.enum_name}] = ")
2699    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
2700    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
2701    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
2702    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
2703    ri.cw.block_end(line=',')
2704
2705
2706def render_user_family(family, cw, prototype):
2707    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
2708    if prototype:
2709        cw.p(f'extern {symbol};')
2710        return
2711
2712    if family.ntfs:
2713        cw.block_start(line=f"static const struct ynl_ntf_info {family['name']}_ntf_info[] = ")
2714        for ntf_op_name, ntf_op in family.ntfs.items():
2715            if 'notify' in ntf_op:
2716                op = family.ops[ntf_op['notify']]
2717                ri = RenderInfo(cw, family, "user", op, "notify")
2718            elif 'event' in ntf_op:
2719                ri = RenderInfo(cw, family, "user", ntf_op, "event")
2720            else:
2721                raise Exception('Invalid notification ' + ntf_op_name)
2722            _render_user_ntf_entry(ri, ntf_op)
2723        for op_name, op in family.ops.items():
2724            if 'event' not in op:
2725                continue
2726            ri = RenderInfo(cw, family, "user", op, "event")
2727            _render_user_ntf_entry(ri, op)
2728        cw.block_end(line=";")
2729        cw.nl()
2730
2731    cw.block_start(f'{symbol} = ')
2732    cw.p(f'.name\t\t= "{family.c_name}",')
2733    if family.fixed_header:
2734        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
2735    else:
2736        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
2737    if family.ntfs:
2738        cw.p(f".ntf_info\t= {family['name']}_ntf_info,")
2739        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family['name']}_ntf_info),")
2740    cw.block_end(line=';')
2741
2742
2743def family_contains_bitfield32(family):
2744    for _, attr_set in family.attr_sets.items():
2745        if attr_set.subset_of:
2746            continue
2747        for _, attr in attr_set.items():
2748            if attr.type == "bitfield32":
2749                return True
2750    return False
2751
2752
2753def find_kernel_root(full_path):
2754    sub_path = ''
2755    while True:
2756        sub_path = os.path.join(os.path.basename(full_path), sub_path)
2757        full_path = os.path.dirname(full_path)
2758        maintainers = os.path.join(full_path, "MAINTAINERS")
2759        if os.path.exists(maintainers):
2760            return full_path, sub_path[:-1]
2761
2762
2763def main():
2764    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2765    parser.add_argument('--mode', dest='mode', type=str, required=True,
2766                        choices=('user', 'kernel', 'uapi'))
2767    parser.add_argument('--spec', dest='spec', type=str, required=True)
2768    parser.add_argument('--header', dest='header', action='store_true', default=None)
2769    parser.add_argument('--source', dest='header', action='store_false')
2770    parser.add_argument('--user-header', nargs='+', default=[])
2771    parser.add_argument('--cmp-out', action='store_true', default=None,
2772                        help='Do not overwrite the output file if the new output is identical to the old')
2773    parser.add_argument('--exclude-op', action='append', default=[])
2774    parser.add_argument('-o', dest='out_file', type=str, default=None)
2775    args = parser.parse_args()
2776
2777    if args.header is None:
2778        parser.error("--header or --source is required")
2779
2780    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
2781
2782    try:
2783        parsed = Family(args.spec, exclude_ops)
2784        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
2785            print('Spec license:', parsed.license)
2786            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
2787            os.sys.exit(1)
2788    except yaml.YAMLError as exc:
2789        print(exc)
2790        os.sys.exit(1)
2791        return
2792
2793    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
2794
2795    _, spec_kernel = find_kernel_root(args.spec)
2796    if args.mode == 'uapi' or args.header:
2797        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
2798    else:
2799        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
2800    cw.p("/* Do not edit directly, auto-generated from: */")
2801    cw.p(f"/*\t{spec_kernel} */")
2802    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2803    if args.exclude_op or args.user_header:
2804        line = ''
2805        line += ' --user-header '.join([''] + args.user_header)
2806        line += ' --exclude-op '.join([''] + args.exclude_op)
2807        cw.p(f'/* YNL-ARG{line} */')
2808    cw.nl()
2809
2810    if args.mode == 'uapi':
2811        render_uapi(parsed, cw)
2812        return
2813
2814    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
2815    if args.header:
2816        cw.p('#ifndef ' + hdr_prot)
2817        cw.p('#define ' + hdr_prot)
2818        cw.nl()
2819
2820    if args.out_file:
2821        hdr_file = os.path.basename(args.out_file[:-2]) + ".h"
2822    else:
2823        hdr_file = "generated_header_file.h"
2824
2825    if args.mode == 'kernel':
2826        cw.p('#include <net/netlink.h>')
2827        cw.p('#include <net/genetlink.h>')
2828        cw.nl()
2829        if not args.header:
2830            if args.out_file:
2831                cw.p(f'#include "{hdr_file}"')
2832            cw.nl()
2833        headers = ['uapi/' + parsed.uapi_header]
2834        headers += parsed.kernel_family.get('headers', [])
2835    else:
2836        cw.p('#include <stdlib.h>')
2837        cw.p('#include <string.h>')
2838        if args.header:
2839            cw.p('#include <linux/types.h>')
2840            if family_contains_bitfield32(parsed):
2841                cw.p('#include <linux/netlink.h>')
2842        else:
2843            cw.p(f'#include "{hdr_file}"')
2844            cw.p('#include "ynl.h"')
2845        headers = []
2846    for definition in parsed['definitions']:
2847        if 'header' in definition:
2848            headers.append(definition['header'])
2849    if args.mode == 'user':
2850        headers.append(parsed.uapi_header)
2851    seen_header = []
2852    for one in headers:
2853        if one not in seen_header:
2854            cw.p(f"#include <{one}>")
2855            seen_header.append(one)
2856    cw.nl()
2857
2858    if args.mode == "user":
2859        if not args.header:
2860            cw.p("#include <linux/genetlink.h>")
2861            cw.nl()
2862            for one in args.user_header:
2863                cw.p(f'#include "{one}"')
2864        else:
2865            cw.p('struct ynl_sock;')
2866            cw.nl()
2867            render_user_family(parsed, cw, True)
2868        cw.nl()
2869
2870    if args.mode == "kernel":
2871        if args.header:
2872            for _, struct in sorted(parsed.pure_nested_structs.items()):
2873                if struct.request:
2874                    cw.p('/* Common nested types */')
2875                    break
2876            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2877                if struct.request:
2878                    print_req_policy_fwd(cw, struct)
2879            cw.nl()
2880
2881            if parsed.kernel_policy == 'global':
2882                cw.p(f"/* Global operation policy for {parsed.name} */")
2883
2884                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2885                print_req_policy_fwd(cw, struct)
2886                cw.nl()
2887
2888            if parsed.kernel_policy in {'per-op', 'split'}:
2889                for op_name, op in parsed.ops.items():
2890                    if 'do' in op and 'event' not in op:
2891                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
2892                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2893                        cw.nl()
2894
2895            print_kernel_op_table_hdr(parsed, cw)
2896            print_kernel_mcgrp_hdr(parsed, cw)
2897            print_kernel_family_struct_hdr(parsed, cw)
2898        else:
2899            print_kernel_policy_ranges(parsed, cw)
2900
2901            for _, struct in sorted(parsed.pure_nested_structs.items()):
2902                if struct.request:
2903                    cw.p('/* Common nested types */')
2904                    break
2905            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2906                if struct.request:
2907                    print_req_policy(cw, struct)
2908            cw.nl()
2909
2910            if parsed.kernel_policy == 'global':
2911                cw.p(f"/* Global operation policy for {parsed.name} */")
2912
2913                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2914                print_req_policy(cw, struct)
2915                cw.nl()
2916
2917            for op_name, op in parsed.ops.items():
2918                if parsed.kernel_policy in {'per-op', 'split'}:
2919                    for op_mode in ['do', 'dump']:
2920                        if op_mode in op and 'request' in op[op_mode]:
2921                            cw.p(f"/* {op.enum_name} - {op_mode} */")
2922                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
2923                            print_req_policy(cw, ri.struct['request'], ri=ri)
2924                            cw.nl()
2925
2926            print_kernel_op_table(parsed, cw)
2927            print_kernel_mcgrp_src(parsed, cw)
2928            print_kernel_family_struct_src(parsed, cw)
2929
2930    if args.mode == "user":
2931        if args.header:
2932            cw.p('/* Enums */')
2933            put_op_name_fwd(parsed, cw)
2934
2935            for name, const in parsed.consts.items():
2936                if isinstance(const, EnumSet):
2937                    put_enum_to_str_fwd(parsed, cw, const)
2938            cw.nl()
2939
2940            cw.p('/* Common nested types */')
2941            for attr_set, struct in parsed.pure_nested_structs.items():
2942                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2943                print_type_full(ri, struct)
2944
2945            for op_name, op in parsed.ops.items():
2946                cw.p(f"/* ============== {op.enum_name} ============== */")
2947
2948                if 'do' in op and 'event' not in op:
2949                    cw.p(f"/* {op.enum_name} - do */")
2950                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2951                    print_req_type(ri)
2952                    print_req_type_helpers(ri)
2953                    cw.nl()
2954                    print_rsp_type(ri)
2955                    print_rsp_type_helpers(ri)
2956                    cw.nl()
2957                    print_req_prototype(ri)
2958                    cw.nl()
2959
2960                if 'dump' in op:
2961                    cw.p(f"/* {op.enum_name} - dump */")
2962                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
2963                    print_req_type(ri)
2964                    print_req_type_helpers(ri)
2965                    if not ri.type_consistent:
2966                        print_rsp_type(ri)
2967                    print_wrapped_type(ri)
2968                    print_dump_prototype(ri)
2969                    cw.nl()
2970
2971                if op.has_ntf:
2972                    cw.p(f"/* {op.enum_name} - notify */")
2973                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2974                    if not ri.type_consistent:
2975                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2976                    print_wrapped_type(ri)
2977
2978            for op_name, op in parsed.ntfs.items():
2979                if 'event' in op:
2980                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
2981                    cw.p(f"/* {op.enum_name} - event */")
2982                    print_rsp_type(ri)
2983                    cw.nl()
2984                    print_wrapped_type(ri)
2985            cw.nl()
2986        else:
2987            cw.p('/* Enums */')
2988            put_op_name(parsed, cw)
2989
2990            for name, const in parsed.consts.items():
2991                if isinstance(const, EnumSet):
2992                    put_enum_to_str(parsed, cw, const)
2993            cw.nl()
2994
2995            has_recursive_nests = False
2996            cw.p('/* Policies */')
2997            for struct in parsed.pure_nested_structs.values():
2998                if struct.recursive:
2999                    put_typol_fwd(cw, struct)
3000                    has_recursive_nests = True
3001            if has_recursive_nests:
3002                cw.nl()
3003            for name in parsed.pure_nested_structs:
3004                struct = Struct(parsed, name)
3005                put_typol(cw, struct)
3006            for name in parsed.root_sets:
3007                struct = Struct(parsed, name)
3008                put_typol(cw, struct)
3009
3010            cw.p('/* Common nested types */')
3011            if has_recursive_nests:
3012                for attr_set, struct in parsed.pure_nested_structs.items():
3013                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3014                    free_rsp_nested_prototype(ri)
3015                    if struct.request:
3016                        put_req_nested_prototype(ri, struct)
3017                    if struct.reply:
3018                        parse_rsp_nested_prototype(ri, struct)
3019                cw.nl()
3020            for attr_set, struct in parsed.pure_nested_structs.items():
3021                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3022
3023                free_rsp_nested(ri, struct)
3024                if struct.request:
3025                    put_req_nested(ri, struct)
3026                if struct.reply:
3027                    parse_rsp_nested(ri, struct)
3028
3029            for op_name, op in parsed.ops.items():
3030                cw.p(f"/* ============== {op.enum_name} ============== */")
3031                if 'do' in op and 'event' not in op:
3032                    cw.p(f"/* {op.enum_name} - do */")
3033                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3034                    print_req_free(ri)
3035                    print_rsp_free(ri)
3036                    parse_rsp_msg(ri)
3037                    print_req(ri)
3038                    cw.nl()
3039
3040                if 'dump' in op:
3041                    cw.p(f"/* {op.enum_name} - dump */")
3042                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
3043                    if not ri.type_consistent:
3044                        parse_rsp_msg(ri, deref=True)
3045                    print_req_free(ri)
3046                    print_dump_type_free(ri)
3047                    print_dump(ri)
3048                    cw.nl()
3049
3050                if op.has_ntf:
3051                    cw.p(f"/* {op.enum_name} - notify */")
3052                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3053                    if not ri.type_consistent:
3054                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3055                    print_ntf_type_free(ri)
3056
3057            for op_name, op in parsed.ntfs.items():
3058                if 'event' in op:
3059                    cw.p(f"/* {op.enum_name} - event */")
3060
3061                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3062                    parse_rsp_msg(ri)
3063
3064                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
3065                    print_ntf_type_free(ri)
3066            cw.nl()
3067            render_user_family(parsed, cw, False)
3068
3069    if args.header:
3070        cw.p(f'#endif /* {hdr_prot} */')
3071
3072
3073if __name__ == "__main__":
3074    main()
3075