xref: /linux/tools/net/ynl/pyynl/ynl_gen_c.py (revision d0309c054362a235077327b46f727bc48878a3bc)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import filecmp
6import pathlib
7import os
8import re
9import shutil
10import sys
11import tempfile
12import yaml
13
14sys.path.append(pathlib.Path(__file__).resolve().parent.as_posix())
15from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
16from lib import SpecSubMessage
17
18
19def c_upper(name):
20    return name.upper().replace('-', '_')
21
22
23def c_lower(name):
24    return name.lower().replace('-', '_')
25
26
27def limit_to_number(name):
28    """
29    Turn a string limit like u32-max or s64-min into its numerical value
30    """
31    if name[0] == 'u' and name.endswith('-min'):
32        return 0
33    width = int(name[1:-4])
34    if name[0] == 's':
35        width -= 1
36    value = (1 << width) - 1
37    if name[0] == 's' and name.endswith('-min'):
38        value = -value - 1
39    return value
40
41
42class BaseNlLib:
43    def get_family_id(self):
44        return 'ys->family_id'
45
46
47class Type(SpecAttr):
48    def __init__(self, family, attr_set, attr, value):
49        super().__init__(family, attr_set, attr, value)
50
51        self.attr = attr
52        self.attr_set = attr_set
53        self.type = attr['type']
54        self.checks = attr.get('checks', {})
55
56        self.request = False
57        self.reply = False
58
59        self.is_selector = False
60
61        if 'len' in attr:
62            self.len = attr['len']
63
64        if 'nested-attributes' in attr:
65            nested = attr['nested-attributes']
66        elif 'sub-message' in attr:
67            nested = attr['sub-message']
68        else:
69            nested = None
70
71        if nested:
72            self.nested_attrs = nested
73            if self.nested_attrs == family.name:
74                self.nested_render_name = c_lower(f"{family.ident_name}")
75            else:
76                self.nested_render_name = c_lower(f"{family.ident_name}_{self.nested_attrs}")
77
78            if self.nested_attrs in self.family.consts:
79                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
80            else:
81                self.nested_struct_type = 'struct ' + self.nested_render_name
82
83        self.c_name = c_lower(self.name)
84        if self.c_name in _C_KW:
85            self.c_name += '_'
86        if self.c_name[0].isdigit():
87            self.c_name = '_' + self.c_name
88
89        # Added by resolve():
90        self.enum_name = None
91        delattr(self, "enum_name")
92
93    def _get_real_attr(self):
94        # if the attr is for a subset return the "real" attr (just one down, does not recurse)
95        return self.family.attr_sets[self.attr_set.subset_of][self.name]
96
97    def set_request(self):
98        self.request = True
99        if self.attr_set.subset_of:
100            self._get_real_attr().set_request()
101
102    def set_reply(self):
103        self.reply = True
104        if self.attr_set.subset_of:
105            self._get_real_attr().set_reply()
106
107    def get_limit(self, limit, default=None):
108        value = self.checks.get(limit, default)
109        if value is None:
110            return value
111        if isinstance(value, int):
112            return value
113        if value in self.family.consts:
114            return self.family.consts[value]["value"]
115        return limit_to_number(value)
116
117    def get_limit_str(self, limit, default=None, suffix=''):
118        value = self.checks.get(limit, default)
119        if value is None:
120            return ''
121        if isinstance(value, int):
122            return str(value) + suffix
123        if value in self.family.consts:
124            const = self.family.consts[value]
125            if const.get('header'):
126                return c_upper(value)
127            return c_upper(f"{self.family['name']}-{value}")
128        return c_upper(value)
129
130    def resolve(self):
131        if 'parent-sub-message' in self.attr:
132            enum_name = self.attr['parent-sub-message'].enum_name
133        elif 'name-prefix' in self.attr:
134            enum_name = f"{self.attr['name-prefix']}{self.name}"
135        else:
136            enum_name = f"{self.attr_set.name_prefix}{self.name}"
137        self.enum_name = c_upper(enum_name)
138
139        if self.attr_set.subset_of:
140            if self.checks != self._get_real_attr().checks:
141                raise Exception("Overriding checks not supported by codegen, yet")
142
143    def is_multi_val(self):
144        return None
145
146    def is_scalar(self):
147        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
148
149    def is_recursive(self):
150        return False
151
152    def is_recursive_for_op(self, ri):
153        return self.is_recursive() and not ri.op
154
155    def presence_type(self):
156        return 'present'
157
158    def presence_member(self, space, type_filter):
159        if self.presence_type() != type_filter:
160            return
161
162        if self.presence_type() == 'present':
163            pfx = '__' if space == 'user' else ''
164            return f"{pfx}u32 {self.c_name}:1;"
165
166        if self.presence_type() in {'len', 'count'}:
167            pfx = '__' if space == 'user' else ''
168            return f"{pfx}u32 {self.c_name};"
169
170    def _complex_member_type(self, ri):
171        return None
172
173    def free_needs_iter(self):
174        return False
175
176    def _free_lines(self, ri, var, ref):
177        if self.is_multi_val() or self.presence_type() in {'count', 'len'}:
178            return [f'free({var}->{ref}{self.c_name});']
179        return []
180
181    def free(self, ri, var, ref):
182        lines = self._free_lines(ri, var, ref)
183        for line in lines:
184            ri.cw.p(line)
185
186    def arg_member(self, ri):
187        member = self._complex_member_type(ri)
188        if member:
189            spc = ' ' if member[-1] != '*' else ''
190            arg = [member + spc + '*' + self.c_name]
191            if self.presence_type() == 'count':
192                arg += ['unsigned int n_' + self.c_name]
193            return arg
194        raise Exception(f"Struct member not implemented for class type {self.type}")
195
196    def struct_member(self, ri):
197        member = self._complex_member_type(ri)
198        if member:
199            ptr = '*' if self.is_multi_val() else ''
200            if self.is_recursive_for_op(ri):
201                ptr = '*'
202            spc = ' ' if member[-1] != '*' else ''
203            ri.cw.p(f"{member}{spc}{ptr}{self.c_name};")
204            return
205        members = self.arg_member(ri)
206        for one in members:
207            ri.cw.p(one + ';')
208
209    def _attr_policy(self, policy):
210        return '{ .type = ' + policy + ', }'
211
212    def attr_policy(self, cw):
213        policy = f'NLA_{c_upper(self.type)}'
214        if self.attr.get('byte-order') == 'big-endian':
215            if self.type in {'u16', 'u32'}:
216                policy = f'NLA_BE{self.type[1:]}'
217
218        spec = self._attr_policy(policy)
219        cw.p(f"\t[{self.enum_name}] = {spec},")
220
221    def _attr_typol(self):
222        raise Exception(f"Type policy not implemented for class type {self.type}")
223
224    def attr_typol(self, cw):
225        typol = self._attr_typol()
226        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
227
228    def _attr_put_line(self, ri, var, line):
229        presence = self.presence_type()
230        if presence in {'present', 'len'}:
231            ri.cw.p(f"if ({var}->_{presence}.{self.c_name})")
232        ri.cw.p(f"{line};")
233
234    def _attr_put_simple(self, ri, var, put_type):
235        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
236        self._attr_put_line(ri, var, line)
237
238    def attr_put(self, ri, var):
239        raise Exception(f"Put not implemented for class type {self.type}")
240
241    def _attr_get(self, ri, var):
242        raise Exception(f"Attr get not implemented for class type {self.type}")
243
244    def attr_get(self, ri, var, first):
245        lines, init_lines, _ = self._attr_get(ri, var)
246        if type(lines) is str:
247            lines = [lines]
248        if type(init_lines) is str:
249            init_lines = [init_lines]
250
251        kw = 'if' if first else 'else if'
252        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
253
254        if not self.is_multi_val():
255            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
256            ri.cw.p("return YNL_PARSE_CB_ERROR;")
257            if self.presence_type() == 'present':
258                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
259
260        if init_lines:
261            ri.cw.nl()
262            for line in init_lines:
263                ri.cw.p(line)
264
265        for line in lines:
266            ri.cw.p(line)
267        ri.cw.block_end()
268        return True
269
270    def _setter_lines(self, ri, member, presence):
271        raise Exception(f"Setter not implemented for class type {self.type}")
272
273    def setter(self, ri, space, direction, deref=False, ref=None, var="req"):
274        ref = (ref if ref else []) + [self.c_name]
275        member = f"{var}->{'.'.join(ref)}"
276
277        local_vars = []
278        if self.free_needs_iter():
279            local_vars += ['unsigned int i;']
280
281        code = []
282        presence = ''
283        for i in range(0, len(ref)):
284            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
285            # Every layer below last is a nest, so we know it uses bit presence
286            # last layer is "self" and may be a complex type
287            if i == len(ref) - 1 and self.presence_type() != 'present':
288                presence = f"{var}->{'.'.join(ref[:i] + [''])}_{self.presence_type()}.{ref[i]}"
289                continue
290            code.append(presence + ' = 1;')
291        ref_path = '.'.join(ref[:-1])
292        if ref_path:
293            ref_path += '.'
294        code += self._free_lines(ri, var, ref_path)
295        code += self._setter_lines(ri, member, presence)
296
297        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
298        free = bool([x for x in code if 'free(' in x])
299        alloc = bool([x for x in code if 'alloc(' in x])
300        if free and not alloc:
301            func_name = '__' + func_name
302        ri.cw.write_func('static inline void', func_name, local_vars=local_vars,
303                         body=code,
304                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
305
306
307class TypeUnused(Type):
308    def presence_type(self):
309        return ''
310
311    def arg_member(self, ri):
312        return []
313
314    def _attr_get(self, ri, var):
315        return ['return YNL_PARSE_CB_ERROR;'], None, None
316
317    def _attr_typol(self):
318        return '.type = YNL_PT_REJECT, '
319
320    def attr_policy(self, cw):
321        pass
322
323    def attr_put(self, ri, var):
324        pass
325
326    def attr_get(self, ri, var, first):
327        pass
328
329    def setter(self, ri, space, direction, deref=False, ref=None, var=None):
330        pass
331
332
333class TypePad(Type):
334    def presence_type(self):
335        return ''
336
337    def arg_member(self, ri):
338        return []
339
340    def _attr_typol(self):
341        return '.type = YNL_PT_IGNORE, '
342
343    def attr_put(self, ri, var):
344        pass
345
346    def attr_get(self, ri, var, first):
347        pass
348
349    def attr_policy(self, cw):
350        pass
351
352    def setter(self, ri, space, direction, deref=False, ref=None, var=None):
353        pass
354
355
356class TypeScalar(Type):
357    def __init__(self, family, attr_set, attr, value):
358        super().__init__(family, attr_set, attr, value)
359
360        self.byte_order_comment = ''
361        if 'byte-order' in attr:
362            self.byte_order_comment = f" /* {attr['byte-order']} */"
363
364        # Classic families have some funny enums, don't bother
365        # computing checks, since we only need them for kernel policies
366        if not family.is_classic():
367            self._init_checks()
368
369        # Added by resolve():
370        self.is_bitfield = None
371        delattr(self, "is_bitfield")
372        self.type_name = None
373        delattr(self, "type_name")
374
375    def resolve(self):
376        self.resolve_up(super())
377
378        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
379            self.is_bitfield = True
380        elif 'enum' in self.attr:
381            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
382        else:
383            self.is_bitfield = False
384
385        if not self.is_bitfield and 'enum' in self.attr:
386            self.type_name = self.family.consts[self.attr['enum']].user_type
387        elif self.is_auto_scalar:
388            self.type_name = '__' + self.type[0] + '64'
389        else:
390            self.type_name = '__' + self.type
391
392    def _init_checks(self):
393        if 'enum' in self.attr:
394            enum = self.family.consts[self.attr['enum']]
395            low, high = enum.value_range()
396            if low is None and high is None:
397                self.checks['sparse'] = True
398            else:
399                if 'min' not in self.checks:
400                    if low != 0 or self.type[0] == 's':
401                        self.checks['min'] = low
402                if 'max' not in self.checks:
403                    self.checks['max'] = high
404
405        if 'min' in self.checks and 'max' in self.checks:
406            if self.get_limit('min') > self.get_limit('max'):
407                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
408            self.checks['range'] = True
409
410        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
411        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
412        if low < 0 and self.type[0] == 'u':
413            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
414        if low < -32768 or high > 32767:
415            self.checks['full-range'] = True
416
417    def _attr_policy(self, policy):
418        if 'flags-mask' in self.checks or self.is_bitfield:
419            if self.is_bitfield:
420                enum = self.family.consts[self.attr['enum']]
421                mask = enum.get_mask(as_flags=True)
422            else:
423                flags = self.family.consts[self.checks['flags-mask']]
424                flag_cnt = len(flags['entries'])
425                mask = (1 << flag_cnt) - 1
426            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
427        elif 'full-range' in self.checks:
428            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
429        elif 'range' in self.checks:
430            return f"NLA_POLICY_RANGE({policy}, {self.get_limit_str('min')}, {self.get_limit_str('max')})"
431        elif 'min' in self.checks:
432            return f"NLA_POLICY_MIN({policy}, {self.get_limit_str('min')})"
433        elif 'max' in self.checks:
434            return f"NLA_POLICY_MAX({policy}, {self.get_limit_str('max')})"
435        elif 'sparse' in self.checks:
436            return f"NLA_POLICY_VALIDATE_FN({policy}, &{c_lower(self.enum_name)}_validate)"
437        return super()._attr_policy(policy)
438
439    def _attr_typol(self):
440        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
441
442    def arg_member(self, ri):
443        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
444
445    def attr_put(self, ri, var):
446        self._attr_put_simple(ri, var, self.type)
447
448    def _attr_get(self, ri, var):
449        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
450
451    def _setter_lines(self, ri, member, presence):
452        return [f"{member} = {self.c_name};"]
453
454
455class TypeFlag(Type):
456    def arg_member(self, ri):
457        return []
458
459    def _attr_typol(self):
460        return '.type = YNL_PT_FLAG, '
461
462    def attr_put(self, ri, var):
463        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
464
465    def _attr_get(self, ri, var):
466        return [], None, None
467
468    def _setter_lines(self, ri, member, presence):
469        return []
470
471
472class TypeString(Type):
473    def arg_member(self, ri):
474        return [f"const char *{self.c_name}"]
475
476    def presence_type(self):
477        return 'len'
478
479    def struct_member(self, ri):
480        ri.cw.p(f"char *{self.c_name};")
481
482    def _attr_typol(self):
483        typol = '.type = YNL_PT_NUL_STR, '
484        if self.is_selector:
485            typol += '.is_selector = 1, '
486        return typol
487
488    def _attr_policy(self, policy):
489        if 'exact-len' in self.checks:
490            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
491        else:
492            mem = '{ .type = ' + policy
493            if 'max-len' in self.checks:
494                mem += ', .len = ' + self.get_limit_str('max-len')
495            mem += ', }'
496        return mem
497
498    def attr_policy(self, cw):
499        if self.checks.get('unterminated-ok', False):
500            policy = 'NLA_STRING'
501        else:
502            policy = 'NLA_NUL_STRING'
503
504        spec = self._attr_policy(policy)
505        cw.p(f"\t[{self.enum_name}] = {spec},")
506
507    def attr_put(self, ri, var):
508        self._attr_put_simple(ri, var, 'str')
509
510    def _attr_get(self, ri, var):
511        len_mem = var + '->_len.' + self.c_name
512        return [f"{len_mem} = len;",
513                f"{var}->{self.c_name} = malloc(len + 1);",
514                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
515                f"{var}->{self.c_name}[len] = 0;"], \
516               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
517               ['unsigned int len;']
518
519    def _setter_lines(self, ri, member, presence):
520        return [f"{presence} = strlen({self.c_name});",
521                f"{member} = malloc({presence} + 1);",
522                f'memcpy({member}, {self.c_name}, {presence});',
523                f'{member}[{presence}] = 0;']
524
525
526class TypeBinary(Type):
527    def arg_member(self, ri):
528        return [f"const void *{self.c_name}", 'size_t len']
529
530    def presence_type(self):
531        return 'len'
532
533    def struct_member(self, ri):
534        ri.cw.p(f"void *{self.c_name};")
535
536    def _attr_typol(self):
537        return '.type = YNL_PT_BINARY,'
538
539    def _attr_policy(self, policy):
540        if len(self.checks) == 0:
541            pass
542        elif len(self.checks) == 1:
543            check_name = list(self.checks)[0]
544            if check_name not in {'exact-len', 'min-len', 'max-len'}:
545                raise Exception('Unsupported check for binary type: ' + check_name)
546        else:
547            raise Exception('More than one check for binary type not implemented, yet')
548
549        if len(self.checks) == 0:
550            mem = '{ .type = NLA_BINARY, }'
551        elif 'exact-len' in self.checks:
552            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
553        elif 'min-len' in self.checks:
554            mem = 'NLA_POLICY_MIN_LEN(' + self.get_limit_str('min-len') + ')'
555        elif 'max-len' in self.checks:
556            mem = 'NLA_POLICY_MAX_LEN(' + self.get_limit_str('max-len') + ')'
557
558        return mem
559
560    def attr_put(self, ri, var):
561        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
562                            f"{var}->{self.c_name}, {var}->_len.{self.c_name})")
563
564    def _attr_get(self, ri, var):
565        len_mem = var + '->_len.' + self.c_name
566        return [f"{len_mem} = len;",
567                f"{var}->{self.c_name} = malloc(len);",
568                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
569               ['len = ynl_attr_data_len(attr);'], \
570               ['unsigned int len;']
571
572    def _setter_lines(self, ri, member, presence):
573        return [f"{presence} = len;",
574                f"{member} = malloc({presence});",
575                f'memcpy({member}, {self.c_name}, {presence});']
576
577
578class TypeBinaryStruct(TypeBinary):
579    def struct_member(self, ri):
580        ri.cw.p(f'struct {c_lower(self.get("struct"))} *{self.c_name};')
581
582    def _attr_get(self, ri, var):
583        struct_sz = 'sizeof(struct ' + c_lower(self.get("struct")) + ')'
584        len_mem = var + '->_' + self.presence_type() + '.' + self.c_name
585        return [f"{len_mem} = len;",
586                f"if (len < {struct_sz})",
587                f"{var}->{self.c_name} = calloc(1, {struct_sz});",
588                "else",
589                f"{var}->{self.c_name} = malloc(len);",
590                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
591               ['len = ynl_attr_data_len(attr);'], \
592               ['unsigned int len;']
593
594
595class TypeBinaryScalarArray(TypeBinary):
596    def arg_member(self, ri):
597        return [f'__{self.get("sub-type")} *{self.c_name}', 'size_t count']
598
599    def presence_type(self):
600        return 'count'
601
602    def struct_member(self, ri):
603        ri.cw.p(f'__{self.get("sub-type")} *{self.c_name};')
604
605    def attr_put(self, ri, var):
606        presence = self.presence_type()
607        ri.cw.block_start(line=f"if ({var}->_{presence}.{self.c_name})")
608        ri.cw.p(f"i = {var}->_{presence}.{self.c_name} * sizeof(__{self.get('sub-type')});")
609        ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, " +
610                f"{var}->{self.c_name}, i);")
611        ri.cw.block_end()
612
613    def _attr_get(self, ri, var):
614        len_mem = var + '->_count.' + self.c_name
615        return [f"{len_mem} = len / sizeof(__{self.get('sub-type')});",
616                f"len = {len_mem} * sizeof(__{self.get('sub-type')});",
617                f"{var}->{self.c_name} = malloc(len);",
618                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
619               ['len = ynl_attr_data_len(attr);'], \
620               ['unsigned int len;']
621
622    def _setter_lines(self, ri, member, presence):
623        return [f"{presence} = count;",
624                f"count *= sizeof(__{self.get('sub-type')});",
625                f"{member} = malloc(count);",
626                f'memcpy({member}, {self.c_name}, count);']
627
628
629class TypeBitfield32(Type):
630    def _complex_member_type(self, ri):
631        return "struct nla_bitfield32"
632
633    def _attr_typol(self):
634        return '.type = YNL_PT_BITFIELD32, '
635
636    def _attr_policy(self, policy):
637        if 'enum' not in self.attr:
638            raise Exception('Enum required for bitfield32 attr')
639        enum = self.family.consts[self.attr['enum']]
640        mask = enum.get_mask(as_flags=True)
641        return f"NLA_POLICY_BITFIELD32({mask})"
642
643    def attr_put(self, ri, var):
644        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
645        self._attr_put_line(ri, var, line)
646
647    def _attr_get(self, ri, var):
648        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
649
650    def _setter_lines(self, ri, member, presence):
651        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
652
653
654class TypeNest(Type):
655    def is_recursive(self):
656        return self.family.pure_nested_structs[self.nested_attrs].recursive
657
658    def _complex_member_type(self, ri):
659        return self.nested_struct_type
660
661    def _free_lines(self, ri, var, ref):
662        lines = []
663        at = '&'
664        if self.is_recursive_for_op(ri):
665            at = ''
666            lines += [f'if ({var}->{ref}{self.c_name})']
667        lines += [f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});']
668        return lines
669
670    def _attr_typol(self):
671        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
672
673    def _attr_policy(self, policy):
674        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
675
676    def attr_put(self, ri, var):
677        at = '' if self.is_recursive_for_op(ri) else '&'
678        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
679                            f"{self.enum_name}, {at}{var}->{self.c_name})")
680
681    def _attr_get(self, ri, var):
682        pns = self.family.pure_nested_structs[self.nested_attrs]
683        args = ["&parg", "attr"]
684        for sel in pns.external_selectors():
685            args.append(f'{var}->{sel.name}')
686        get_lines = [f"if ({self.nested_render_name}_parse({', '.join(args)}))",
687                     "return YNL_PARSE_CB_ERROR;"]
688        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
689                      f"parg.data = &{var}->{self.c_name};"]
690        return get_lines, init_lines, None
691
692    def setter(self, ri, space, direction, deref=False, ref=None, var="req"):
693        ref = (ref if ref else []) + [self.c_name]
694
695        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
696            if attr.is_recursive():
697                continue
698            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref,
699                        var=var)
700
701
702class TypeMultiAttr(Type):
703    def __init__(self, family, attr_set, attr, value, base_type):
704        super().__init__(family, attr_set, attr, value)
705
706        self.base_type = base_type
707
708    def is_multi_val(self):
709        return True
710
711    def presence_type(self):
712        return 'count'
713
714    def _complex_member_type(self, ri):
715        if 'type' not in self.attr or self.attr['type'] == 'nest':
716            return self.nested_struct_type
717        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
718            return None  # use arg_member()
719        elif self.attr['type'] == 'string':
720            return 'struct ynl_string *'
721        elif self.attr['type'] in scalars:
722            scalar_pfx = '__' if ri.ku_space == 'user' else ''
723            if self.is_auto_scalar:
724                name = self.type[0] + '64'
725            else:
726                name = self.attr['type']
727            return scalar_pfx + name
728        else:
729            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
730
731    def arg_member(self, ri):
732        if self.type == 'binary' and 'struct' in self.attr:
733            return [f'struct {c_lower(self.attr["struct"])} *{self.c_name}',
734                    f'unsigned int n_{self.c_name}']
735        return super().arg_member(ri)
736
737    def free_needs_iter(self):
738        return self.attr['type'] in {'nest', 'string'}
739
740    def _free_lines(self, ri, var, ref):
741        lines = []
742        if self.attr['type'] in scalars:
743            lines += [f"free({var}->{ref}{self.c_name});"]
744        elif self.attr['type'] == 'binary':
745            lines += [f"free({var}->{ref}{self.c_name});"]
746        elif self.attr['type'] == 'string':
747            lines += [
748                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
749                f"free({var}->{ref}{self.c_name}[i]);",
750                f"free({var}->{ref}{self.c_name});",
751            ]
752        elif 'type' not in self.attr or self.attr['type'] == 'nest':
753            lines += [
754                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
755                f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);',
756                f"free({var}->{ref}{self.c_name});",
757            ]
758        else:
759            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
760        return lines
761
762    def _attr_policy(self, policy):
763        return self.base_type._attr_policy(policy)
764
765    def _attr_typol(self):
766        return self.base_type._attr_typol()
767
768    def _attr_get(self, ri, var):
769        return f'n_{self.c_name}++;', None, None
770
771    def attr_put(self, ri, var):
772        if self.attr['type'] in scalars:
773            put_type = self.type
774            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
775            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
776        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
777            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
778            ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}[i], sizeof(struct {c_lower(self.attr['struct'])}));")
779        elif self.attr['type'] == 'string':
780            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
781            ri.cw.p(f"ynl_attr_put_str(nlh, {self.enum_name}, {var}->{self.c_name}[i]->str);")
782        elif 'type' not in self.attr or self.attr['type'] == 'nest':
783            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
784            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
785                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
786        else:
787            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
788
789    def _setter_lines(self, ri, member, presence):
790        return [f"{member} = {self.c_name};",
791                f"{presence} = n_{self.c_name};"]
792
793
794class TypeIndexedArray(Type):
795    def is_multi_val(self):
796        return True
797
798    def presence_type(self):
799        return 'count'
800
801    def _complex_member_type(self, ri):
802        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
803            return self.nested_struct_type
804        elif self.attr['sub-type'] in scalars:
805            scalar_pfx = '__' if ri.ku_space == 'user' else ''
806            return scalar_pfx + self.attr['sub-type']
807        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
808            return None  # use arg_member()
809        else:
810            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
811
812    def arg_member(self, ri):
813        if self.sub_type == 'binary' and 'exact-len' in self.checks:
814            return [f'unsigned char (*{self.c_name})[{self.checks["exact-len"]}]',
815                    f'unsigned int n_{self.c_name}']
816        return super().arg_member(ri)
817
818    def _attr_policy(self, policy):
819        if self.attr['sub-type'] == 'nest':
820            return f'NLA_POLICY_NESTED_ARRAY({self.nested_render_name}_nl_policy)'
821        return super()._attr_policy(policy)
822
823    def _attr_typol(self):
824        if self.attr['sub-type'] in scalars:
825            return f'.type = YNL_PT_U{c_upper(self.sub_type[1:])}, '
826        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
827            return f'.type = YNL_PT_BINARY, .len = {self.checks["exact-len"]}, '
828        elif self.attr['sub-type'] == 'nest':
829            return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
830        else:
831            raise Exception(f"Typol for IndexedArray sub-type {self.attr['sub-type']} not supported, yet")
832
833    def _attr_get(self, ri, var):
834        local_vars = ['const struct nlattr *attr2;']
835        get_lines = [f'attr_{self.c_name} = attr;',
836                     'ynl_attr_for_each_nested(attr2, attr) {',
837                     '\tif (__ynl_attr_validate(yarg, attr2, type))',
838                     '\t\treturn YNL_PARSE_CB_ERROR;',
839                     f'\tn_{self.c_name}++;',
840                     '}']
841        return get_lines, None, local_vars
842
843    def attr_put(self, ri, var):
844        ri.cw.p(f'array = ynl_attr_nest_start(nlh, {self.enum_name});')
845        if self.sub_type in scalars:
846            put_type = self.sub_type
847            ri.cw.block_start(line=f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
848            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, i, {var}->{self.c_name}[i]);")
849            ri.cw.block_end()
850        elif self.sub_type == 'binary' and 'exact-len' in self.checks:
851            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
852            ri.cw.p(f"ynl_attr_put(nlh, i, {var}->{self.c_name}[i], {self.checks['exact-len']});")
853        elif self.sub_type == 'nest':
854            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
855            ri.cw.p(f"{self.nested_render_name}_put(nlh, i, &{var}->{self.c_name}[i]);")
856        else:
857            raise Exception(f"Put for IndexedArray sub-type {self.attr['sub-type']} not supported, yet")
858        ri.cw.p('ynl_attr_nest_end(nlh, array);')
859
860    def _setter_lines(self, ri, member, presence):
861        return [f"{member} = {self.c_name};",
862                f"{presence} = n_{self.c_name};"]
863
864    def free_needs_iter(self):
865        return self.sub_type == 'nest'
866
867    def _free_lines(self, ri, var, ref):
868        lines = []
869        if self.sub_type == 'nest':
870            lines += [
871                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
872                f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);',
873            ]
874        lines += f"free({var}->{ref}{self.c_name});",
875        return lines
876
877class TypeNestTypeValue(Type):
878    def _complex_member_type(self, ri):
879        return self.nested_struct_type
880
881    def _attr_typol(self):
882        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
883
884    def _attr_get(self, ri, var):
885        prev = 'attr'
886        tv_args = ''
887        get_lines = []
888        local_vars = []
889        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
890                      f"parg.data = &{var}->{self.c_name};"]
891        if 'type-value' in self.attr:
892            tv_names = [c_lower(x) for x in self.attr["type-value"]]
893            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
894            local_vars += [f'__u32 {", ".join(tv_names)};']
895            for level in self.attr["type-value"]:
896                level = c_lower(level)
897                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
898                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
899                prev = 'attr_' + level
900
901            tv_args = f", {', '.join(tv_names)}"
902
903        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
904        return get_lines, init_lines, local_vars
905
906
907class TypeSubMessage(TypeNest):
908    def __init__(self, family, attr_set, attr, value):
909        super().__init__(family, attr_set, attr, value)
910
911        self.selector = Selector(attr, attr_set)
912
913    def _attr_typol(self):
914        typol = f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
915        typol += '.is_submsg = 1, '
916        # Reverse-parsing of the policy (ynl_err_walk() in ynl.c) does not
917        # support external selectors. No family uses sub-messages with external
918        # selector for requests so this is fine for now.
919        if not self.selector.is_external():
920            typol += f'.selector_type = {self.attr_set[self["selector"]].value} '
921        return typol
922
923    def _attr_get(self, ri, var):
924        sel = c_lower(self['selector'])
925        if self.selector.is_external():
926            sel_var = f"_sel_{sel}"
927        else:
928            sel_var = f"{var}->{sel}"
929        get_lines = [f'if (!{sel_var})',
930                     'return ynl_submsg_failed(yarg, "%s", "%s");' %
931                        (self.name, self['selector']),
932                    f"if ({self.nested_render_name}_parse(&parg, {sel_var}, attr))",
933                     "return YNL_PARSE_CB_ERROR;"]
934        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
935                      f"parg.data = &{var}->{self.c_name};"]
936        return get_lines, init_lines, None
937
938
939class Selector:
940    def __init__(self, msg_attr, attr_set):
941        self.name = msg_attr["selector"]
942
943        if self.name in attr_set:
944            self.attr = attr_set[self.name]
945            self.attr.is_selector = True
946            self._external = False
947        else:
948            # The selector will need to get passed down thru the structs
949            self.attr = None
950            self._external = True
951
952    def set_attr(self, attr):
953        self.attr = attr
954
955    def is_external(self):
956        return self._external
957
958
959class Struct:
960    def __init__(self, family, space_name, type_list=None, fixed_header=None,
961                 inherited=None, submsg=None):
962        self.family = family
963        self.space_name = space_name
964        self.attr_set = family.attr_sets[space_name]
965        # Use list to catch comparisons with empty sets
966        self._inherited = inherited if inherited is not None else []
967        self.inherited = []
968        self.fixed_header = None
969        if fixed_header:
970            self.fixed_header = 'struct ' + c_lower(fixed_header)
971        self.submsg = submsg
972
973        self.nested = type_list is None
974        if family.name == c_lower(space_name):
975            self.render_name = c_lower(family.ident_name)
976        else:
977            self.render_name = c_lower(family.ident_name + '-' + space_name)
978        self.struct_name = 'struct ' + self.render_name
979        if self.nested and space_name in family.consts:
980            self.struct_name += '_'
981        self.ptr_name = self.struct_name + ' *'
982        # All attr sets this one contains, directly or multiple levels down
983        self.child_nests = set()
984
985        self.request = False
986        self.reply = False
987        self.recursive = False
988        self.in_multi_val = False  # used by a MultiAttr or and legacy arrays
989
990        self.attr_list = []
991        self.attrs = dict()
992        if type_list is not None:
993            for t in type_list:
994                self.attr_list.append((t, self.attr_set[t]),)
995        else:
996            for t in self.attr_set:
997                self.attr_list.append((t, self.attr_set[t]),)
998
999        max_val = 0
1000        self.attr_max_val = None
1001        for name, attr in self.attr_list:
1002            if attr.value >= max_val:
1003                max_val = attr.value
1004                self.attr_max_val = attr
1005            self.attrs[name] = attr
1006
1007    def __iter__(self):
1008        yield from self.attrs
1009
1010    def __getitem__(self, key):
1011        return self.attrs[key]
1012
1013    def member_list(self):
1014        return self.attr_list
1015
1016    def set_inherited(self, new_inherited):
1017        if self._inherited != new_inherited:
1018            raise Exception("Inheriting different members not supported")
1019        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
1020
1021    def external_selectors(self):
1022        sels = []
1023        for name, attr in self.attr_list:
1024            if isinstance(attr, TypeSubMessage) and attr.selector.is_external():
1025                sels.append(attr.selector)
1026        return sels
1027
1028    def free_needs_iter(self):
1029        for _, attr in self.attr_list:
1030            if attr.free_needs_iter():
1031                return True
1032        return False
1033
1034
1035class EnumEntry(SpecEnumEntry):
1036    def __init__(self, enum_set, yaml, prev, value_start):
1037        super().__init__(enum_set, yaml, prev, value_start)
1038
1039        if prev:
1040            self.value_change = (self.value != prev.value + 1)
1041        else:
1042            self.value_change = (self.value != 0)
1043        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
1044
1045        # Added by resolve:
1046        self.c_name = None
1047        delattr(self, "c_name")
1048
1049    def resolve(self):
1050        self.resolve_up(super())
1051
1052        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
1053
1054
1055class EnumSet(SpecEnumSet):
1056    def __init__(self, family, yaml):
1057        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1058
1059        if 'enum-name' in yaml:
1060            if yaml['enum-name']:
1061                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
1062                self.user_type = self.enum_name
1063            else:
1064                self.enum_name = None
1065        else:
1066            self.enum_name = 'enum ' + self.render_name
1067
1068        if self.enum_name:
1069            self.user_type = self.enum_name
1070        else:
1071            self.user_type = 'int'
1072
1073        self.value_pfx = yaml.get('name-prefix', f"{family.ident_name}-{yaml['name']}-")
1074        self.header = yaml.get('header', None)
1075        self.enum_cnt_name = yaml.get('enum-cnt-name', None)
1076
1077        super().__init__(family, yaml)
1078
1079    def new_entry(self, entry, prev_entry, value_start):
1080        return EnumEntry(self, entry, prev_entry, value_start)
1081
1082    def value_range(self):
1083        low = min([x.value for x in self.entries.values()])
1084        high = max([x.value for x in self.entries.values()])
1085
1086        if high - low + 1 != len(self.entries):
1087            return None, None
1088
1089        return low, high
1090
1091
1092class AttrSet(SpecAttrSet):
1093    def __init__(self, family, yaml):
1094        super().__init__(family, yaml)
1095
1096        if self.subset_of is None:
1097            if 'name-prefix' in yaml:
1098                pfx = yaml['name-prefix']
1099            elif self.name == family.name:
1100                pfx = family.ident_name + '-a-'
1101            else:
1102                pfx = f"{family.ident_name}-a-{self.name}-"
1103            self.name_prefix = c_upper(pfx)
1104            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
1105            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
1106        else:
1107            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
1108            self.max_name = family.attr_sets[self.subset_of].max_name
1109            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
1110
1111        # Added by resolve:
1112        self.c_name = None
1113        delattr(self, "c_name")
1114
1115    def resolve(self):
1116        self.c_name = c_lower(self.name)
1117        if self.c_name in _C_KW:
1118            self.c_name += '_'
1119        if self.c_name == self.family.c_name:
1120            self.c_name = ''
1121
1122    def new_attr(self, elem, value):
1123        if elem['type'] in scalars:
1124            t = TypeScalar(self.family, self, elem, value)
1125        elif elem['type'] == 'unused':
1126            t = TypeUnused(self.family, self, elem, value)
1127        elif elem['type'] == 'pad':
1128            t = TypePad(self.family, self, elem, value)
1129        elif elem['type'] == 'flag':
1130            t = TypeFlag(self.family, self, elem, value)
1131        elif elem['type'] == 'string':
1132            t = TypeString(self.family, self, elem, value)
1133        elif elem['type'] == 'binary':
1134            if 'struct' in elem:
1135                t = TypeBinaryStruct(self.family, self, elem, value)
1136            elif elem.get('sub-type') in scalars:
1137                t = TypeBinaryScalarArray(self.family, self, elem, value)
1138            else:
1139                t = TypeBinary(self.family, self, elem, value)
1140        elif elem['type'] == 'bitfield32':
1141            t = TypeBitfield32(self.family, self, elem, value)
1142        elif elem['type'] == 'nest':
1143            t = TypeNest(self.family, self, elem, value)
1144        elif elem['type'] == 'indexed-array' and 'sub-type' in elem:
1145            if elem["sub-type"] in ['binary', 'nest', 'u32']:
1146                t = TypeIndexedArray(self.family, self, elem, value)
1147            else:
1148                raise Exception(f'new_attr: unsupported sub-type {elem["sub-type"]}')
1149        elif elem['type'] == 'nest-type-value':
1150            t = TypeNestTypeValue(self.family, self, elem, value)
1151        elif elem['type'] == 'sub-message':
1152            t = TypeSubMessage(self.family, self, elem, value)
1153        else:
1154            raise Exception(f"No typed class for type {elem['type']}")
1155
1156        if 'multi-attr' in elem and elem['multi-attr']:
1157            t = TypeMultiAttr(self.family, self, elem, value, t)
1158
1159        return t
1160
1161
1162class Operation(SpecOperation):
1163    def __init__(self, family, yaml, req_value, rsp_value):
1164        # Fill in missing operation properties (for fixed hdr-only msgs)
1165        for mode in ['do', 'dump', 'event']:
1166            for direction in ['request', 'reply']:
1167                try:
1168                    yaml[mode][direction].setdefault('attributes', [])
1169                except KeyError:
1170                    pass
1171
1172        super().__init__(family, yaml, req_value, rsp_value)
1173
1174        self.render_name = c_lower(family.ident_name + '_' + self.name)
1175
1176        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
1177                         ('dump' in yaml and 'request' in yaml['dump'])
1178
1179        self.has_ntf = False
1180
1181        # Added by resolve:
1182        self.enum_name = None
1183        delattr(self, "enum_name")
1184
1185    def resolve(self):
1186        self.resolve_up(super())
1187
1188        if not self.is_async:
1189            self.enum_name = self.family.op_prefix + c_upper(self.name)
1190        else:
1191            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
1192
1193    def mark_has_ntf(self):
1194        self.has_ntf = True
1195
1196
1197class SubMessage(SpecSubMessage):
1198    def __init__(self, family, yaml):
1199        super().__init__(family, yaml)
1200
1201        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1202
1203    def resolve(self):
1204        self.resolve_up(super())
1205
1206
1207class Family(SpecFamily):
1208    def __init__(self, file_name, exclude_ops):
1209        # Added by resolve:
1210        self.c_name = None
1211        delattr(self, "c_name")
1212        self.op_prefix = None
1213        delattr(self, "op_prefix")
1214        self.async_op_prefix = None
1215        delattr(self, "async_op_prefix")
1216        self.mcgrps = None
1217        delattr(self, "mcgrps")
1218        self.consts = None
1219        delattr(self, "consts")
1220        self.hooks = None
1221        delattr(self, "hooks")
1222
1223        super().__init__(file_name, exclude_ops=exclude_ops)
1224
1225        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
1226        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
1227
1228        if 'definitions' not in self.yaml:
1229            self.yaml['definitions'] = []
1230
1231        if 'uapi-header' in self.yaml:
1232            self.uapi_header = self.yaml['uapi-header']
1233        else:
1234            self.uapi_header = f"linux/{self.ident_name}.h"
1235        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
1236            self.uapi_header_name = self.uapi_header[6:-2]
1237        else:
1238            self.uapi_header_name = self.ident_name
1239
1240    def resolve(self):
1241        self.resolve_up(super())
1242
1243        self.c_name = c_lower(self.ident_name)
1244        if 'name-prefix' in self.yaml['operations']:
1245            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
1246        else:
1247            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
1248        if 'async-prefix' in self.yaml['operations']:
1249            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
1250        else:
1251            self.async_op_prefix = self.op_prefix
1252
1253        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
1254
1255        self.hooks = dict()
1256        for when in ['pre', 'post']:
1257            self.hooks[when] = dict()
1258            for op_mode in ['do', 'dump']:
1259                self.hooks[when][op_mode] = dict()
1260                self.hooks[when][op_mode]['set'] = set()
1261                self.hooks[when][op_mode]['list'] = []
1262
1263        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
1264        self.root_sets = dict()
1265        # dict space-name -> Struct
1266        self.pure_nested_structs = dict()
1267
1268        self._mark_notify()
1269        self._mock_up_events()
1270
1271        self._load_root_sets()
1272        self._load_nested_sets()
1273        self._load_attr_use()
1274        self._load_selector_passing()
1275        self._load_hooks()
1276
1277        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
1278        if self.kernel_policy == 'global':
1279            self._load_global_policy()
1280
1281    def new_enum(self, elem):
1282        return EnumSet(self, elem)
1283
1284    def new_attr_set(self, elem):
1285        return AttrSet(self, elem)
1286
1287    def new_operation(self, elem, req_value, rsp_value):
1288        return Operation(self, elem, req_value, rsp_value)
1289
1290    def new_sub_message(self, elem):
1291        return SubMessage(self, elem)
1292
1293    def is_classic(self):
1294        return self.proto == 'netlink-raw'
1295
1296    def _mark_notify(self):
1297        for op in self.msgs.values():
1298            if 'notify' in op:
1299                self.ops[op['notify']].mark_has_ntf()
1300
1301    # Fake a 'do' equivalent of all events, so that we can render their response parsing
1302    def _mock_up_events(self):
1303        for op in self.yaml['operations']['list']:
1304            if 'event' in op:
1305                op['do'] = {
1306                    'reply': {
1307                        'attributes': op['event']['attributes']
1308                    }
1309                }
1310
1311    def _load_root_sets(self):
1312        for op_name, op in self.msgs.items():
1313            if 'attribute-set' not in op:
1314                continue
1315
1316            req_attrs = set()
1317            rsp_attrs = set()
1318            for op_mode in ['do', 'dump']:
1319                if op_mode in op and 'request' in op[op_mode]:
1320                    req_attrs.update(set(op[op_mode]['request']['attributes']))
1321                if op_mode in op and 'reply' in op[op_mode]:
1322                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
1323            if 'event' in op:
1324                rsp_attrs.update(set(op['event']['attributes']))
1325
1326            if op['attribute-set'] not in self.root_sets:
1327                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1328            else:
1329                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1330                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1331
1332    def _sort_pure_types(self):
1333        # Try to reorder according to dependencies
1334        pns_key_list = list(self.pure_nested_structs.keys())
1335        pns_key_seen = set()
1336        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1337        for _ in range(rounds):
1338            if len(pns_key_list) == 0:
1339                break
1340            name = pns_key_list.pop(0)
1341            finished = True
1342            for _, spec in self.attr_sets[name].items():
1343                if 'nested-attributes' in spec:
1344                    nested = spec['nested-attributes']
1345                elif 'sub-message' in spec:
1346                    nested = spec.sub_message
1347                else:
1348                    continue
1349
1350                # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1351                if self.pure_nested_structs[nested].recursive:
1352                    continue
1353                if nested not in pns_key_seen:
1354                    # Dicts are sorted, this will make struct last
1355                    struct = self.pure_nested_structs.pop(name)
1356                    self.pure_nested_structs[name] = struct
1357                    finished = False
1358                    break
1359            if finished:
1360                pns_key_seen.add(name)
1361            else:
1362                pns_key_list.append(name)
1363
1364    def _load_nested_set_nest(self, spec):
1365        inherit = set()
1366        nested = spec['nested-attributes']
1367        if nested not in self.root_sets:
1368            if nested not in self.pure_nested_structs:
1369                self.pure_nested_structs[nested] = \
1370                    Struct(self, nested, inherited=inherit,
1371                           fixed_header=spec.get('fixed-header'))
1372        else:
1373            raise Exception(f'Using attr set as root and nested not supported - {nested}')
1374
1375        if 'type-value' in spec:
1376            if nested in self.root_sets:
1377                raise Exception("Inheriting members to a space used as root not supported")
1378            inherit.update(set(spec['type-value']))
1379        elif spec['type'] == 'indexed-array':
1380            inherit.add('idx')
1381        self.pure_nested_structs[nested].set_inherited(inherit)
1382
1383        return nested
1384
1385    def _load_nested_set_submsg(self, spec):
1386        # Fake the struct type for the sub-message itself
1387        # its not a attr_set but codegen wants attr_sets.
1388        submsg = self.sub_msgs[spec["sub-message"]]
1389        nested = submsg.name
1390
1391        attrs = []
1392        for name, fmt in submsg.formats.items():
1393            attr = {
1394                "name": name,
1395                "parent-sub-message": spec,
1396            }
1397            if 'attribute-set' in fmt:
1398                attr |= {
1399                    "type": "nest",
1400                    "nested-attributes": fmt['attribute-set'],
1401                }
1402                if 'fixed-header' in fmt:
1403                    attr |= { "fixed-header": fmt["fixed-header"] }
1404            elif 'fixed-header' in fmt:
1405                attr |= {
1406                    "type": "binary",
1407                    "struct": fmt["fixed-header"],
1408                }
1409            else:
1410                attr["type"] = "flag"
1411            attrs.append(attr)
1412
1413        self.attr_sets[nested] = AttrSet(self, {
1414            "name": nested,
1415            "name-pfx": self.name + '-' + spec.name + '-',
1416            "attributes": attrs
1417        })
1418
1419        if nested not in self.pure_nested_structs:
1420            self.pure_nested_structs[nested] = Struct(self, nested, submsg=submsg)
1421
1422        return nested
1423
1424    def _load_nested_sets(self):
1425        attr_set_queue = list(self.root_sets.keys())
1426        attr_set_seen = set(self.root_sets.keys())
1427
1428        while len(attr_set_queue):
1429            a_set = attr_set_queue.pop(0)
1430            for attr, spec in self.attr_sets[a_set].items():
1431                if 'nested-attributes' in spec:
1432                    nested = self._load_nested_set_nest(spec)
1433                elif 'sub-message' in spec:
1434                    nested = self._load_nested_set_submsg(spec)
1435                else:
1436                    continue
1437
1438                if nested not in attr_set_seen:
1439                    attr_set_queue.append(nested)
1440                    attr_set_seen.add(nested)
1441
1442        for root_set, rs_members in self.root_sets.items():
1443            for attr, spec in self.attr_sets[root_set].items():
1444                if 'nested-attributes' in spec:
1445                    nested = spec['nested-attributes']
1446                elif 'sub-message' in spec:
1447                    nested = spec.sub_message
1448                else:
1449                    nested = None
1450
1451                if nested:
1452                    if attr in rs_members['request']:
1453                        self.pure_nested_structs[nested].request = True
1454                    if attr in rs_members['reply']:
1455                        self.pure_nested_structs[nested].reply = True
1456
1457                    if spec.is_multi_val():
1458                        child = self.pure_nested_structs.get(nested)
1459                        child.in_multi_val = True
1460
1461        self._sort_pure_types()
1462
1463        # Propagate the request / reply / recursive
1464        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1465            for _, spec in self.attr_sets[attr_set].items():
1466                if attr_set in struct.child_nests:
1467                    struct.recursive = True
1468
1469                if 'nested-attributes' in spec:
1470                    child_name = spec['nested-attributes']
1471                elif 'sub-message' in spec:
1472                    child_name = spec.sub_message
1473                else:
1474                    continue
1475
1476                struct.child_nests.add(child_name)
1477                child = self.pure_nested_structs.get(child_name)
1478                if child:
1479                    if not child.recursive:
1480                        struct.child_nests.update(child.child_nests)
1481                    child.request |= struct.request
1482                    child.reply |= struct.reply
1483                    if spec.is_multi_val():
1484                        child.in_multi_val = True
1485
1486        self._sort_pure_types()
1487
1488    def _load_attr_use(self):
1489        for _, struct in self.pure_nested_structs.items():
1490            if struct.request:
1491                for _, arg in struct.member_list():
1492                    arg.set_request()
1493            if struct.reply:
1494                for _, arg in struct.member_list():
1495                    arg.set_reply()
1496
1497        for root_set, rs_members in self.root_sets.items():
1498            for attr, spec in self.attr_sets[root_set].items():
1499                if attr in rs_members['request']:
1500                    spec.set_request()
1501                if attr in rs_members['reply']:
1502                    spec.set_reply()
1503
1504    def _load_selector_passing(self):
1505        def all_structs():
1506            for k, v in reversed(self.pure_nested_structs.items()):
1507                yield k, v
1508            for k, _ in self.root_sets.items():
1509                yield k, None  # we don't have a struct, but it must be terminal
1510
1511        for attr_set, struct in all_structs():
1512            for _, spec in self.attr_sets[attr_set].items():
1513                if 'nested-attributes' in spec:
1514                    child_name = spec['nested-attributes']
1515                elif 'sub-message' in spec:
1516                    child_name = spec.sub_message
1517                else:
1518                    continue
1519
1520                child = self.pure_nested_structs.get(child_name)
1521                for selector in child.external_selectors():
1522                    if selector.name in self.attr_sets[attr_set]:
1523                        sel_attr = self.attr_sets[attr_set][selector.name]
1524                        selector.set_attr(sel_attr)
1525                    else:
1526                        raise Exception("Passing selector thru more than one layer not supported")
1527
1528    def _load_global_policy(self):
1529        global_set = set()
1530        attr_set_name = None
1531        for op_name, op in self.ops.items():
1532            if not op:
1533                continue
1534            if 'attribute-set' not in op:
1535                continue
1536
1537            if attr_set_name is None:
1538                attr_set_name = op['attribute-set']
1539            if attr_set_name != op['attribute-set']:
1540                raise Exception('For a global policy all ops must use the same set')
1541
1542            for op_mode in ['do', 'dump']:
1543                if op_mode in op:
1544                    req = op[op_mode].get('request')
1545                    if req:
1546                        global_set.update(req.get('attributes', []))
1547
1548        self.global_policy = []
1549        self.global_policy_set = attr_set_name
1550        for attr in self.attr_sets[attr_set_name]:
1551            if attr in global_set:
1552                self.global_policy.append(attr)
1553
1554    def _load_hooks(self):
1555        for op in self.ops.values():
1556            for op_mode in ['do', 'dump']:
1557                if op_mode not in op:
1558                    continue
1559                for when in ['pre', 'post']:
1560                    if when not in op[op_mode]:
1561                        continue
1562                    name = op[op_mode][when]
1563                    if name in self.hooks[when][op_mode]['set']:
1564                        continue
1565                    self.hooks[when][op_mode]['set'].add(name)
1566                    self.hooks[when][op_mode]['list'].append(name)
1567
1568
1569class RenderInfo:
1570    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1571        self.family = family
1572        self.nl = cw.nlib
1573        self.ku_space = ku_space
1574        self.op_mode = op_mode
1575        self.op = op
1576
1577        fixed_hdr = op.fixed_header if op else None
1578        self.fixed_hdr_len = 'ys->family->hdr_len'
1579        if op and op.fixed_header:
1580            if op.fixed_header != family.fixed_header:
1581                if family.is_classic():
1582                    self.fixed_hdr_len = f"sizeof(struct {c_lower(fixed_hdr)})"
1583                else:
1584                    raise Exception("Per-op fixed header not supported, yet")
1585
1586
1587        # 'do' and 'dump' response parsing is identical
1588        self.type_consistent = True
1589        self.type_oneside = False
1590        if op_mode != 'do' and 'dump' in op:
1591            if 'do' in op:
1592                if ('reply' in op['do']) != ('reply' in op["dump"]):
1593                    self.type_consistent = False
1594                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1595                    self.type_consistent = False
1596            else:
1597                self.type_consistent = True
1598                self.type_oneside = True
1599
1600        self.attr_set = attr_set
1601        if not self.attr_set:
1602            self.attr_set = op['attribute-set']
1603
1604        self.type_name_conflict = False
1605        if op:
1606            self.type_name = c_lower(op.name)
1607        else:
1608            self.type_name = c_lower(attr_set)
1609            if attr_set in family.consts:
1610                self.type_name_conflict = True
1611
1612        self.cw = cw
1613
1614        self.struct = dict()
1615        if op_mode == 'notify':
1616            op_mode = 'do' if 'do' in op else 'dump'
1617        for op_dir in ['request', 'reply']:
1618            if op:
1619                type_list = []
1620                if op_dir in op[op_mode]:
1621                    type_list = op[op_mode][op_dir]['attributes']
1622                self.struct[op_dir] = Struct(family, self.attr_set,
1623                                             fixed_header=fixed_hdr,
1624                                             type_list=type_list)
1625        if op_mode == 'event':
1626            self.struct['reply'] = Struct(family, self.attr_set,
1627                                          fixed_header=fixed_hdr,
1628                                          type_list=op['event']['attributes'])
1629
1630    def type_empty(self, key):
1631        return len(self.struct[key].attr_list) == 0 and \
1632            self.struct['request'].fixed_header is None
1633
1634    def needs_nlflags(self, direction):
1635        return self.op_mode == 'do' and direction == 'request' and self.family.is_classic()
1636
1637
1638class CodeWriter:
1639    def __init__(self, nlib, out_file=None, overwrite=True):
1640        self.nlib = nlib
1641        self._overwrite = overwrite
1642
1643        self._nl = False
1644        self._block_end = False
1645        self._silent_block = False
1646        self._ind = 0
1647        self._ifdef_block = None
1648        if out_file is None:
1649            self._out = os.sys.stdout
1650        else:
1651            self._out = tempfile.NamedTemporaryFile('w+')
1652            self._out_file = out_file
1653
1654    def __del__(self):
1655        self.close_out_file()
1656
1657    def close_out_file(self):
1658        if self._out == os.sys.stdout:
1659            return
1660        # Avoid modifying the file if contents didn't change
1661        self._out.flush()
1662        if not self._overwrite and os.path.isfile(self._out_file):
1663            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1664                return
1665        with open(self._out_file, 'w+') as out_file:
1666            self._out.seek(0)
1667            shutil.copyfileobj(self._out, out_file)
1668            self._out.close()
1669        self._out = os.sys.stdout
1670
1671    @classmethod
1672    def _is_cond(cls, line):
1673        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1674
1675    def p(self, line, add_ind=0):
1676        if self._block_end:
1677            self._block_end = False
1678            if line.startswith('else'):
1679                line = '} ' + line
1680            else:
1681                self._out.write('\t' * self._ind + '}\n')
1682
1683        if self._nl:
1684            self._out.write('\n')
1685            self._nl = False
1686
1687        ind = self._ind
1688        if line[-1] == ':':
1689            ind -= 1
1690        if self._silent_block:
1691            ind += 1
1692        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1693        self._silent_block |= line.strip() == 'else'
1694        if line[0] == '#':
1695            ind = 0
1696        if add_ind:
1697            ind += add_ind
1698        self._out.write('\t' * ind + line + '\n')
1699
1700    def nl(self):
1701        self._nl = True
1702
1703    def block_start(self, line=''):
1704        if line:
1705            line = line + ' '
1706        self.p(line + '{')
1707        self._ind += 1
1708
1709    def block_end(self, line=''):
1710        if line and line[0] not in {';', ','}:
1711            line = ' ' + line
1712        self._ind -= 1
1713        self._nl = False
1714        if not line:
1715            # Delay printing closing bracket in case "else" comes next
1716            if self._block_end:
1717                self._out.write('\t' * (self._ind + 1) + '}\n')
1718            self._block_end = True
1719        else:
1720            self.p('}' + line)
1721
1722    def write_doc_line(self, doc, indent=True):
1723        words = doc.split()
1724        line = ' *'
1725        for word in words:
1726            if len(line) + len(word) >= 79:
1727                self.p(line)
1728                line = ' *'
1729                if indent:
1730                    line += '  '
1731            line += ' ' + word
1732        self.p(line)
1733
1734    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1735        if not args:
1736            args = ['void']
1737
1738        if doc:
1739            self.p('/*')
1740            self.p(' * ' + doc)
1741            self.p(' */')
1742
1743        oneline = qual_ret
1744        if qual_ret[-1] != '*':
1745            oneline += ' '
1746        oneline += f"{name}({', '.join(args)}){suffix}"
1747
1748        if len(oneline) < 80:
1749            self.p(oneline)
1750            return
1751
1752        v = qual_ret
1753        if len(v) > 3:
1754            self.p(v)
1755            v = ''
1756        elif qual_ret[-1] != '*':
1757            v += ' '
1758        v += name + '('
1759        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1760        delta_ind = len(v) - len(ind)
1761        v += args[0]
1762        i = 1
1763        while i < len(args):
1764            next_len = len(v) + len(args[i])
1765            if v[0] == '\t':
1766                next_len += delta_ind
1767            if next_len > 76:
1768                self.p(v + ',')
1769                v = ind
1770            else:
1771                v += ', '
1772            v += args[i]
1773            i += 1
1774        self.p(v + ')' + suffix)
1775
1776    def write_func_lvar(self, local_vars):
1777        if not local_vars:
1778            return
1779
1780        if type(local_vars) is str:
1781            local_vars = [local_vars]
1782
1783        local_vars.sort(key=len, reverse=True)
1784        for var in local_vars:
1785            self.p(var)
1786        self.nl()
1787
1788    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1789        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1790        self.block_start()
1791        self.write_func_lvar(local_vars=local_vars)
1792
1793        for line in body:
1794            self.p(line)
1795        self.block_end()
1796
1797    def writes_defines(self, defines):
1798        longest = 0
1799        for define in defines:
1800            if len(define[0]) > longest:
1801                longest = len(define[0])
1802        longest = ((longest + 8) // 8) * 8
1803        for define in defines:
1804            line = '#define ' + define[0]
1805            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1806            if type(define[1]) is int:
1807                line += str(define[1])
1808            elif type(define[1]) is str:
1809                line += '"' + define[1] + '"'
1810            self.p(line)
1811
1812    def write_struct_init(self, members):
1813        longest = max([len(x[0]) for x in members])
1814        longest += 1  # because we prepend a .
1815        longest = ((longest + 8) // 8) * 8
1816        for one in members:
1817            line = '.' + one[0]
1818            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1819            line += '= ' + str(one[1]) + ','
1820            self.p(line)
1821
1822    def ifdef_block(self, config):
1823        config_option = None
1824        if config:
1825            config_option = 'CONFIG_' + c_upper(config)
1826        if self._ifdef_block == config_option:
1827            return
1828
1829        if self._ifdef_block:
1830            self.p('#endif /* ' + self._ifdef_block + ' */')
1831        if config_option:
1832            self.p('#ifdef ' + config_option)
1833        self._ifdef_block = config_option
1834
1835
1836scalars = {'u8', 'u16', 'u32', 'u64', 's8', 's16', 's32', 's64', 'uint', 'sint'}
1837
1838direction_to_suffix = {
1839    'reply': '_rsp',
1840    'request': '_req',
1841    '': ''
1842}
1843
1844op_mode_to_wrapper = {
1845    'do': '',
1846    'dump': '_list',
1847    'notify': '_ntf',
1848    'event': '',
1849}
1850
1851_C_KW = {
1852    'auto',
1853    'bool',
1854    'break',
1855    'case',
1856    'char',
1857    'const',
1858    'continue',
1859    'default',
1860    'do',
1861    'double',
1862    'else',
1863    'enum',
1864    'extern',
1865    'float',
1866    'for',
1867    'goto',
1868    'if',
1869    'inline',
1870    'int',
1871    'long',
1872    'register',
1873    'return',
1874    'short',
1875    'signed',
1876    'sizeof',
1877    'static',
1878    'struct',
1879    'switch',
1880    'typedef',
1881    'union',
1882    'unsigned',
1883    'void',
1884    'volatile',
1885    'while'
1886}
1887
1888
1889def rdir(direction):
1890    if direction == 'reply':
1891        return 'request'
1892    if direction == 'request':
1893        return 'reply'
1894    return direction
1895
1896
1897def op_prefix(ri, direction, deref=False):
1898    suffix = f"_{ri.type_name}"
1899
1900    if not ri.op_mode:
1901        pass
1902    elif ri.op_mode == 'do':
1903        suffix += f"{direction_to_suffix[direction]}"
1904    else:
1905        if direction == 'request':
1906            suffix += '_req'
1907            if not ri.type_oneside:
1908                suffix += '_dump'
1909        else:
1910            if ri.type_consistent:
1911                if deref:
1912                    suffix += f"{direction_to_suffix[direction]}"
1913                else:
1914                    suffix += op_mode_to_wrapper[ri.op_mode]
1915            else:
1916                suffix += '_rsp'
1917                suffix += '_dump' if deref else '_list'
1918
1919    return f"{ri.family.c_name}{suffix}"
1920
1921
1922def type_name(ri, direction, deref=False):
1923    return f"struct {op_prefix(ri, direction, deref=deref)}"
1924
1925
1926def print_prototype(ri, direction, terminate=True, doc=None):
1927    suffix = ';' if terminate else ''
1928
1929    fname = ri.op.render_name
1930    if ri.op_mode == 'dump':
1931        fname += '_dump'
1932
1933    args = ['struct ynl_sock *ys']
1934    if 'request' in ri.op[ri.op_mode]:
1935        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1936
1937    ret = 'int'
1938    if 'reply' in ri.op[ri.op_mode]:
1939        ret = f"{type_name(ri, rdir(direction))} *"
1940
1941    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1942
1943
1944def print_req_prototype(ri):
1945    print_prototype(ri, "request", doc=ri.op['doc'])
1946
1947
1948def print_dump_prototype(ri):
1949    print_prototype(ri, "request")
1950
1951
1952def put_typol_submsg(cw, struct):
1953    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[] =')
1954
1955    i = 0
1956    for name, arg in struct.member_list():
1957        nest = ""
1958        if arg.type == 'nest':
1959            nest = f" .nest = &{arg.nested_render_name}_nest,"
1960        cw.p('[%d] = { .type = YNL_PT_SUBMSG, .name = "%s",%s },' %
1961             (i, name, nest))
1962        i += 1
1963
1964    cw.block_end(line=';')
1965    cw.nl()
1966
1967    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1968    cw.p(f'.max_attr = {i - 1},')
1969    cw.p(f'.table = {struct.render_name}_policy,')
1970    cw.block_end(line=';')
1971    cw.nl()
1972
1973
1974def put_typol_fwd(cw, struct):
1975    cw.p(f'extern const struct ynl_policy_nest {struct.render_name}_nest;')
1976
1977
1978def put_typol(cw, struct):
1979    if struct.submsg:
1980        put_typol_submsg(cw, struct)
1981        return
1982
1983    type_max = struct.attr_set.max_name
1984    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1985
1986    for _, arg in struct.member_list():
1987        arg.attr_typol(cw)
1988
1989    cw.block_end(line=';')
1990    cw.nl()
1991
1992    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1993    cw.p(f'.max_attr = {type_max},')
1994    cw.p(f'.table = {struct.render_name}_policy,')
1995    cw.block_end(line=';')
1996    cw.nl()
1997
1998
1999def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
2000    args = [f'int {arg_name}']
2001    if enum:
2002        args = [enum.user_type + ' ' + arg_name]
2003    cw.write_func_prot('const char *', f'{render_name}_str', args)
2004    cw.block_start()
2005    if enum and enum.type == 'flags':
2006        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
2007    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
2008    cw.p('return NULL;')
2009    cw.p(f'return {map_name}[{arg_name}];')
2010    cw.block_end()
2011    cw.nl()
2012
2013
2014def put_op_name_fwd(family, cw):
2015    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
2016
2017
2018def put_op_name(family, cw):
2019    map_name = f'{family.c_name}_op_strmap'
2020    cw.block_start(line=f"static const char * const {map_name}[] =")
2021    for op_name, op in family.msgs.items():
2022        if op.rsp_value:
2023            # Make sure we don't add duplicated entries, if multiple commands
2024            # produce the same response in legacy families.
2025            if family.rsp_by_value[op.rsp_value] != op:
2026                cw.p(f'// skip "{op_name}", duplicate reply value')
2027                continue
2028
2029            if op.req_value == op.rsp_value:
2030                cw.p(f'[{op.enum_name}] = "{op_name}",')
2031            else:
2032                cw.p(f'[{op.rsp_value}] = "{op_name}",')
2033    cw.block_end(line=';')
2034    cw.nl()
2035
2036    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
2037
2038
2039def put_enum_to_str_fwd(family, cw, enum):
2040    args = [enum.user_type + ' value']
2041    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
2042
2043
2044def put_enum_to_str(family, cw, enum):
2045    map_name = f'{enum.render_name}_strmap'
2046    cw.block_start(line=f"static const char * const {map_name}[] =")
2047    for entry in enum.entries.values():
2048        cw.p(f'[{entry.value}] = "{entry.name}",')
2049    cw.block_end(line=';')
2050    cw.nl()
2051
2052    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
2053
2054
2055def put_local_vars(struct):
2056    local_vars = []
2057    has_array = False
2058    has_count = False
2059    for _, arg in struct.member_list():
2060        has_array |= arg.type == 'indexed-array'
2061        has_count |= arg.presence_type() == 'count'
2062    if has_array:
2063        local_vars.append('struct nlattr *array;')
2064    if has_count:
2065        local_vars.append('unsigned int i;')
2066    return local_vars
2067
2068
2069def put_req_nested_prototype(ri, struct, suffix=';'):
2070    func_args = ['struct nlmsghdr *nlh',
2071                 'unsigned int attr_type',
2072                 f'{struct.ptr_name}obj']
2073
2074    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
2075                          suffix=suffix)
2076
2077
2078def put_req_nested(ri, struct):
2079    local_vars = []
2080    init_lines = []
2081
2082    if struct.submsg is None:
2083        local_vars.append('struct nlattr *nest;')
2084        init_lines.append("nest = ynl_attr_nest_start(nlh, attr_type);")
2085    if struct.fixed_header:
2086        local_vars.append('void *hdr;')
2087        struct_sz = f'sizeof({struct.fixed_header})'
2088        init_lines.append(f"hdr = ynl_nlmsg_put_extra_header(nlh, {struct_sz});")
2089        init_lines.append(f"memcpy(hdr, &obj->_hdr, {struct_sz});")
2090
2091    local_vars += put_local_vars(struct)
2092
2093    put_req_nested_prototype(ri, struct, suffix='')
2094    ri.cw.block_start()
2095    ri.cw.write_func_lvar(local_vars)
2096
2097    for line in init_lines:
2098        ri.cw.p(line)
2099
2100    for _, arg in struct.member_list():
2101        arg.attr_put(ri, "obj")
2102
2103    if struct.submsg is None:
2104        ri.cw.p("ynl_attr_nest_end(nlh, nest);")
2105
2106    ri.cw.nl()
2107    ri.cw.p('return 0;')
2108    ri.cw.block_end()
2109    ri.cw.nl()
2110
2111
2112def _multi_parse(ri, struct, init_lines, local_vars):
2113    if struct.fixed_header:
2114        local_vars += ['void *hdr;']
2115    if struct.nested:
2116        if struct.fixed_header:
2117            iter_line = f"ynl_attr_for_each_nested_off(attr, nested, sizeof({struct.fixed_header}))"
2118        else:
2119            iter_line = "ynl_attr_for_each_nested(attr, nested)"
2120    else:
2121        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
2122        if ri.op.fixed_header != ri.family.fixed_header:
2123            if ri.family.is_classic():
2124                iter_line = f"ynl_attr_for_each(attr, nlh, sizeof({struct.fixed_header}))"
2125            else:
2126                raise Exception("Per-op fixed header not supported, yet")
2127
2128    indexed_arrays = set()
2129    multi_attrs = set()
2130    needs_parg = False
2131    var_set = set()
2132    for arg, aspec in struct.member_list():
2133        if aspec['type'] == 'indexed-array' and 'sub-type' in aspec:
2134            if aspec["sub-type"] in {'binary', 'nest'}:
2135                local_vars.append(f'const struct nlattr *attr_{aspec.c_name} = NULL;')
2136                indexed_arrays.add(arg)
2137            elif aspec['sub-type'] in scalars:
2138                local_vars.append(f'const struct nlattr *attr_{aspec.c_name} = NULL;')
2139                indexed_arrays.add(arg)
2140            else:
2141                raise Exception(f'Not supported sub-type {aspec["sub-type"]}')
2142        if 'multi-attr' in aspec:
2143            multi_attrs.add(arg)
2144        needs_parg |= 'nested-attributes' in aspec
2145        needs_parg |= 'sub-message' in aspec
2146
2147        try:
2148            _, _, l_vars = aspec._attr_get(ri, '')
2149            var_set |= set(l_vars) if l_vars else set()
2150        except Exception:
2151            pass  # _attr_get() not implemented by simple types, ignore
2152    local_vars += list(var_set)
2153    if indexed_arrays or multi_attrs:
2154        local_vars.append('int i;')
2155    if needs_parg:
2156        local_vars.append('struct ynl_parse_arg parg;')
2157        init_lines.append('parg.ys = yarg->ys;')
2158
2159    all_multi = indexed_arrays | multi_attrs
2160
2161    for arg in sorted(all_multi):
2162        local_vars.append(f"unsigned int n_{struct[arg].c_name} = 0;")
2163
2164    ri.cw.block_start()
2165    ri.cw.write_func_lvar(local_vars)
2166
2167    for line in init_lines:
2168        ri.cw.p(line)
2169    ri.cw.nl()
2170
2171    for arg in struct.inherited:
2172        ri.cw.p(f'dst->{arg} = {arg};')
2173
2174    if struct.fixed_header:
2175        if struct.nested:
2176            ri.cw.p('hdr = ynl_attr_data(nested);')
2177        elif ri.family.is_classic():
2178            ri.cw.p('hdr = ynl_nlmsg_data(nlh);')
2179        else:
2180            ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
2181        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({struct.fixed_header}));")
2182    for arg in sorted(all_multi):
2183        aspec = struct[arg]
2184        ri.cw.p(f"if (dst->{aspec.c_name})")
2185        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
2186
2187    ri.cw.nl()
2188    ri.cw.block_start(line=iter_line)
2189    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
2190    ri.cw.nl()
2191
2192    first = True
2193    for _, arg in struct.member_list():
2194        good = arg.attr_get(ri, 'dst', first=first)
2195        # First may be 'unused' or 'pad', ignore those
2196        first &= not good
2197
2198    ri.cw.block_end()
2199    ri.cw.nl()
2200
2201    for arg in sorted(indexed_arrays):
2202        aspec = struct[arg]
2203
2204        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2205        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2206        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2207        ri.cw.p('i = 0;')
2208        if 'nested-attributes' in aspec:
2209            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2210        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
2211        if 'nested-attributes' in aspec:
2212            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2213            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
2214            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2215        elif aspec.sub_type in scalars:
2216            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.sub_type}(attr);")
2217        elif aspec.sub_type == 'binary' and 'exact-len' in aspec.checks:
2218            # Length is validated by typol
2219            ri.cw.p(f'memcpy(dst->{aspec.c_name}[i], ynl_attr_data(attr), {aspec.checks["exact-len"]});')
2220        else:
2221            raise Exception(f"Nest parsing type not supported in {aspec['name']}")
2222        ri.cw.p('i++;')
2223        ri.cw.block_end()
2224        ri.cw.block_end()
2225    ri.cw.nl()
2226
2227    for arg in sorted(multi_attrs):
2228        aspec = struct[arg]
2229        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2230        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2231        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2232        ri.cw.p('i = 0;')
2233        if 'nested-attributes' in aspec:
2234            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2235        ri.cw.block_start(line=iter_line)
2236        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
2237        if 'nested-attributes' in aspec:
2238            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2239            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
2240            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2241        elif aspec.type in scalars:
2242            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
2243        elif aspec.type == 'binary' and 'struct' in aspec:
2244            ri.cw.p('size_t len = ynl_attr_data_len(attr);')
2245            ri.cw.nl()
2246            ri.cw.p(f'if (len > sizeof(dst->{aspec.c_name}[0]))')
2247            ri.cw.p(f'len = sizeof(dst->{aspec.c_name}[0]);')
2248            ri.cw.p(f"memcpy(&dst->{aspec.c_name}[i], ynl_attr_data(attr), len);")
2249        elif aspec.type == 'string':
2250            ri.cw.p('unsigned int len;')
2251            ri.cw.nl()
2252            ri.cw.p('len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));')
2253            ri.cw.p(f'dst->{aspec.c_name}[i] = malloc(sizeof(struct ynl_string) + len + 1);')
2254            ri.cw.p(f"dst->{aspec.c_name}[i]->len = len;")
2255            ri.cw.p(f"memcpy(dst->{aspec.c_name}[i]->str, ynl_attr_get_str(attr), len);")
2256            ri.cw.p(f"dst->{aspec.c_name}[i]->str[len] = 0;")
2257        else:
2258            raise Exception(f'Nest parsing of type {aspec.type} not supported yet')
2259        ri.cw.p('i++;')
2260        ri.cw.block_end()
2261        ri.cw.block_end()
2262        ri.cw.block_end()
2263    ri.cw.nl()
2264
2265    if struct.nested:
2266        ri.cw.p('return 0;')
2267    else:
2268        ri.cw.p('return YNL_PARSE_CB_OK;')
2269    ri.cw.block_end()
2270    ri.cw.nl()
2271
2272
2273def parse_rsp_submsg(ri, struct):
2274    parse_rsp_nested_prototype(ri, struct, suffix='')
2275
2276    var = 'dst'
2277    local_vars = {'const struct nlattr *attr = nested;',
2278                  f'{struct.ptr_name}{var} = yarg->data;',
2279                  'struct ynl_parse_arg parg;'}
2280
2281    for _, arg in struct.member_list():
2282        _, _, l_vars = arg._attr_get(ri, var)
2283        local_vars |= set(l_vars) if l_vars else set()
2284
2285    ri.cw.block_start()
2286    ri.cw.write_func_lvar(list(local_vars))
2287    ri.cw.p('parg.ys = yarg->ys;')
2288    ri.cw.nl()
2289
2290    first = True
2291    for name, arg in struct.member_list():
2292        kw = 'if' if first else 'else if'
2293        first = False
2294
2295        ri.cw.block_start(line=f'{kw} (!strcmp(sel, "{name}"))')
2296        get_lines, init_lines, _ = arg._attr_get(ri, var)
2297        for line in init_lines or []:
2298            ri.cw.p(line)
2299        for line in get_lines:
2300            ri.cw.p(line)
2301        if arg.presence_type() == 'present':
2302            ri.cw.p(f"{var}->_present.{arg.c_name} = 1;")
2303        ri.cw.block_end()
2304    ri.cw.p('return 0;')
2305    ri.cw.block_end()
2306    ri.cw.nl()
2307
2308
2309def parse_rsp_nested_prototype(ri, struct, suffix=';'):
2310    func_args = ['struct ynl_parse_arg *yarg',
2311                 'const struct nlattr *nested']
2312    for sel in struct.external_selectors():
2313        func_args.append('const char *_sel_' + sel.name)
2314    if struct.submsg:
2315        func_args.insert(1, 'const char *sel')
2316    for arg in struct.inherited:
2317        func_args.append('__u32 ' + arg)
2318
2319    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
2320                          suffix=suffix)
2321
2322
2323def parse_rsp_nested(ri, struct):
2324    if struct.submsg:
2325        return parse_rsp_submsg(ri, struct)
2326
2327    parse_rsp_nested_prototype(ri, struct, suffix='')
2328
2329    local_vars = ['const struct nlattr *attr;',
2330                  f'{struct.ptr_name}dst = yarg->data;']
2331    init_lines = []
2332
2333    if struct.member_list():
2334        _multi_parse(ri, struct, init_lines, local_vars)
2335    else:
2336        # Empty nest
2337        ri.cw.block_start()
2338        ri.cw.p('return 0;')
2339        ri.cw.block_end()
2340        ri.cw.nl()
2341
2342
2343def parse_rsp_msg(ri, deref=False):
2344    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
2345        return
2346
2347    func_args = ['const struct nlmsghdr *nlh',
2348                 'struct ynl_parse_arg *yarg']
2349
2350    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
2351                  'const struct nlattr *attr;']
2352    init_lines = ['dst = yarg->data;']
2353
2354    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
2355
2356    if ri.struct["reply"].member_list():
2357        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
2358    else:
2359        # Empty reply
2360        ri.cw.block_start()
2361        ri.cw.p('return YNL_PARSE_CB_OK;')
2362        ri.cw.block_end()
2363        ri.cw.nl()
2364
2365
2366def print_req(ri):
2367    ret_ok = '0'
2368    ret_err = '-1'
2369    direction = "request"
2370    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
2371                  'struct nlmsghdr *nlh;',
2372                  'int err;']
2373
2374    if 'reply' in ri.op[ri.op_mode]:
2375        ret_ok = 'rsp'
2376        ret_err = 'NULL'
2377        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
2378
2379    if ri.struct["request"].fixed_header:
2380        local_vars += ['size_t hdr_len;',
2381                       'void *hdr;']
2382
2383    local_vars += put_local_vars(ri.struct['request'])
2384
2385    print_prototype(ri, direction, terminate=False)
2386    ri.cw.block_start()
2387    ri.cw.write_func_lvar(local_vars)
2388
2389    if ri.family.is_classic():
2390        ri.cw.p(f"nlh = ynl_msg_start_req(ys, {ri.op.enum_name}, req->_nlmsg_flags);")
2391    else:
2392        ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2393
2394    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2395    ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2396    if 'reply' in ri.op[ri.op_mode]:
2397        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2398    ri.cw.nl()
2399
2400    if ri.struct['request'].fixed_header:
2401        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2402        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2403        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2404        ri.cw.nl()
2405
2406    for _, attr in ri.struct["request"].member_list():
2407        attr.attr_put(ri, "req")
2408    ri.cw.nl()
2409
2410    if 'reply' in ri.op[ri.op_mode]:
2411        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
2412        ri.cw.p('yrs.yarg.data = rsp;')
2413        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
2414        if ri.op.value is not None:
2415            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
2416        else:
2417            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
2418        ri.cw.nl()
2419    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
2420    ri.cw.p('if (err < 0)')
2421    if 'reply' in ri.op[ri.op_mode]:
2422        ri.cw.p('goto err_free;')
2423    else:
2424        ri.cw.p('return -1;')
2425    ri.cw.nl()
2426
2427    ri.cw.p(f"return {ret_ok};")
2428    ri.cw.nl()
2429
2430    if 'reply' in ri.op[ri.op_mode]:
2431        ri.cw.p('err_free:')
2432        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
2433        ri.cw.p(f"return {ret_err};")
2434
2435    ri.cw.block_end()
2436
2437
2438def print_dump(ri):
2439    direction = "request"
2440    print_prototype(ri, direction, terminate=False)
2441    ri.cw.block_start()
2442    local_vars = ['struct ynl_dump_state yds = {};',
2443                  'struct nlmsghdr *nlh;',
2444                  'int err;']
2445
2446    if ri.struct['request'].fixed_header:
2447        local_vars += ['size_t hdr_len;',
2448                       'void *hdr;']
2449
2450    if 'request' in ri.op[ri.op_mode]:
2451        local_vars += put_local_vars(ri.struct['request'])
2452
2453    ri.cw.write_func_lvar(local_vars)
2454
2455    ri.cw.p('yds.yarg.ys = ys;')
2456    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2457    ri.cw.p("yds.yarg.data = NULL;")
2458    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
2459    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
2460    if ri.op.value is not None:
2461        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
2462    else:
2463        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
2464    ri.cw.nl()
2465    if ri.family.is_classic():
2466        ri.cw.p(f"nlh = ynl_msg_start_dump(ys, {ri.op.enum_name});")
2467    else:
2468        ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2469
2470    if ri.struct['request'].fixed_header:
2471        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2472        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2473        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2474        ri.cw.nl()
2475
2476    if "request" in ri.op[ri.op_mode]:
2477        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2478        ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2479        ri.cw.nl()
2480        for _, attr in ri.struct["request"].member_list():
2481            attr.attr_put(ri, "req")
2482    ri.cw.nl()
2483
2484    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
2485    ri.cw.p('if (err < 0)')
2486    ri.cw.p('goto free_list;')
2487    ri.cw.nl()
2488
2489    ri.cw.p('return yds.first;')
2490    ri.cw.nl()
2491    ri.cw.p('free_list:')
2492    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
2493    ri.cw.p('return NULL;')
2494    ri.cw.block_end()
2495
2496
2497def call_free(ri, direction, var):
2498    return f"{op_prefix(ri, direction)}_free({var});"
2499
2500
2501def free_arg_name(direction):
2502    if direction:
2503        return direction_to_suffix[direction][1:]
2504    return 'obj'
2505
2506
2507def print_alloc_wrapper(ri, direction, struct=None):
2508    name = op_prefix(ri, direction)
2509    struct_name = name
2510    if ri.type_name_conflict:
2511        struct_name += '_'
2512
2513    args = ["void"]
2514    cnt = "1"
2515    if struct and struct.in_multi_val:
2516        args = ["unsigned int n"]
2517        cnt = "n"
2518
2519    ri.cw.write_func_prot(f'static inline struct {struct_name} *',
2520                          f"{name}_alloc", args)
2521    ri.cw.block_start()
2522    ri.cw.p(f'return calloc({cnt}, sizeof(struct {struct_name}));')
2523    ri.cw.block_end()
2524
2525
2526def print_free_prototype(ri, direction, suffix=';'):
2527    name = op_prefix(ri, direction)
2528    struct_name = name
2529    if ri.type_name_conflict:
2530        struct_name += '_'
2531    arg = free_arg_name(direction)
2532    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
2533
2534
2535def print_nlflags_set(ri, direction):
2536    name = op_prefix(ri, direction)
2537    ri.cw.write_func_prot('static inline void', f"{name}_set_nlflags",
2538                          [f"struct {name} *req", "__u16 nl_flags"])
2539    ri.cw.block_start()
2540    ri.cw.p('req->_nlmsg_flags = nl_flags;')
2541    ri.cw.block_end()
2542    ri.cw.nl()
2543
2544
2545def _print_type(ri, direction, struct):
2546    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
2547    if not direction and ri.type_name_conflict:
2548        suffix += '_'
2549
2550    if ri.op_mode == 'dump' and not ri.type_oneside:
2551        suffix += '_dump'
2552
2553    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
2554
2555    if ri.needs_nlflags(direction):
2556        ri.cw.p('__u16 _nlmsg_flags;')
2557        ri.cw.nl()
2558    if struct.fixed_header:
2559        ri.cw.p(struct.fixed_header + ' _hdr;')
2560        ri.cw.nl()
2561
2562    for type_filter in ['present', 'len', 'count']:
2563        meta_started = False
2564        for _, attr in struct.member_list():
2565            line = attr.presence_member(ri.ku_space, type_filter)
2566            if line:
2567                if not meta_started:
2568                    ri.cw.block_start(line="struct")
2569                    meta_started = True
2570                ri.cw.p(line)
2571        if meta_started:
2572            ri.cw.block_end(line=f'_{type_filter};')
2573    ri.cw.nl()
2574
2575    for arg in struct.inherited:
2576        ri.cw.p(f"__u32 {arg};")
2577
2578    for _, attr in struct.member_list():
2579        attr.struct_member(ri)
2580
2581    ri.cw.block_end(line=';')
2582    ri.cw.nl()
2583
2584
2585def print_type(ri, direction):
2586    _print_type(ri, direction, ri.struct[direction])
2587
2588
2589def print_type_full(ri, struct):
2590    _print_type(ri, "", struct)
2591
2592    if struct.request and struct.in_multi_val:
2593        print_alloc_wrapper(ri, "", struct)
2594        ri.cw.nl()
2595        free_rsp_nested_prototype(ri)
2596        ri.cw.nl()
2597
2598        # Name conflicts are too hard to deal with with the current code base,
2599        # they are very rare so don't bother printing setters in that case.
2600        if ri.ku_space == 'user' and not ri.type_name_conflict:
2601            for _, attr in struct.member_list():
2602                attr.setter(ri, ri.attr_set, "", var="obj")
2603        ri.cw.nl()
2604
2605
2606def print_type_helpers(ri, direction, deref=False):
2607    print_free_prototype(ri, direction)
2608    ri.cw.nl()
2609
2610    if ri.needs_nlflags(direction):
2611        print_nlflags_set(ri, direction)
2612
2613    if ri.ku_space == 'user' and direction == 'request':
2614        for _, attr in ri.struct[direction].member_list():
2615            attr.setter(ri, ri.attr_set, direction, deref=deref)
2616    ri.cw.nl()
2617
2618
2619def print_req_type_helpers(ri):
2620    if ri.type_empty("request"):
2621        return
2622    print_alloc_wrapper(ri, "request")
2623    print_type_helpers(ri, "request")
2624
2625
2626def print_rsp_type_helpers(ri):
2627    if 'reply' not in ri.op[ri.op_mode]:
2628        return
2629    print_type_helpers(ri, "reply")
2630
2631
2632def print_parse_prototype(ri, direction, terminate=True):
2633    suffix = "_rsp" if direction == "reply" else "_req"
2634    term = ';' if terminate else ''
2635
2636    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
2637                          ['const struct nlattr **tb',
2638                           f"struct {ri.op.render_name}{suffix} *req"],
2639                          suffix=term)
2640
2641
2642def print_req_type(ri):
2643    if ri.type_empty("request"):
2644        return
2645    print_type(ri, "request")
2646
2647
2648def print_req_free(ri):
2649    if 'request' not in ri.op[ri.op_mode]:
2650        return
2651    _free_type(ri, 'request', ri.struct['request'])
2652
2653
2654def print_rsp_type(ri):
2655    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2656        direction = 'reply'
2657    elif ri.op_mode == 'event':
2658        direction = 'reply'
2659    else:
2660        return
2661    print_type(ri, direction)
2662
2663
2664def print_wrapped_type(ri):
2665    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2666    if ri.op_mode == 'dump':
2667        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2668    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2669        ri.cw.p('__u16 family;')
2670        ri.cw.p('__u8 cmd;')
2671        ri.cw.p('struct ynl_ntf_base_type *next;')
2672        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2673    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2674    ri.cw.block_end(line=';')
2675    ri.cw.nl()
2676    print_free_prototype(ri, 'reply')
2677    ri.cw.nl()
2678
2679
2680def _free_type_members_iter(ri, struct):
2681    if struct.free_needs_iter():
2682        ri.cw.p('unsigned int i;')
2683        ri.cw.nl()
2684
2685
2686def _free_type_members(ri, var, struct, ref=''):
2687    for _, attr in struct.member_list():
2688        attr.free(ri, var, ref)
2689
2690
2691def _free_type(ri, direction, struct):
2692    var = free_arg_name(direction)
2693
2694    print_free_prototype(ri, direction, suffix='')
2695    ri.cw.block_start()
2696    _free_type_members_iter(ri, struct)
2697    _free_type_members(ri, var, struct)
2698    if direction:
2699        ri.cw.p(f'free({var});')
2700    ri.cw.block_end()
2701    ri.cw.nl()
2702
2703
2704def free_rsp_nested_prototype(ri):
2705        print_free_prototype(ri, "")
2706
2707
2708def free_rsp_nested(ri, struct):
2709    _free_type(ri, "", struct)
2710
2711
2712def print_rsp_free(ri):
2713    if 'reply' not in ri.op[ri.op_mode]:
2714        return
2715    _free_type(ri, 'reply', ri.struct['reply'])
2716
2717
2718def print_dump_type_free(ri):
2719    sub_type = type_name(ri, 'reply')
2720
2721    print_free_prototype(ri, 'reply', suffix='')
2722    ri.cw.block_start()
2723    ri.cw.p(f"{sub_type} *next = rsp;")
2724    ri.cw.nl()
2725    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2726    _free_type_members_iter(ri, ri.struct['reply'])
2727    ri.cw.p('rsp = next;')
2728    ri.cw.p('next = rsp->next;')
2729    ri.cw.nl()
2730
2731    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2732    ri.cw.p('free(rsp);')
2733    ri.cw.block_end()
2734    ri.cw.block_end()
2735    ri.cw.nl()
2736
2737
2738def print_ntf_type_free(ri):
2739    print_free_prototype(ri, 'reply', suffix='')
2740    ri.cw.block_start()
2741    _free_type_members_iter(ri, ri.struct['reply'])
2742    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2743    ri.cw.p('free(rsp);')
2744    ri.cw.block_end()
2745    ri.cw.nl()
2746
2747
2748def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2749    if terminate and ri and policy_should_be_static(struct.family):
2750        return
2751
2752    if terminate:
2753        prefix = 'extern '
2754    else:
2755        if ri and policy_should_be_static(struct.family):
2756            prefix = 'static '
2757        else:
2758            prefix = ''
2759
2760    suffix = ';' if terminate else ' = {'
2761
2762    max_attr = struct.attr_max_val
2763    if ri:
2764        name = ri.op.render_name
2765        if ri.op.dual_policy:
2766            name += '_' + ri.op_mode
2767    else:
2768        name = struct.render_name
2769    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2770
2771
2772def print_req_policy(cw, struct, ri=None):
2773    if ri and ri.op:
2774        cw.ifdef_block(ri.op.get('config-cond', None))
2775    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2776    for _, arg in struct.member_list():
2777        arg.attr_policy(cw)
2778    cw.p("};")
2779    cw.ifdef_block(None)
2780    cw.nl()
2781
2782
2783def kernel_can_gen_family_struct(family):
2784    return family.proto == 'genetlink'
2785
2786
2787def policy_should_be_static(family):
2788    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2789
2790
2791def print_kernel_policy_ranges(family, cw):
2792    first = True
2793    for _, attr_set in family.attr_sets.items():
2794        if attr_set.subset_of:
2795            continue
2796
2797        for _, attr in attr_set.items():
2798            if not attr.request:
2799                continue
2800            if 'full-range' not in attr.checks:
2801                continue
2802
2803            if first:
2804                cw.p('/* Integer value ranges */')
2805                first = False
2806
2807            sign = '' if attr.type[0] == 'u' else '_signed'
2808            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2809            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2810            members = []
2811            if 'min' in attr.checks:
2812                members.append(('min', attr.get_limit_str('min', suffix=suffix)))
2813            if 'max' in attr.checks:
2814                members.append(('max', attr.get_limit_str('max', suffix=suffix)))
2815            cw.write_struct_init(members)
2816            cw.block_end(line=';')
2817            cw.nl()
2818
2819
2820def print_kernel_policy_sparse_enum_validates(family, cw):
2821    first = True
2822    for _, attr_set in family.attr_sets.items():
2823        if attr_set.subset_of:
2824            continue
2825
2826        for _, attr in attr_set.items():
2827            if not attr.request:
2828                continue
2829            if not attr.enum_name:
2830                continue
2831            if 'sparse' not in attr.checks:
2832                continue
2833
2834            if first:
2835                cw.p('/* Sparse enums validation callbacks */')
2836                first = False
2837
2838            cw.write_func_prot('static int', f'{c_lower(attr.enum_name)}_validate',
2839                               ['const struct nlattr *attr', 'struct netlink_ext_ack *extack'])
2840            cw.block_start()
2841            cw.block_start(line=f'switch (nla_get_{attr["type"]}(attr))')
2842            enum = family.consts[attr['enum']]
2843            first_entry = True
2844            for entry in enum.entries.values():
2845                if first_entry:
2846                    first_entry = False
2847                else:
2848                    cw.p('fallthrough;')
2849                cw.p(f'case {entry.c_name}:')
2850            cw.p('return 0;')
2851            cw.block_end()
2852            cw.p('NL_SET_ERR_MSG_ATTR(extack, attr, "invalid enum value");')
2853            cw.p('return -EINVAL;')
2854            cw.block_end()
2855            cw.nl()
2856
2857
2858def print_kernel_op_table_fwd(family, cw, terminate):
2859    exported = not kernel_can_gen_family_struct(family)
2860
2861    if not terminate or exported:
2862        cw.p(f"/* Ops table for {family.ident_name} */")
2863
2864        pol_to_struct = {'global': 'genl_small_ops',
2865                         'per-op': 'genl_ops',
2866                         'split': 'genl_split_ops'}
2867        struct_type = pol_to_struct[family.kernel_policy]
2868
2869        if not exported:
2870            cnt = ""
2871        elif family.kernel_policy == 'split':
2872            cnt = 0
2873            for op in family.ops.values():
2874                if 'do' in op:
2875                    cnt += 1
2876                if 'dump' in op:
2877                    cnt += 1
2878        else:
2879            cnt = len(family.ops)
2880
2881        qual = 'static const' if not exported else 'const'
2882        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2883        if terminate:
2884            cw.p(f"extern {line};")
2885        else:
2886            cw.block_start(line=line + ' =')
2887
2888    if not terminate:
2889        return
2890
2891    cw.nl()
2892    for name in family.hooks['pre']['do']['list']:
2893        cw.write_func_prot('int', c_lower(name),
2894                           ['const struct genl_split_ops *ops',
2895                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2896    for name in family.hooks['post']['do']['list']:
2897        cw.write_func_prot('void', c_lower(name),
2898                           ['const struct genl_split_ops *ops',
2899                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2900    for name in family.hooks['pre']['dump']['list']:
2901        cw.write_func_prot('int', c_lower(name),
2902                           ['struct netlink_callback *cb'], suffix=';')
2903    for name in family.hooks['post']['dump']['list']:
2904        cw.write_func_prot('int', c_lower(name),
2905                           ['struct netlink_callback *cb'], suffix=';')
2906
2907    cw.nl()
2908
2909    for op_name, op in family.ops.items():
2910        if op.is_async:
2911            continue
2912
2913        if 'do' in op:
2914            name = c_lower(f"{family.ident_name}-nl-{op_name}-doit")
2915            cw.write_func_prot('int', name,
2916                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2917
2918        if 'dump' in op:
2919            name = c_lower(f"{family.ident_name}-nl-{op_name}-dumpit")
2920            cw.write_func_prot('int', name,
2921                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2922    cw.nl()
2923
2924
2925def print_kernel_op_table_hdr(family, cw):
2926    print_kernel_op_table_fwd(family, cw, terminate=True)
2927
2928
2929def print_kernel_op_table(family, cw):
2930    print_kernel_op_table_fwd(family, cw, terminate=False)
2931    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2932        for op_name, op in family.ops.items():
2933            if op.is_async:
2934                continue
2935
2936            cw.ifdef_block(op.get('config-cond', None))
2937            cw.block_start()
2938            members = [('cmd', op.enum_name)]
2939            if 'dont-validate' in op:
2940                members.append(('validate',
2941                                ' | '.join([c_upper('genl-dont-validate-' + x)
2942                                            for x in op['dont-validate']])), )
2943            for op_mode in ['do', 'dump']:
2944                if op_mode in op:
2945                    name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2946                    members.append((op_mode + 'it', name))
2947            if family.kernel_policy == 'per-op':
2948                struct = Struct(family, op['attribute-set'],
2949                                type_list=op['do']['request']['attributes'])
2950
2951                name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2952                members.append(('policy', name))
2953                members.append(('maxattr', struct.attr_max_val.enum_name))
2954            if 'flags' in op:
2955                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2956            cw.write_struct_init(members)
2957            cw.block_end(line=',')
2958    elif family.kernel_policy == 'split':
2959        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2960                    'dump': {'pre': 'start', 'post': 'done'}}
2961
2962        for op_name, op in family.ops.items():
2963            for op_mode in ['do', 'dump']:
2964                if op.is_async or op_mode not in op:
2965                    continue
2966
2967                cw.ifdef_block(op.get('config-cond', None))
2968                cw.block_start()
2969                members = [('cmd', op.enum_name)]
2970                if 'dont-validate' in op:
2971                    dont_validate = []
2972                    for x in op['dont-validate']:
2973                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2974                            continue
2975                        if op_mode == "dump" and x == 'strict':
2976                            continue
2977                        dont_validate.append(x)
2978
2979                    if dont_validate:
2980                        members.append(('validate',
2981                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2982                                                    for x in dont_validate])), )
2983                name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2984                if 'pre' in op[op_mode]:
2985                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2986                members.append((op_mode + 'it', name))
2987                if 'post' in op[op_mode]:
2988                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2989                if 'request' in op[op_mode]:
2990                    struct = Struct(family, op['attribute-set'],
2991                                    type_list=op[op_mode]['request']['attributes'])
2992
2993                    if op.dual_policy:
2994                        name = c_lower(f"{family.ident_name}-{op_name}-{op_mode}-nl-policy")
2995                    else:
2996                        name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2997                    members.append(('policy', name))
2998                    members.append(('maxattr', struct.attr_max_val.enum_name))
2999                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
3000                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
3001                cw.write_struct_init(members)
3002                cw.block_end(line=',')
3003    cw.ifdef_block(None)
3004
3005    cw.block_end(line=';')
3006    cw.nl()
3007
3008
3009def print_kernel_mcgrp_hdr(family, cw):
3010    if not family.mcgrps['list']:
3011        return
3012
3013    cw.block_start('enum')
3014    for grp in family.mcgrps['list']:
3015        grp_id = c_upper(f"{family.ident_name}-nlgrp-{grp['name']},")
3016        cw.p(grp_id)
3017    cw.block_end(';')
3018    cw.nl()
3019
3020
3021def print_kernel_mcgrp_src(family, cw):
3022    if not family.mcgrps['list']:
3023        return
3024
3025    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
3026    for grp in family.mcgrps['list']:
3027        name = grp['name']
3028        grp_id = c_upper(f"{family.ident_name}-nlgrp-{name}")
3029        cw.p('[' + grp_id + '] = { "' + name + '", },')
3030    cw.block_end(';')
3031    cw.nl()
3032
3033
3034def print_kernel_family_struct_hdr(family, cw):
3035    if not kernel_can_gen_family_struct(family):
3036        return
3037
3038    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
3039    cw.nl()
3040    if 'sock-priv' in family.kernel_family:
3041        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
3042        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
3043        cw.nl()
3044
3045
3046def print_kernel_family_struct_src(family, cw):
3047    if not kernel_can_gen_family_struct(family):
3048        return
3049
3050    if 'sock-priv' in family.kernel_family:
3051        # Generate "trampolines" to make CFI happy
3052        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_init",
3053                      [f"{family.c_name}_nl_sock_priv_init(priv);"],
3054                      ["void *priv"])
3055        cw.nl()
3056        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_destroy",
3057                      [f"{family.c_name}_nl_sock_priv_destroy(priv);"],
3058                      ["void *priv"])
3059        cw.nl()
3060
3061    cw.block_start(f"struct genl_family {family.ident_name}_nl_family __ro_after_init =")
3062    cw.p('.name\t\t= ' + family.fam_key + ',')
3063    cw.p('.version\t= ' + family.ver_key + ',')
3064    cw.p('.netnsok\t= true,')
3065    cw.p('.parallel_ops\t= true,')
3066    cw.p('.module\t\t= THIS_MODULE,')
3067    if family.kernel_policy == 'per-op':
3068        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
3069        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
3070    elif family.kernel_policy == 'split':
3071        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
3072        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
3073    if family.mcgrps['list']:
3074        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
3075        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
3076    if 'sock-priv' in family.kernel_family:
3077        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
3078        cw.p(f'.sock_priv_init\t= __{family.c_name}_nl_sock_priv_init,')
3079        cw.p(f'.sock_priv_destroy = __{family.c_name}_nl_sock_priv_destroy,')
3080    cw.block_end(';')
3081
3082
3083def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
3084    start_line = 'enum'
3085    if enum_name in obj:
3086        if obj[enum_name]:
3087            start_line = 'enum ' + c_lower(obj[enum_name])
3088    elif ckey and ckey in obj:
3089        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
3090    cw.block_start(line=start_line)
3091
3092
3093def render_uapi_unified(family, cw, max_by_define, separate_ntf):
3094    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
3095    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
3096    max_value = f"({cnt_name} - 1)"
3097
3098    uapi_enum_start(family, cw, family['operations'], 'enum-name')
3099    val = 0
3100    for op in family.msgs.values():
3101        if separate_ntf and ('notify' in op or 'event' in op):
3102            continue
3103
3104        suffix = ','
3105        if op.value != val:
3106            suffix = f" = {op.value},"
3107            val = op.value
3108        cw.p(op.enum_name + suffix)
3109        val += 1
3110    cw.nl()
3111    cw.p(cnt_name + ('' if max_by_define else ','))
3112    if not max_by_define:
3113        cw.p(f"{max_name} = {max_value}")
3114    cw.block_end(line=';')
3115    if max_by_define:
3116        cw.p(f"#define {max_name} {max_value}")
3117    cw.nl()
3118
3119
3120def render_uapi_directional(family, cw, max_by_define):
3121    max_name = f"{family.op_prefix}USER_MAX"
3122    cnt_name = f"__{family.op_prefix}USER_CNT"
3123    max_value = f"({cnt_name} - 1)"
3124
3125    cw.block_start(line='enum')
3126    cw.p(c_upper(f'{family.name}_MSG_USER_NONE = 0,'))
3127    val = 0
3128    for op in family.msgs.values():
3129        if 'do' in op and 'event' not in op:
3130            suffix = ','
3131            if op.value and op.value != val:
3132                suffix = f" = {op.value},"
3133                val = op.value
3134            cw.p(op.enum_name + suffix)
3135            val += 1
3136    cw.nl()
3137    cw.p(cnt_name + ('' if max_by_define else ','))
3138    if not max_by_define:
3139        cw.p(f"{max_name} = {max_value}")
3140    cw.block_end(line=';')
3141    if max_by_define:
3142        cw.p(f"#define {max_name} {max_value}")
3143    cw.nl()
3144
3145    max_name = f"{family.op_prefix}KERNEL_MAX"
3146    cnt_name = f"__{family.op_prefix}KERNEL_CNT"
3147    max_value = f"({cnt_name} - 1)"
3148
3149    cw.block_start(line='enum')
3150    cw.p(c_upper(f'{family.name}_MSG_KERNEL_NONE = 0,'))
3151    val = 0
3152    for op in family.msgs.values():
3153        if ('do' in op and 'reply' in op['do']) or 'notify' in op or 'event' in op:
3154            enum_name = op.enum_name
3155            if 'event' not in op and 'notify' not in op:
3156                enum_name = f'{enum_name}_REPLY'
3157
3158            suffix = ','
3159            if op.value and op.value != val:
3160                suffix = f" = {op.value},"
3161                val = op.value
3162            cw.p(enum_name + suffix)
3163            val += 1
3164    cw.nl()
3165    cw.p(cnt_name + ('' if max_by_define else ','))
3166    if not max_by_define:
3167        cw.p(f"{max_name} = {max_value}")
3168    cw.block_end(line=';')
3169    if max_by_define:
3170        cw.p(f"#define {max_name} {max_value}")
3171    cw.nl()
3172
3173
3174def render_uapi(family, cw):
3175    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
3176    hdr_prot = hdr_prot.replace('/', '_')
3177    cw.p('#ifndef ' + hdr_prot)
3178    cw.p('#define ' + hdr_prot)
3179    cw.nl()
3180
3181    defines = [(family.fam_key, family["name"]),
3182               (family.ver_key, family.get('version', 1))]
3183    cw.writes_defines(defines)
3184    cw.nl()
3185
3186    defines = []
3187    for const in family['definitions']:
3188        if const.get('header'):
3189            continue
3190
3191        if const['type'] != 'const':
3192            cw.writes_defines(defines)
3193            defines = []
3194            cw.nl()
3195
3196        # Write kdoc for enum and flags (one day maybe also structs)
3197        if const['type'] == 'enum' or const['type'] == 'flags':
3198            enum = family.consts[const['name']]
3199
3200            if enum.header:
3201                continue
3202
3203            if enum.has_doc():
3204                if enum.has_entry_doc():
3205                    cw.p('/**')
3206                    doc = ''
3207                    if 'doc' in enum:
3208                        doc = ' - ' + enum['doc']
3209                    cw.write_doc_line(enum.enum_name + doc)
3210                else:
3211                    cw.p('/*')
3212                    cw.write_doc_line(enum['doc'], indent=False)
3213                for entry in enum.entries.values():
3214                    if entry.has_doc():
3215                        doc = '@' + entry.c_name + ': ' + entry['doc']
3216                        cw.write_doc_line(doc)
3217                cw.p(' */')
3218
3219            uapi_enum_start(family, cw, const, 'name')
3220            name_pfx = const.get('name-prefix', f"{family.ident_name}-{const['name']}-")
3221            for entry in enum.entries.values():
3222                suffix = ','
3223                if entry.value_change:
3224                    suffix = f" = {entry.user_value()}" + suffix
3225                cw.p(entry.c_name + suffix)
3226
3227            if const.get('render-max', False):
3228                cw.nl()
3229                cw.p('/* private: */')
3230                if const['type'] == 'flags':
3231                    max_name = c_upper(name_pfx + 'mask')
3232                    max_val = f' = {enum.get_mask()},'
3233                    cw.p(max_name + max_val)
3234                else:
3235                    cnt_name = enum.enum_cnt_name
3236                    max_name = c_upper(name_pfx + 'max')
3237                    if not cnt_name:
3238                        cnt_name = '__' + name_pfx + 'max'
3239                    cw.p(c_upper(cnt_name) + ',')
3240                    cw.p(max_name + ' = (' + c_upper(cnt_name) + ' - 1)')
3241            cw.block_end(line=';')
3242            cw.nl()
3243        elif const['type'] == 'const':
3244            name_pfx = const.get('name-prefix', f"{family.ident_name}-")
3245            defines.append([c_upper(family.get('c-define-name',
3246                                               f"{name_pfx}{const['name']}")),
3247                            const['value']])
3248
3249    if defines:
3250        cw.writes_defines(defines)
3251        cw.nl()
3252
3253    max_by_define = family.get('max-by-define', False)
3254
3255    for _, attr_set in family.attr_sets.items():
3256        if attr_set.subset_of:
3257            continue
3258
3259        max_value = f"({attr_set.cnt_name} - 1)"
3260
3261        val = 0
3262        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
3263        for _, attr in attr_set.items():
3264            suffix = ','
3265            if attr.value != val:
3266                suffix = f" = {attr.value},"
3267                val = attr.value
3268            val += 1
3269            cw.p(attr.enum_name + suffix)
3270        if attr_set.items():
3271            cw.nl()
3272        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
3273        if not max_by_define:
3274            cw.p(f"{attr_set.max_name} = {max_value}")
3275        cw.block_end(line=';')
3276        if max_by_define:
3277            cw.p(f"#define {attr_set.max_name} {max_value}")
3278        cw.nl()
3279
3280    # Commands
3281    separate_ntf = 'async-prefix' in family['operations']
3282
3283    if family.msg_id_model == 'unified':
3284        render_uapi_unified(family, cw, max_by_define, separate_ntf)
3285    elif family.msg_id_model == 'directional':
3286        render_uapi_directional(family, cw, max_by_define)
3287    else:
3288        raise Exception(f'Unsupported message enum-model {family.msg_id_model}')
3289
3290    if separate_ntf:
3291        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
3292        for op in family.msgs.values():
3293            if separate_ntf and not ('notify' in op or 'event' in op):
3294                continue
3295
3296            suffix = ','
3297            if 'value' in op:
3298                suffix = f" = {op['value']},"
3299            cw.p(op.enum_name + suffix)
3300        cw.block_end(line=';')
3301        cw.nl()
3302
3303    # Multicast
3304    defines = []
3305    for grp in family.mcgrps['list']:
3306        name = grp['name']
3307        defines.append([c_upper(grp.get('c-define-name', f"{family.ident_name}-mcgrp-{name}")),
3308                        f'{name}'])
3309    cw.nl()
3310    if defines:
3311        cw.writes_defines(defines)
3312        cw.nl()
3313
3314    cw.p(f'#endif /* {hdr_prot} */')
3315
3316
3317def _render_user_ntf_entry(ri, op):
3318    if not ri.family.is_classic():
3319        ri.cw.block_start(line=f"[{op.enum_name}] = ")
3320    else:
3321        crud_op = ri.family.req_by_value[op.rsp_value]
3322        ri.cw.block_start(line=f"[{crud_op.enum_name}] = ")
3323    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
3324    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
3325    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
3326    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
3327    ri.cw.block_end(line=',')
3328
3329
3330def render_user_family(family, cw, prototype):
3331    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
3332    if prototype:
3333        cw.p(f'extern {symbol};')
3334        return
3335
3336    if family.ntfs:
3337        cw.block_start(line=f"static const struct ynl_ntf_info {family.c_name}_ntf_info[] = ")
3338        for ntf_op_name, ntf_op in family.ntfs.items():
3339            if 'notify' in ntf_op:
3340                op = family.ops[ntf_op['notify']]
3341                ri = RenderInfo(cw, family, "user", op, "notify")
3342            elif 'event' in ntf_op:
3343                ri = RenderInfo(cw, family, "user", ntf_op, "event")
3344            else:
3345                raise Exception('Invalid notification ' + ntf_op_name)
3346            _render_user_ntf_entry(ri, ntf_op)
3347        for op_name, op in family.ops.items():
3348            if 'event' not in op:
3349                continue
3350            ri = RenderInfo(cw, family, "user", op, "event")
3351            _render_user_ntf_entry(ri, op)
3352        cw.block_end(line=";")
3353        cw.nl()
3354
3355    cw.block_start(f'{symbol} = ')
3356    cw.p(f'.name\t\t= "{family.c_name}",')
3357    if family.is_classic():
3358        cw.p('.is_classic\t= true,')
3359        cw.p(f'.classic_id\t= {family.get("protonum")},')
3360    if family.is_classic():
3361        if family.fixed_header:
3362            cw.p(f'.hdr_len\t= sizeof(struct {c_lower(family.fixed_header)}),')
3363    elif family.fixed_header:
3364        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
3365    else:
3366        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
3367    if family.ntfs:
3368        cw.p(f".ntf_info\t= {family.c_name}_ntf_info,")
3369        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family.c_name}_ntf_info),")
3370    cw.block_end(line=';')
3371
3372
3373def family_contains_bitfield32(family):
3374    for _, attr_set in family.attr_sets.items():
3375        if attr_set.subset_of:
3376            continue
3377        for _, attr in attr_set.items():
3378            if attr.type == "bitfield32":
3379                return True
3380    return False
3381
3382
3383def find_kernel_root(full_path):
3384    sub_path = ''
3385    while True:
3386        sub_path = os.path.join(os.path.basename(full_path), sub_path)
3387        full_path = os.path.dirname(full_path)
3388        maintainers = os.path.join(full_path, "MAINTAINERS")
3389        if os.path.exists(maintainers):
3390            return full_path, sub_path[:-1]
3391
3392
3393def main():
3394    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
3395    parser.add_argument('--mode', dest='mode', type=str, required=True,
3396                        choices=('user', 'kernel', 'uapi'))
3397    parser.add_argument('--spec', dest='spec', type=str, required=True)
3398    parser.add_argument('--header', dest='header', action='store_true', default=None)
3399    parser.add_argument('--source', dest='header', action='store_false')
3400    parser.add_argument('--user-header', nargs='+', default=[])
3401    parser.add_argument('--cmp-out', action='store_true', default=None,
3402                        help='Do not overwrite the output file if the new output is identical to the old')
3403    parser.add_argument('--exclude-op', action='append', default=[])
3404    parser.add_argument('-o', dest='out_file', type=str, default=None)
3405    args = parser.parse_args()
3406
3407    if args.header is None:
3408        parser.error("--header or --source is required")
3409
3410    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
3411
3412    try:
3413        parsed = Family(args.spec, exclude_ops)
3414        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
3415            print('Spec license:', parsed.license)
3416            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
3417            os.sys.exit(1)
3418    except yaml.YAMLError as exc:
3419        print(exc)
3420        os.sys.exit(1)
3421        return
3422
3423    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
3424
3425    _, spec_kernel = find_kernel_root(args.spec)
3426    if args.mode == 'uapi' or args.header:
3427        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
3428    else:
3429        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
3430    cw.p("/* Do not edit directly, auto-generated from: */")
3431    cw.p(f"/*\t{spec_kernel} */")
3432    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
3433    if args.exclude_op or args.user_header:
3434        line = ''
3435        line += ' --user-header '.join([''] + args.user_header)
3436        line += ' --exclude-op '.join([''] + args.exclude_op)
3437        cw.p(f'/* YNL-ARG{line} */')
3438    cw.nl()
3439
3440    if args.mode == 'uapi':
3441        render_uapi(parsed, cw)
3442        return
3443
3444    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
3445    if args.header:
3446        cw.p('#ifndef ' + hdr_prot)
3447        cw.p('#define ' + hdr_prot)
3448        cw.nl()
3449
3450    if args.out_file:
3451        hdr_file = os.path.basename(args.out_file[:-2]) + ".h"
3452    else:
3453        hdr_file = "generated_header_file.h"
3454
3455    if args.mode == 'kernel':
3456        cw.p('#include <net/netlink.h>')
3457        cw.p('#include <net/genetlink.h>')
3458        cw.nl()
3459        if not args.header:
3460            if args.out_file:
3461                cw.p(f'#include "{hdr_file}"')
3462            cw.nl()
3463        headers = ['uapi/' + parsed.uapi_header]
3464        headers += parsed.kernel_family.get('headers', [])
3465    else:
3466        cw.p('#include <stdlib.h>')
3467        cw.p('#include <string.h>')
3468        if args.header:
3469            cw.p('#include <linux/types.h>')
3470            if family_contains_bitfield32(parsed):
3471                cw.p('#include <linux/netlink.h>')
3472        else:
3473            cw.p(f'#include "{hdr_file}"')
3474            cw.p('#include "ynl.h"')
3475        headers = []
3476    for definition in parsed['definitions'] + parsed['attribute-sets']:
3477        if 'header' in definition:
3478            headers.append(definition['header'])
3479    if args.mode == 'user':
3480        headers.append(parsed.uapi_header)
3481    seen_header = []
3482    for one in headers:
3483        if one not in seen_header:
3484            cw.p(f"#include <{one}>")
3485            seen_header.append(one)
3486    cw.nl()
3487
3488    if args.mode == "user":
3489        if not args.header:
3490            cw.p("#include <linux/genetlink.h>")
3491            cw.nl()
3492            for one in args.user_header:
3493                cw.p(f'#include "{one}"')
3494        else:
3495            cw.p('struct ynl_sock;')
3496            cw.nl()
3497            render_user_family(parsed, cw, True)
3498        cw.nl()
3499
3500    if args.mode == "kernel":
3501        if args.header:
3502            for _, struct in sorted(parsed.pure_nested_structs.items()):
3503                if struct.request:
3504                    cw.p('/* Common nested types */')
3505                    break
3506            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3507                if struct.request:
3508                    print_req_policy_fwd(cw, struct)
3509            cw.nl()
3510
3511            if parsed.kernel_policy == 'global':
3512                cw.p(f"/* Global operation policy for {parsed.name} */")
3513
3514                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3515                print_req_policy_fwd(cw, struct)
3516                cw.nl()
3517
3518            if parsed.kernel_policy in {'per-op', 'split'}:
3519                for op_name, op in parsed.ops.items():
3520                    if 'do' in op and 'event' not in op:
3521                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
3522                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
3523                        cw.nl()
3524
3525            print_kernel_op_table_hdr(parsed, cw)
3526            print_kernel_mcgrp_hdr(parsed, cw)
3527            print_kernel_family_struct_hdr(parsed, cw)
3528        else:
3529            print_kernel_policy_ranges(parsed, cw)
3530            print_kernel_policy_sparse_enum_validates(parsed, cw)
3531
3532            for _, struct in sorted(parsed.pure_nested_structs.items()):
3533                if struct.request:
3534                    cw.p('/* Common nested types */')
3535                    break
3536            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3537                if struct.request:
3538                    print_req_policy(cw, struct)
3539            cw.nl()
3540
3541            if parsed.kernel_policy == 'global':
3542                cw.p(f"/* Global operation policy for {parsed.name} */")
3543
3544                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3545                print_req_policy(cw, struct)
3546                cw.nl()
3547
3548            for op_name, op in parsed.ops.items():
3549                if parsed.kernel_policy in {'per-op', 'split'}:
3550                    for op_mode in ['do', 'dump']:
3551                        if op_mode in op and 'request' in op[op_mode]:
3552                            cw.p(f"/* {op.enum_name} - {op_mode} */")
3553                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
3554                            print_req_policy(cw, ri.struct['request'], ri=ri)
3555                            cw.nl()
3556
3557            print_kernel_op_table(parsed, cw)
3558            print_kernel_mcgrp_src(parsed, cw)
3559            print_kernel_family_struct_src(parsed, cw)
3560
3561    if args.mode == "user":
3562        if args.header:
3563            cw.p('/* Enums */')
3564            put_op_name_fwd(parsed, cw)
3565
3566            for name, const in parsed.consts.items():
3567                if isinstance(const, EnumSet):
3568                    put_enum_to_str_fwd(parsed, cw, const)
3569            cw.nl()
3570
3571            cw.p('/* Common nested types */')
3572            for attr_set, struct in parsed.pure_nested_structs.items():
3573                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3574                print_type_full(ri, struct)
3575
3576            for op_name, op in parsed.ops.items():
3577                cw.p(f"/* ============== {op.enum_name} ============== */")
3578
3579                if 'do' in op and 'event' not in op:
3580                    cw.p(f"/* {op.enum_name} - do */")
3581                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3582                    print_req_type(ri)
3583                    print_req_type_helpers(ri)
3584                    cw.nl()
3585                    print_rsp_type(ri)
3586                    print_rsp_type_helpers(ri)
3587                    cw.nl()
3588                    print_req_prototype(ri)
3589                    cw.nl()
3590
3591                if 'dump' in op:
3592                    cw.p(f"/* {op.enum_name} - dump */")
3593                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
3594                    print_req_type(ri)
3595                    print_req_type_helpers(ri)
3596                    if not ri.type_consistent or ri.type_oneside:
3597                        print_rsp_type(ri)
3598                    print_wrapped_type(ri)
3599                    print_dump_prototype(ri)
3600                    cw.nl()
3601
3602                if op.has_ntf:
3603                    cw.p(f"/* {op.enum_name} - notify */")
3604                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3605                    if not ri.type_consistent:
3606                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3607                    print_wrapped_type(ri)
3608
3609            for op_name, op in parsed.ntfs.items():
3610                if 'event' in op:
3611                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
3612                    cw.p(f"/* {op.enum_name} - event */")
3613                    print_rsp_type(ri)
3614                    cw.nl()
3615                    print_wrapped_type(ri)
3616            cw.nl()
3617        else:
3618            cw.p('/* Enums */')
3619            put_op_name(parsed, cw)
3620
3621            for name, const in parsed.consts.items():
3622                if isinstance(const, EnumSet):
3623                    put_enum_to_str(parsed, cw, const)
3624            cw.nl()
3625
3626            has_recursive_nests = False
3627            cw.p('/* Policies */')
3628            for struct in parsed.pure_nested_structs.values():
3629                if struct.recursive:
3630                    put_typol_fwd(cw, struct)
3631                    has_recursive_nests = True
3632            if has_recursive_nests:
3633                cw.nl()
3634            for struct in parsed.pure_nested_structs.values():
3635                put_typol(cw, struct)
3636            for name in parsed.root_sets:
3637                struct = Struct(parsed, name)
3638                put_typol(cw, struct)
3639
3640            cw.p('/* Common nested types */')
3641            if has_recursive_nests:
3642                for attr_set, struct in parsed.pure_nested_structs.items():
3643                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3644                    free_rsp_nested_prototype(ri)
3645                    if struct.request:
3646                        put_req_nested_prototype(ri, struct)
3647                    if struct.reply:
3648                        parse_rsp_nested_prototype(ri, struct)
3649                cw.nl()
3650            for attr_set, struct in parsed.pure_nested_structs.items():
3651                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3652
3653                free_rsp_nested(ri, struct)
3654                if struct.request:
3655                    put_req_nested(ri, struct)
3656                if struct.reply:
3657                    parse_rsp_nested(ri, struct)
3658
3659            for op_name, op in parsed.ops.items():
3660                cw.p(f"/* ============== {op.enum_name} ============== */")
3661                if 'do' in op and 'event' not in op:
3662                    cw.p(f"/* {op.enum_name} - do */")
3663                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3664                    print_req_free(ri)
3665                    print_rsp_free(ri)
3666                    parse_rsp_msg(ri)
3667                    print_req(ri)
3668                    cw.nl()
3669
3670                if 'dump' in op:
3671                    cw.p(f"/* {op.enum_name} - dump */")
3672                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
3673                    if not ri.type_consistent or ri.type_oneside:
3674                        parse_rsp_msg(ri, deref=True)
3675                    print_req_free(ri)
3676                    print_dump_type_free(ri)
3677                    print_dump(ri)
3678                    cw.nl()
3679
3680                if op.has_ntf:
3681                    cw.p(f"/* {op.enum_name} - notify */")
3682                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3683                    if not ri.type_consistent:
3684                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3685                    print_ntf_type_free(ri)
3686
3687            for op_name, op in parsed.ntfs.items():
3688                if 'event' in op:
3689                    cw.p(f"/* {op.enum_name} - event */")
3690
3691                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3692                    parse_rsp_msg(ri)
3693
3694                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
3695                    print_ntf_type_free(ri)
3696            cw.nl()
3697            render_user_family(parsed, cw, False)
3698
3699    if args.header:
3700        cw.p(f'#endif /* {hdr_prot} */')
3701
3702
3703if __name__ == "__main__":
3704    main()
3705