xref: /linux/tools/net/ynl/pyynl/ynl_gen_c.py (revision 3f1c07fc21c68bd3bd2df9d2c9441f6485e934d9)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import filecmp
6import pathlib
7import os
8import re
9import shutil
10import sys
11import tempfile
12import yaml
13
14sys.path.append(pathlib.Path(__file__).resolve().parent.as_posix())
15from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
16from lib import SpecSubMessage
17
18
19def c_upper(name):
20    return name.upper().replace('-', '_')
21
22
23def c_lower(name):
24    return name.lower().replace('-', '_')
25
26
27def limit_to_number(name):
28    """
29    Turn a string limit like u32-max or s64-min into its numerical value
30    """
31    if name[0] == 'u' and name.endswith('-min'):
32        return 0
33    width = int(name[1:-4])
34    if name[0] == 's':
35        width -= 1
36    value = (1 << width) - 1
37    if name[0] == 's' and name.endswith('-min'):
38        value = -value - 1
39    return value
40
41
42class BaseNlLib:
43    def get_family_id(self):
44        return 'ys->family_id'
45
46
47class Type(SpecAttr):
48    def __init__(self, family, attr_set, attr, value):
49        super().__init__(family, attr_set, attr, value)
50
51        self.attr = attr
52        self.attr_set = attr_set
53        self.type = attr['type']
54        self.checks = attr.get('checks', {})
55
56        self.request = False
57        self.reply = False
58
59        self.is_selector = False
60
61        if 'len' in attr:
62            self.len = attr['len']
63
64        if 'nested-attributes' in attr:
65            nested = attr['nested-attributes']
66        elif 'sub-message' in attr:
67            nested = attr['sub-message']
68        else:
69            nested = None
70
71        if nested:
72            self.nested_attrs = nested
73            if self.nested_attrs == family.name:
74                self.nested_render_name = c_lower(f"{family.ident_name}")
75            else:
76                self.nested_render_name = c_lower(f"{family.ident_name}_{self.nested_attrs}")
77
78            if self.nested_attrs in self.family.consts:
79                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
80            else:
81                self.nested_struct_type = 'struct ' + self.nested_render_name
82
83        self.c_name = c_lower(self.name)
84        if self.c_name in _C_KW:
85            self.c_name += '_'
86        if self.c_name[0].isdigit():
87            self.c_name = '_' + self.c_name
88
89        # Added by resolve():
90        self.enum_name = None
91        delattr(self, "enum_name")
92
93    def _get_real_attr(self):
94        # if the attr is for a subset return the "real" attr (just one down, does not recurse)
95        return self.family.attr_sets[self.attr_set.subset_of][self.name]
96
97    def set_request(self):
98        self.request = True
99        if self.attr_set.subset_of:
100            self._get_real_attr().set_request()
101
102    def set_reply(self):
103        self.reply = True
104        if self.attr_set.subset_of:
105            self._get_real_attr().set_reply()
106
107    def get_limit(self, limit, default=None):
108        value = self.checks.get(limit, default)
109        if value is None:
110            return value
111        if isinstance(value, int):
112            return value
113        if value in self.family.consts:
114            return self.family.consts[value]["value"]
115        return limit_to_number(value)
116
117    def get_limit_str(self, limit, default=None, suffix=''):
118        value = self.checks.get(limit, default)
119        if value is None:
120            return ''
121        if isinstance(value, int):
122            return str(value) + suffix
123        if value in self.family.consts:
124            const = self.family.consts[value]
125            if const.get('header'):
126                return c_upper(value)
127            return c_upper(f"{self.family['name']}-{value}")
128        return c_upper(value)
129
130    def resolve(self):
131        if 'parent-sub-message' in self.attr:
132            enum_name = self.attr['parent-sub-message'].enum_name
133        elif 'name-prefix' in self.attr:
134            enum_name = f"{self.attr['name-prefix']}{self.name}"
135        else:
136            enum_name = f"{self.attr_set.name_prefix}{self.name}"
137        self.enum_name = c_upper(enum_name)
138
139        if self.attr_set.subset_of:
140            if self.checks != self._get_real_attr().checks:
141                raise Exception("Overriding checks not supported by codegen, yet")
142
143    def is_multi_val(self):
144        return None
145
146    def is_scalar(self):
147        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
148
149    def is_recursive(self):
150        return False
151
152    def is_recursive_for_op(self, ri):
153        return self.is_recursive() and not ri.op
154
155    def presence_type(self):
156        return 'present'
157
158    def presence_member(self, space, type_filter):
159        if self.presence_type() != type_filter:
160            return
161
162        if self.presence_type() == 'present':
163            pfx = '__' if space == 'user' else ''
164            return f"{pfx}u32 {self.c_name}:1;"
165
166        if self.presence_type() in {'len', 'count'}:
167            pfx = '__' if space == 'user' else ''
168            return f"{pfx}u32 {self.c_name};"
169
170    def _complex_member_type(self, ri):
171        return None
172
173    def free_needs_iter(self):
174        return False
175
176    def _free_lines(self, ri, var, ref):
177        if self.is_multi_val() or self.presence_type() in {'count', 'len'}:
178            return [f'free({var}->{ref}{self.c_name});']
179        return []
180
181    def free(self, ri, var, ref):
182        lines = self._free_lines(ri, var, ref)
183        for line in lines:
184            ri.cw.p(line)
185
186    def arg_member(self, ri):
187        member = self._complex_member_type(ri)
188        if member:
189            spc = ' ' if member[-1] != '*' else ''
190            arg = [member + spc + '*' + self.c_name]
191            if self.presence_type() == 'count':
192                arg += ['unsigned int n_' + self.c_name]
193            return arg
194        raise Exception(f"Struct member not implemented for class type {self.type}")
195
196    def struct_member(self, ri):
197        member = self._complex_member_type(ri)
198        if member:
199            ptr = '*' if self.is_multi_val() else ''
200            if self.is_recursive_for_op(ri):
201                ptr = '*'
202            spc = ' ' if member[-1] != '*' else ''
203            ri.cw.p(f"{member}{spc}{ptr}{self.c_name};")
204            return
205        members = self.arg_member(ri)
206        for one in members:
207            ri.cw.p(one + ';')
208
209    def _attr_policy(self, policy):
210        return '{ .type = ' + policy + ', }'
211
212    def attr_policy(self, cw):
213        policy = f'NLA_{c_upper(self.type)}'
214        if self.attr.get('byte-order') == 'big-endian':
215            if self.type in {'u16', 'u32'}:
216                policy = f'NLA_BE{self.type[1:]}'
217
218        spec = self._attr_policy(policy)
219        cw.p(f"\t[{self.enum_name}] = {spec},")
220
221    def _attr_typol(self):
222        raise Exception(f"Type policy not implemented for class type {self.type}")
223
224    def attr_typol(self, cw):
225        typol = self._attr_typol()
226        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
227
228    def _attr_put_line(self, ri, var, line):
229        presence = self.presence_type()
230        if presence in {'present', 'len'}:
231            ri.cw.p(f"if ({var}->_{presence}.{self.c_name})")
232        ri.cw.p(f"{line};")
233
234    def _attr_put_simple(self, ri, var, put_type):
235        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
236        self._attr_put_line(ri, var, line)
237
238    def attr_put(self, ri, var):
239        raise Exception(f"Put not implemented for class type {self.type}")
240
241    def _attr_get(self, ri, var):
242        raise Exception(f"Attr get not implemented for class type {self.type}")
243
244    def attr_get(self, ri, var, first):
245        lines, init_lines, _ = self._attr_get(ri, var)
246        if type(lines) is str:
247            lines = [lines]
248        if type(init_lines) is str:
249            init_lines = [init_lines]
250
251        kw = 'if' if first else 'else if'
252        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
253
254        if not self.is_multi_val():
255            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
256            ri.cw.p("return YNL_PARSE_CB_ERROR;")
257            if self.presence_type() == 'present':
258                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
259
260        if init_lines:
261            ri.cw.nl()
262            for line in init_lines:
263                ri.cw.p(line)
264
265        for line in lines:
266            ri.cw.p(line)
267        ri.cw.block_end()
268        return True
269
270    def _setter_lines(self, ri, member, presence):
271        raise Exception(f"Setter not implemented for class type {self.type}")
272
273    def setter(self, ri, space, direction, deref=False, ref=None, var="req"):
274        ref = (ref if ref else []) + [self.c_name]
275        member = f"{var}->{'.'.join(ref)}"
276
277        local_vars = []
278        if self.free_needs_iter():
279            local_vars += ['unsigned int i;']
280
281        code = []
282        presence = ''
283        for i in range(0, len(ref)):
284            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
285            # Every layer below last is a nest, so we know it uses bit presence
286            # last layer is "self" and may be a complex type
287            if i == len(ref) - 1 and self.presence_type() != 'present':
288                presence = f"{var}->{'.'.join(ref[:i] + [''])}_{self.presence_type()}.{ref[i]}"
289                continue
290            code.append(presence + ' = 1;')
291        ref_path = '.'.join(ref[:-1])
292        if ref_path:
293            ref_path += '.'
294        code += self._free_lines(ri, var, ref_path)
295        code += self._setter_lines(ri, member, presence)
296
297        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
298        free = bool([x for x in code if 'free(' in x])
299        alloc = bool([x for x in code if 'alloc(' in x])
300        if free and not alloc:
301            func_name = '__' + func_name
302        ri.cw.write_func('static inline void', func_name, local_vars=local_vars,
303                         body=code,
304                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
305
306
307class TypeUnused(Type):
308    def presence_type(self):
309        return ''
310
311    def arg_member(self, ri):
312        return []
313
314    def _attr_get(self, ri, var):
315        return ['return YNL_PARSE_CB_ERROR;'], None, None
316
317    def _attr_typol(self):
318        return '.type = YNL_PT_REJECT, '
319
320    def attr_policy(self, cw):
321        pass
322
323    def attr_put(self, ri, var):
324        pass
325
326    def attr_get(self, ri, var, first):
327        pass
328
329    def setter(self, ri, space, direction, deref=False, ref=None, var=None):
330        pass
331
332
333class TypePad(Type):
334    def presence_type(self):
335        return ''
336
337    def arg_member(self, ri):
338        return []
339
340    def _attr_typol(self):
341        return '.type = YNL_PT_IGNORE, '
342
343    def attr_put(self, ri, var):
344        pass
345
346    def attr_get(self, ri, var, first):
347        pass
348
349    def attr_policy(self, cw):
350        pass
351
352    def setter(self, ri, space, direction, deref=False, ref=None, var=None):
353        pass
354
355
356class TypeScalar(Type):
357    def __init__(self, family, attr_set, attr, value):
358        super().__init__(family, attr_set, attr, value)
359
360        self.byte_order_comment = ''
361        if 'byte-order' in attr:
362            self.byte_order_comment = f" /* {attr['byte-order']} */"
363
364        # Classic families have some funny enums, don't bother
365        # computing checks, since we only need them for kernel policies
366        if not family.is_classic():
367            self._init_checks()
368
369        # Added by resolve():
370        self.is_bitfield = None
371        delattr(self, "is_bitfield")
372        self.type_name = None
373        delattr(self, "type_name")
374
375    def resolve(self):
376        self.resolve_up(super())
377
378        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
379            self.is_bitfield = True
380        elif 'enum' in self.attr:
381            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
382        else:
383            self.is_bitfield = False
384
385        if not self.is_bitfield and 'enum' in self.attr:
386            self.type_name = self.family.consts[self.attr['enum']].user_type
387        elif self.is_auto_scalar:
388            self.type_name = '__' + self.type[0] + '64'
389        else:
390            self.type_name = '__' + self.type
391
392    def _init_checks(self):
393        if 'enum' in self.attr:
394            enum = self.family.consts[self.attr['enum']]
395            low, high = enum.value_range()
396            if low is None and high is None:
397                self.checks['sparse'] = True
398            else:
399                if 'min' not in self.checks:
400                    if low != 0 or self.type[0] == 's':
401                        self.checks['min'] = low
402                if 'max' not in self.checks:
403                    self.checks['max'] = high
404
405        if 'min' in self.checks and 'max' in self.checks:
406            if self.get_limit('min') > self.get_limit('max'):
407                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
408            self.checks['range'] = True
409
410        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
411        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
412        if low < 0 and self.type[0] == 'u':
413            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
414        if low < -32768 or high > 32767:
415            self.checks['full-range'] = True
416
417    def _attr_policy(self, policy):
418        if 'flags-mask' in self.checks or self.is_bitfield:
419            if self.is_bitfield:
420                enum = self.family.consts[self.attr['enum']]
421                mask = enum.get_mask(as_flags=True)
422            else:
423                flags = self.family.consts[self.checks['flags-mask']]
424                flag_cnt = len(flags['entries'])
425                mask = (1 << flag_cnt) - 1
426            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
427        elif 'full-range' in self.checks:
428            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
429        elif 'range' in self.checks:
430            return f"NLA_POLICY_RANGE({policy}, {self.get_limit_str('min')}, {self.get_limit_str('max')})"
431        elif 'min' in self.checks:
432            return f"NLA_POLICY_MIN({policy}, {self.get_limit_str('min')})"
433        elif 'max' in self.checks:
434            return f"NLA_POLICY_MAX({policy}, {self.get_limit_str('max')})"
435        elif 'sparse' in self.checks:
436            return f"NLA_POLICY_VALIDATE_FN({policy}, &{c_lower(self.enum_name)}_validate)"
437        return super()._attr_policy(policy)
438
439    def _attr_typol(self):
440        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
441
442    def arg_member(self, ri):
443        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
444
445    def attr_put(self, ri, var):
446        self._attr_put_simple(ri, var, self.type)
447
448    def _attr_get(self, ri, var):
449        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
450
451    def _setter_lines(self, ri, member, presence):
452        return [f"{member} = {self.c_name};"]
453
454
455class TypeFlag(Type):
456    def arg_member(self, ri):
457        return []
458
459    def _attr_typol(self):
460        return '.type = YNL_PT_FLAG, '
461
462    def attr_put(self, ri, var):
463        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
464
465    def _attr_get(self, ri, var):
466        return [], None, None
467
468    def _setter_lines(self, ri, member, presence):
469        return []
470
471
472class TypeString(Type):
473    def arg_member(self, ri):
474        return [f"const char *{self.c_name}"]
475
476    def presence_type(self):
477        return 'len'
478
479    def struct_member(self, ri):
480        ri.cw.p(f"char *{self.c_name};")
481
482    def _attr_typol(self):
483        typol = '.type = YNL_PT_NUL_STR, '
484        if self.is_selector:
485            typol += '.is_selector = 1, '
486        return typol
487
488    def _attr_policy(self, policy):
489        if 'exact-len' in self.checks:
490            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
491        else:
492            mem = '{ .type = ' + policy
493            if 'max-len' in self.checks:
494                mem += ', .len = ' + self.get_limit_str('max-len')
495            mem += ', }'
496        return mem
497
498    def attr_policy(self, cw):
499        if self.checks.get('unterminated-ok', False):
500            policy = 'NLA_STRING'
501        else:
502            policy = 'NLA_NUL_STRING'
503
504        spec = self._attr_policy(policy)
505        cw.p(f"\t[{self.enum_name}] = {spec},")
506
507    def attr_put(self, ri, var):
508        self._attr_put_simple(ri, var, 'str')
509
510    def _attr_get(self, ri, var):
511        len_mem = var + '->_len.' + self.c_name
512        return [f"{len_mem} = len;",
513                f"{var}->{self.c_name} = malloc(len + 1);",
514                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
515                f"{var}->{self.c_name}[len] = 0;"], \
516               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
517               ['unsigned int len;']
518
519    def _setter_lines(self, ri, member, presence):
520        return [f"{presence} = strlen({self.c_name});",
521                f"{member} = malloc({presence} + 1);",
522                f'memcpy({member}, {self.c_name}, {presence});',
523                f'{member}[{presence}] = 0;']
524
525
526class TypeBinary(Type):
527    def arg_member(self, ri):
528        return [f"const void *{self.c_name}", 'size_t len']
529
530    def presence_type(self):
531        return 'len'
532
533    def struct_member(self, ri):
534        ri.cw.p(f"void *{self.c_name};")
535
536    def _attr_typol(self):
537        return '.type = YNL_PT_BINARY,'
538
539    def _attr_policy(self, policy):
540        if len(self.checks) == 0:
541            pass
542        elif len(self.checks) == 1:
543            check_name = list(self.checks)[0]
544            if check_name not in {'exact-len', 'min-len', 'max-len'}:
545                raise Exception('Unsupported check for binary type: ' + check_name)
546        else:
547            raise Exception('More than one check for binary type not implemented, yet')
548
549        if len(self.checks) == 0:
550            mem = '{ .type = NLA_BINARY, }'
551        elif 'exact-len' in self.checks:
552            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
553        elif 'min-len' in self.checks:
554            mem = 'NLA_POLICY_MIN_LEN(' + self.get_limit_str('min-len') + ')'
555        elif 'max-len' in self.checks:
556            mem = 'NLA_POLICY_MAX_LEN(' + self.get_limit_str('max-len') + ')'
557
558        return mem
559
560    def attr_put(self, ri, var):
561        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
562                            f"{var}->{self.c_name}, {var}->_len.{self.c_name})")
563
564    def _attr_get(self, ri, var):
565        len_mem = var + '->_len.' + self.c_name
566        return [f"{len_mem} = len;",
567                f"{var}->{self.c_name} = malloc(len);",
568                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
569               ['len = ynl_attr_data_len(attr);'], \
570               ['unsigned int len;']
571
572    def _setter_lines(self, ri, member, presence):
573        return [f"{presence} = len;",
574                f"{member} = malloc({presence});",
575                f'memcpy({member}, {self.c_name}, {presence});']
576
577
578class TypeBinaryStruct(TypeBinary):
579    def struct_member(self, ri):
580        ri.cw.p(f'struct {c_lower(self.get("struct"))} *{self.c_name};')
581
582    def _attr_get(self, ri, var):
583        struct_sz = 'sizeof(struct ' + c_lower(self.get("struct")) + ')'
584        len_mem = var + '->_' + self.presence_type() + '.' + self.c_name
585        return [f"{len_mem} = len;",
586                f"if (len < {struct_sz})",
587                f"{var}->{self.c_name} = calloc(1, {struct_sz});",
588                "else",
589                f"{var}->{self.c_name} = malloc(len);",
590                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
591               ['len = ynl_attr_data_len(attr);'], \
592               ['unsigned int len;']
593
594
595class TypeBinaryScalarArray(TypeBinary):
596    def arg_member(self, ri):
597        return [f'__{self.get("sub-type")} *{self.c_name}', 'size_t count']
598
599    def presence_type(self):
600        return 'count'
601
602    def struct_member(self, ri):
603        ri.cw.p(f'__{self.get("sub-type")} *{self.c_name};')
604
605    def attr_put(self, ri, var):
606        presence = self.presence_type()
607        ri.cw.block_start(line=f"if ({var}->_{presence}.{self.c_name})")
608        ri.cw.p(f"i = {var}->_{presence}.{self.c_name} * sizeof(__{self.get('sub-type')});")
609        ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, " +
610                f"{var}->{self.c_name}, i);")
611        ri.cw.block_end()
612
613    def _attr_get(self, ri, var):
614        len_mem = var + '->_count.' + self.c_name
615        return [f"{len_mem} = len / sizeof(__{self.get('sub-type')});",
616                f"len = {len_mem} * sizeof(__{self.get('sub-type')});",
617                f"{var}->{self.c_name} = malloc(len);",
618                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
619               ['len = ynl_attr_data_len(attr);'], \
620               ['unsigned int len;']
621
622    def _setter_lines(self, ri, member, presence):
623        return [f"{presence} = count;",
624                f"count *= sizeof(__{self.get('sub-type')});",
625                f"{member} = malloc(count);",
626                f'memcpy({member}, {self.c_name}, count);']
627
628
629class TypeBitfield32(Type):
630    def _complex_member_type(self, ri):
631        return "struct nla_bitfield32"
632
633    def _attr_typol(self):
634        return '.type = YNL_PT_BITFIELD32, '
635
636    def _attr_policy(self, policy):
637        if 'enum' not in self.attr:
638            raise Exception('Enum required for bitfield32 attr')
639        enum = self.family.consts[self.attr['enum']]
640        mask = enum.get_mask(as_flags=True)
641        return f"NLA_POLICY_BITFIELD32({mask})"
642
643    def attr_put(self, ri, var):
644        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
645        self._attr_put_line(ri, var, line)
646
647    def _attr_get(self, ri, var):
648        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
649
650    def _setter_lines(self, ri, member, presence):
651        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
652
653
654class TypeNest(Type):
655    def is_recursive(self):
656        return self.family.pure_nested_structs[self.nested_attrs].recursive
657
658    def _complex_member_type(self, ri):
659        return self.nested_struct_type
660
661    def _free_lines(self, ri, var, ref):
662        lines = []
663        at = '&'
664        if self.is_recursive_for_op(ri):
665            at = ''
666            lines += [f'if ({var}->{ref}{self.c_name})']
667        lines += [f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});']
668        return lines
669
670    def _attr_typol(self):
671        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
672
673    def _attr_policy(self, policy):
674        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
675
676    def attr_put(self, ri, var):
677        at = '' if self.is_recursive_for_op(ri) else '&'
678        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
679                            f"{self.enum_name}, {at}{var}->{self.c_name})")
680
681    def _attr_get(self, ri, var):
682        pns = self.family.pure_nested_structs[self.nested_attrs]
683        args = ["&parg", "attr"]
684        for sel in pns.external_selectors():
685            args.append(f'{var}->{sel.name}')
686        get_lines = [f"if ({self.nested_render_name}_parse({', '.join(args)}))",
687                     "return YNL_PARSE_CB_ERROR;"]
688        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
689                      f"parg.data = &{var}->{self.c_name};"]
690        return get_lines, init_lines, None
691
692    def setter(self, ri, space, direction, deref=False, ref=None, var="req"):
693        ref = (ref if ref else []) + [self.c_name]
694
695        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
696            if attr.is_recursive():
697                continue
698            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref,
699                        var=var)
700
701
702class TypeMultiAttr(Type):
703    def __init__(self, family, attr_set, attr, value, base_type):
704        super().__init__(family, attr_set, attr, value)
705
706        self.base_type = base_type
707
708    def is_multi_val(self):
709        return True
710
711    def presence_type(self):
712        return 'count'
713
714    def _complex_member_type(self, ri):
715        if 'type' not in self.attr or self.attr['type'] == 'nest':
716            return self.nested_struct_type
717        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
718            return None  # use arg_member()
719        elif self.attr['type'] == 'string':
720            return 'struct ynl_string *'
721        elif self.attr['type'] in scalars:
722            scalar_pfx = '__' if ri.ku_space == 'user' else ''
723            if self.is_auto_scalar:
724                name = self.type[0] + '64'
725            else:
726                name = self.attr['type']
727            return scalar_pfx + name
728        else:
729            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
730
731    def arg_member(self, ri):
732        if self.type == 'binary' and 'struct' in self.attr:
733            return [f'struct {c_lower(self.attr["struct"])} *{self.c_name}',
734                    f'unsigned int n_{self.c_name}']
735        return super().arg_member(ri)
736
737    def free_needs_iter(self):
738        return self.attr['type'] in {'nest', 'string'}
739
740    def _free_lines(self, ri, var, ref):
741        lines = []
742        if self.attr['type'] in scalars:
743            lines += [f"free({var}->{ref}{self.c_name});"]
744        elif self.attr['type'] == 'binary':
745            lines += [f"free({var}->{ref}{self.c_name});"]
746        elif self.attr['type'] == 'string':
747            lines += [
748                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
749                f"free({var}->{ref}{self.c_name}[i]);",
750                f"free({var}->{ref}{self.c_name});",
751            ]
752        elif 'type' not in self.attr or self.attr['type'] == 'nest':
753            lines += [
754                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
755                f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);',
756                f"free({var}->{ref}{self.c_name});",
757            ]
758        else:
759            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
760        return lines
761
762    def _attr_policy(self, policy):
763        return self.base_type._attr_policy(policy)
764
765    def _attr_typol(self):
766        return self.base_type._attr_typol()
767
768    def _attr_get(self, ri, var):
769        return f'n_{self.c_name}++;', None, None
770
771    def attr_put(self, ri, var):
772        if self.attr['type'] in scalars:
773            put_type = self.type
774            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
775            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
776        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
777            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
778            ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}[i], sizeof(struct {c_lower(self.attr['struct'])}));")
779        elif self.attr['type'] == 'string':
780            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
781            ri.cw.p(f"ynl_attr_put_str(nlh, {self.enum_name}, {var}->{self.c_name}[i]->str);")
782        elif 'type' not in self.attr or self.attr['type'] == 'nest':
783            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
784            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
785                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
786        else:
787            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
788
789    def _setter_lines(self, ri, member, presence):
790        return [f"{member} = {self.c_name};",
791                f"{presence} = n_{self.c_name};"]
792
793
794class TypeIndexedArray(Type):
795    def is_multi_val(self):
796        return True
797
798    def presence_type(self):
799        return 'count'
800
801    def _complex_member_type(self, ri):
802        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
803            return self.nested_struct_type
804        elif self.attr['sub-type'] in scalars:
805            scalar_pfx = '__' if ri.ku_space == 'user' else ''
806            return scalar_pfx + self.attr['sub-type']
807        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
808            return None  # use arg_member()
809        else:
810            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
811
812    def arg_member(self, ri):
813        if self.sub_type == 'binary' and 'exact-len' in self.checks:
814            return [f'unsigned char (*{self.c_name})[{self.checks["exact-len"]}]',
815                    f'unsigned int n_{self.c_name}']
816        return super().arg_member(ri)
817
818    def _attr_policy(self, policy):
819        if self.attr['sub-type'] == 'nest':
820            return f'NLA_POLICY_NESTED_ARRAY({self.nested_render_name}_nl_policy)'
821        return super()._attr_policy(policy)
822
823    def _attr_typol(self):
824        if self.attr['sub-type'] in scalars:
825            return f'.type = YNL_PT_U{c_upper(self.sub_type[1:])}, '
826        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
827            return f'.type = YNL_PT_BINARY, .len = {self.checks["exact-len"]}, '
828        elif self.attr['sub-type'] == 'nest':
829            return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
830        else:
831            raise Exception(f"Typol for IndexedArray sub-type {self.attr['sub-type']} not supported, yet")
832
833    def _attr_get(self, ri, var):
834        local_vars = ['const struct nlattr *attr2;']
835        get_lines = [f'attr_{self.c_name} = attr;',
836                     'ynl_attr_for_each_nested(attr2, attr) {',
837                     '\tif (__ynl_attr_validate(yarg, attr2, type))',
838                     '\t\treturn YNL_PARSE_CB_ERROR;',
839                     f'\tn_{self.c_name}++;',
840                     '}']
841        return get_lines, None, local_vars
842
843    def attr_put(self, ri, var):
844        ri.cw.p(f'array = ynl_attr_nest_start(nlh, {self.enum_name});')
845        if self.sub_type in scalars:
846            put_type = self.sub_type
847            ri.cw.block_start(line=f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
848            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, i, {var}->{self.c_name}[i]);")
849            ri.cw.block_end()
850        elif self.sub_type == 'binary' and 'exact-len' in self.checks:
851            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
852            ri.cw.p(f"ynl_attr_put(nlh, i, {var}->{self.c_name}[i], {self.checks['exact-len']});")
853        elif self.sub_type == 'nest':
854            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
855            ri.cw.p(f"{self.nested_render_name}_put(nlh, i, &{var}->{self.c_name}[i]);")
856        else:
857            raise Exception(f"Put for IndexedArray sub-type {self.attr['sub-type']} not supported, yet")
858        ri.cw.p('ynl_attr_nest_end(nlh, array);')
859
860    def _setter_lines(self, ri, member, presence):
861        return [f"{member} = {self.c_name};",
862                f"{presence} = n_{self.c_name};"]
863
864    def free_needs_iter(self):
865        return self.sub_type == 'nest'
866
867    def _free_lines(self, ri, var, ref):
868        lines = []
869        if self.sub_type == 'nest':
870            lines += [
871                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
872                f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);',
873            ]
874        lines += f"free({var}->{ref}{self.c_name});",
875        return lines
876
877class TypeNestTypeValue(Type):
878    def _complex_member_type(self, ri):
879        return self.nested_struct_type
880
881    def _attr_typol(self):
882        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
883
884    def _attr_get(self, ri, var):
885        prev = 'attr'
886        tv_args = ''
887        get_lines = []
888        local_vars = []
889        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
890                      f"parg.data = &{var}->{self.c_name};"]
891        if 'type-value' in self.attr:
892            tv_names = [c_lower(x) for x in self.attr["type-value"]]
893            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
894            local_vars += [f'__u32 {", ".join(tv_names)};']
895            for level in self.attr["type-value"]:
896                level = c_lower(level)
897                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
898                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
899                prev = 'attr_' + level
900
901            tv_args = f", {', '.join(tv_names)}"
902
903        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
904        return get_lines, init_lines, local_vars
905
906
907class TypeSubMessage(TypeNest):
908    def __init__(self, family, attr_set, attr, value):
909        super().__init__(family, attr_set, attr, value)
910
911        self.selector = Selector(attr, attr_set)
912
913    def _attr_typol(self):
914        typol = f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
915        typol += '.is_submsg = 1, '
916        # Reverse-parsing of the policy (ynl_err_walk() in ynl.c) does not
917        # support external selectors. No family uses sub-messages with external
918        # selector for requests so this is fine for now.
919        if not self.selector.is_external():
920            typol += f'.selector_type = {self.attr_set[self["selector"]].value} '
921        return typol
922
923    def _attr_get(self, ri, var):
924        sel = c_lower(self['selector'])
925        if self.selector.is_external():
926            sel_var = f"_sel_{sel}"
927        else:
928            sel_var = f"{var}->{sel}"
929        get_lines = [f'if (!{sel_var})',
930                     'return ynl_submsg_failed(yarg, "%s", "%s");' %
931                        (self.name, self['selector']),
932                    f"if ({self.nested_render_name}_parse(&parg, {sel_var}, attr))",
933                     "return YNL_PARSE_CB_ERROR;"]
934        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
935                      f"parg.data = &{var}->{self.c_name};"]
936        return get_lines, init_lines, None
937
938
939class Selector:
940    def __init__(self, msg_attr, attr_set):
941        self.name = msg_attr["selector"]
942
943        if self.name in attr_set:
944            self.attr = attr_set[self.name]
945            self.attr.is_selector = True
946            self._external = False
947        else:
948            # The selector will need to get passed down thru the structs
949            self.attr = None
950            self._external = True
951
952    def set_attr(self, attr):
953        self.attr = attr
954
955    def is_external(self):
956        return self._external
957
958
959class Struct:
960    def __init__(self, family, space_name, type_list=None, fixed_header=None,
961                 inherited=None, submsg=None):
962        self.family = family
963        self.space_name = space_name
964        self.attr_set = family.attr_sets[space_name]
965        # Use list to catch comparisons with empty sets
966        self._inherited = inherited if inherited is not None else []
967        self.inherited = []
968        self.fixed_header = None
969        if fixed_header:
970            self.fixed_header = 'struct ' + c_lower(fixed_header)
971        self.submsg = submsg
972
973        self.nested = type_list is None
974        if family.name == c_lower(space_name):
975            self.render_name = c_lower(family.ident_name)
976        else:
977            self.render_name = c_lower(family.ident_name + '-' + space_name)
978        self.struct_name = 'struct ' + self.render_name
979        if self.nested and space_name in family.consts:
980            self.struct_name += '_'
981        self.ptr_name = self.struct_name + ' *'
982        # All attr sets this one contains, directly or multiple levels down
983        self.child_nests = set()
984
985        self.request = False
986        self.reply = False
987        self.recursive = False
988        self.in_multi_val = False  # used by a MultiAttr or and legacy arrays
989
990        self.attr_list = []
991        self.attrs = dict()
992        if type_list is not None:
993            for t in type_list:
994                self.attr_list.append((t, self.attr_set[t]),)
995        else:
996            for t in self.attr_set:
997                self.attr_list.append((t, self.attr_set[t]),)
998
999        max_val = 0
1000        self.attr_max_val = None
1001        for name, attr in self.attr_list:
1002            if attr.value >= max_val:
1003                max_val = attr.value
1004                self.attr_max_val = attr
1005            self.attrs[name] = attr
1006
1007    def __iter__(self):
1008        yield from self.attrs
1009
1010    def __getitem__(self, key):
1011        return self.attrs[key]
1012
1013    def member_list(self):
1014        return self.attr_list
1015
1016    def set_inherited(self, new_inherited):
1017        if self._inherited != new_inherited:
1018            raise Exception("Inheriting different members not supported")
1019        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
1020
1021    def external_selectors(self):
1022        sels = []
1023        for name, attr in self.attr_list:
1024            if isinstance(attr, TypeSubMessage) and attr.selector.is_external():
1025                sels.append(attr.selector)
1026        return sels
1027
1028    def free_needs_iter(self):
1029        for _, attr in self.attr_list:
1030            if attr.free_needs_iter():
1031                return True
1032        return False
1033
1034
1035class EnumEntry(SpecEnumEntry):
1036    def __init__(self, enum_set, yaml, prev, value_start):
1037        super().__init__(enum_set, yaml, prev, value_start)
1038
1039        if prev:
1040            self.value_change = (self.value != prev.value + 1)
1041        else:
1042            self.value_change = (self.value != 0)
1043        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
1044
1045        # Added by resolve:
1046        self.c_name = None
1047        delattr(self, "c_name")
1048
1049    def resolve(self):
1050        self.resolve_up(super())
1051
1052        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
1053
1054
1055class EnumSet(SpecEnumSet):
1056    def __init__(self, family, yaml):
1057        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1058
1059        if 'enum-name' in yaml:
1060            if yaml['enum-name']:
1061                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
1062                self.user_type = self.enum_name
1063            else:
1064                self.enum_name = None
1065        else:
1066            self.enum_name = 'enum ' + self.render_name
1067
1068        if self.enum_name:
1069            self.user_type = self.enum_name
1070        else:
1071            self.user_type = 'int'
1072
1073        self.value_pfx = yaml.get('name-prefix', f"{family.ident_name}-{yaml['name']}-")
1074        self.header = yaml.get('header', None)
1075        self.enum_cnt_name = yaml.get('enum-cnt-name', None)
1076
1077        super().__init__(family, yaml)
1078
1079    def new_entry(self, entry, prev_entry, value_start):
1080        return EnumEntry(self, entry, prev_entry, value_start)
1081
1082    def value_range(self):
1083        low = min([x.value for x in self.entries.values()])
1084        high = max([x.value for x in self.entries.values()])
1085
1086        if high - low + 1 != len(self.entries):
1087            return None, None
1088
1089        return low, high
1090
1091
1092class AttrSet(SpecAttrSet):
1093    def __init__(self, family, yaml):
1094        super().__init__(family, yaml)
1095
1096        if self.subset_of is None:
1097            if 'name-prefix' in yaml:
1098                pfx = yaml['name-prefix']
1099            elif self.name == family.name:
1100                pfx = family.ident_name + '-a-'
1101            else:
1102                pfx = f"{family.ident_name}-a-{self.name}-"
1103            self.name_prefix = c_upper(pfx)
1104            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
1105            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
1106        else:
1107            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
1108            self.max_name = family.attr_sets[self.subset_of].max_name
1109            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
1110
1111        # Added by resolve:
1112        self.c_name = None
1113        delattr(self, "c_name")
1114
1115    def resolve(self):
1116        self.c_name = c_lower(self.name)
1117        if self.c_name in _C_KW:
1118            self.c_name += '_'
1119        if self.c_name == self.family.c_name:
1120            self.c_name = ''
1121
1122    def new_attr(self, elem, value):
1123        if elem['type'] in scalars:
1124            t = TypeScalar(self.family, self, elem, value)
1125        elif elem['type'] == 'unused':
1126            t = TypeUnused(self.family, self, elem, value)
1127        elif elem['type'] == 'pad':
1128            t = TypePad(self.family, self, elem, value)
1129        elif elem['type'] == 'flag':
1130            t = TypeFlag(self.family, self, elem, value)
1131        elif elem['type'] == 'string':
1132            t = TypeString(self.family, self, elem, value)
1133        elif elem['type'] == 'binary':
1134            if 'struct' in elem:
1135                t = TypeBinaryStruct(self.family, self, elem, value)
1136            elif elem.get('sub-type') in scalars:
1137                t = TypeBinaryScalarArray(self.family, self, elem, value)
1138            else:
1139                t = TypeBinary(self.family, self, elem, value)
1140        elif elem['type'] == 'bitfield32':
1141            t = TypeBitfield32(self.family, self, elem, value)
1142        elif elem['type'] == 'nest':
1143            t = TypeNest(self.family, self, elem, value)
1144        elif elem['type'] == 'indexed-array' and 'sub-type' in elem:
1145            if elem["sub-type"] in ['binary', 'nest', 'u32']:
1146                t = TypeIndexedArray(self.family, self, elem, value)
1147            else:
1148                raise Exception(f'new_attr: unsupported sub-type {elem["sub-type"]}')
1149        elif elem['type'] == 'nest-type-value':
1150            t = TypeNestTypeValue(self.family, self, elem, value)
1151        elif elem['type'] == 'sub-message':
1152            t = TypeSubMessage(self.family, self, elem, value)
1153        else:
1154            raise Exception(f"No typed class for type {elem['type']}")
1155
1156        if 'multi-attr' in elem and elem['multi-attr']:
1157            t = TypeMultiAttr(self.family, self, elem, value, t)
1158
1159        return t
1160
1161
1162class Operation(SpecOperation):
1163    def __init__(self, family, yaml, req_value, rsp_value):
1164        # Fill in missing operation properties (for fixed hdr-only msgs)
1165        for mode in ['do', 'dump', 'event']:
1166            for direction in ['request', 'reply']:
1167                try:
1168                    yaml[mode][direction].setdefault('attributes', [])
1169                except KeyError:
1170                    pass
1171
1172        super().__init__(family, yaml, req_value, rsp_value)
1173
1174        self.render_name = c_lower(family.ident_name + '_' + self.name)
1175
1176        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
1177                         ('dump' in yaml and 'request' in yaml['dump'])
1178
1179        self.has_ntf = False
1180
1181        # Added by resolve:
1182        self.enum_name = None
1183        delattr(self, "enum_name")
1184
1185    def resolve(self):
1186        self.resolve_up(super())
1187
1188        if not self.is_async:
1189            self.enum_name = self.family.op_prefix + c_upper(self.name)
1190        else:
1191            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
1192
1193    def mark_has_ntf(self):
1194        self.has_ntf = True
1195
1196
1197class SubMessage(SpecSubMessage):
1198    def __init__(self, family, yaml):
1199        super().__init__(family, yaml)
1200
1201        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1202
1203    def resolve(self):
1204        self.resolve_up(super())
1205
1206
1207class Family(SpecFamily):
1208    def __init__(self, file_name, exclude_ops, fn_prefix):
1209        # Added by resolve:
1210        self.c_name = None
1211        delattr(self, "c_name")
1212        self.op_prefix = None
1213        delattr(self, "op_prefix")
1214        self.async_op_prefix = None
1215        delattr(self, "async_op_prefix")
1216        self.mcgrps = None
1217        delattr(self, "mcgrps")
1218        self.consts = None
1219        delattr(self, "consts")
1220        self.hooks = None
1221        delattr(self, "hooks")
1222
1223        super().__init__(file_name, exclude_ops=exclude_ops)
1224
1225        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
1226        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
1227
1228        if 'definitions' not in self.yaml:
1229            self.yaml['definitions'] = []
1230
1231        if 'uapi-header' in self.yaml:
1232            self.uapi_header = self.yaml['uapi-header']
1233        else:
1234            self.uapi_header = f"linux/{self.ident_name}.h"
1235        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
1236            self.uapi_header_name = self.uapi_header[6:-2]
1237        else:
1238            self.uapi_header_name = self.ident_name
1239
1240        self.fn_prefix = fn_prefix if fn_prefix else f'{self.ident_name}-nl'
1241
1242    def resolve(self):
1243        self.resolve_up(super())
1244
1245        self.c_name = c_lower(self.ident_name)
1246        if 'name-prefix' in self.yaml['operations']:
1247            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
1248        else:
1249            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
1250        if 'async-prefix' in self.yaml['operations']:
1251            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
1252        else:
1253            self.async_op_prefix = self.op_prefix
1254
1255        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
1256
1257        self.hooks = dict()
1258        for when in ['pre', 'post']:
1259            self.hooks[when] = dict()
1260            for op_mode in ['do', 'dump']:
1261                self.hooks[when][op_mode] = dict()
1262                self.hooks[when][op_mode]['set'] = set()
1263                self.hooks[when][op_mode]['list'] = []
1264
1265        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
1266        self.root_sets = dict()
1267        # dict space-name -> Struct
1268        self.pure_nested_structs = dict()
1269
1270        self._mark_notify()
1271        self._mock_up_events()
1272
1273        self._load_root_sets()
1274        self._load_nested_sets()
1275        self._load_attr_use()
1276        self._load_selector_passing()
1277        self._load_hooks()
1278
1279        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
1280        if self.kernel_policy == 'global':
1281            self._load_global_policy()
1282
1283    def new_enum(self, elem):
1284        return EnumSet(self, elem)
1285
1286    def new_attr_set(self, elem):
1287        return AttrSet(self, elem)
1288
1289    def new_operation(self, elem, req_value, rsp_value):
1290        return Operation(self, elem, req_value, rsp_value)
1291
1292    def new_sub_message(self, elem):
1293        return SubMessage(self, elem)
1294
1295    def is_classic(self):
1296        return self.proto == 'netlink-raw'
1297
1298    def _mark_notify(self):
1299        for op in self.msgs.values():
1300            if 'notify' in op:
1301                self.ops[op['notify']].mark_has_ntf()
1302
1303    # Fake a 'do' equivalent of all events, so that we can render their response parsing
1304    def _mock_up_events(self):
1305        for op in self.yaml['operations']['list']:
1306            if 'event' in op:
1307                op['do'] = {
1308                    'reply': {
1309                        'attributes': op['event']['attributes']
1310                    }
1311                }
1312
1313    def _load_root_sets(self):
1314        for op_name, op in self.msgs.items():
1315            if 'attribute-set' not in op:
1316                continue
1317
1318            req_attrs = set()
1319            rsp_attrs = set()
1320            for op_mode in ['do', 'dump']:
1321                if op_mode in op and 'request' in op[op_mode]:
1322                    req_attrs.update(set(op[op_mode]['request']['attributes']))
1323                if op_mode in op and 'reply' in op[op_mode]:
1324                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
1325            if 'event' in op:
1326                rsp_attrs.update(set(op['event']['attributes']))
1327
1328            if op['attribute-set'] not in self.root_sets:
1329                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1330            else:
1331                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1332                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1333
1334    def _sort_pure_types(self):
1335        # Try to reorder according to dependencies
1336        pns_key_list = list(self.pure_nested_structs.keys())
1337        pns_key_seen = set()
1338        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1339        for _ in range(rounds):
1340            if len(pns_key_list) == 0:
1341                break
1342            name = pns_key_list.pop(0)
1343            finished = True
1344            for _, spec in self.attr_sets[name].items():
1345                if 'nested-attributes' in spec:
1346                    nested = spec['nested-attributes']
1347                elif 'sub-message' in spec:
1348                    nested = spec.sub_message
1349                else:
1350                    continue
1351
1352                # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1353                if self.pure_nested_structs[nested].recursive:
1354                    continue
1355                if nested not in pns_key_seen:
1356                    # Dicts are sorted, this will make struct last
1357                    struct = self.pure_nested_structs.pop(name)
1358                    self.pure_nested_structs[name] = struct
1359                    finished = False
1360                    break
1361            if finished:
1362                pns_key_seen.add(name)
1363            else:
1364                pns_key_list.append(name)
1365
1366    def _load_nested_set_nest(self, spec):
1367        inherit = set()
1368        nested = spec['nested-attributes']
1369        if nested not in self.root_sets:
1370            if nested not in self.pure_nested_structs:
1371                self.pure_nested_structs[nested] = \
1372                    Struct(self, nested, inherited=inherit,
1373                           fixed_header=spec.get('fixed-header'))
1374        else:
1375            raise Exception(f'Using attr set as root and nested not supported - {nested}')
1376
1377        if 'type-value' in spec:
1378            if nested in self.root_sets:
1379                raise Exception("Inheriting members to a space used as root not supported")
1380            inherit.update(set(spec['type-value']))
1381        elif spec['type'] == 'indexed-array':
1382            inherit.add('idx')
1383        self.pure_nested_structs[nested].set_inherited(inherit)
1384
1385        return nested
1386
1387    def _load_nested_set_submsg(self, spec):
1388        # Fake the struct type for the sub-message itself
1389        # its not a attr_set but codegen wants attr_sets.
1390        submsg = self.sub_msgs[spec["sub-message"]]
1391        nested = submsg.name
1392
1393        attrs = []
1394        for name, fmt in submsg.formats.items():
1395            attr = {
1396                "name": name,
1397                "parent-sub-message": spec,
1398            }
1399            if 'attribute-set' in fmt:
1400                attr |= {
1401                    "type": "nest",
1402                    "nested-attributes": fmt['attribute-set'],
1403                }
1404                if 'fixed-header' in fmt:
1405                    attr |= { "fixed-header": fmt["fixed-header"] }
1406            elif 'fixed-header' in fmt:
1407                attr |= {
1408                    "type": "binary",
1409                    "struct": fmt["fixed-header"],
1410                }
1411            else:
1412                attr["type"] = "flag"
1413            attrs.append(attr)
1414
1415        self.attr_sets[nested] = AttrSet(self, {
1416            "name": nested,
1417            "name-pfx": self.name + '-' + spec.name + '-',
1418            "attributes": attrs
1419        })
1420
1421        if nested not in self.pure_nested_structs:
1422            self.pure_nested_structs[nested] = Struct(self, nested, submsg=submsg)
1423
1424        return nested
1425
1426    def _load_nested_sets(self):
1427        attr_set_queue = list(self.root_sets.keys())
1428        attr_set_seen = set(self.root_sets.keys())
1429
1430        while len(attr_set_queue):
1431            a_set = attr_set_queue.pop(0)
1432            for attr, spec in self.attr_sets[a_set].items():
1433                if 'nested-attributes' in spec:
1434                    nested = self._load_nested_set_nest(spec)
1435                elif 'sub-message' in spec:
1436                    nested = self._load_nested_set_submsg(spec)
1437                else:
1438                    continue
1439
1440                if nested not in attr_set_seen:
1441                    attr_set_queue.append(nested)
1442                    attr_set_seen.add(nested)
1443
1444        for root_set, rs_members in self.root_sets.items():
1445            for attr, spec in self.attr_sets[root_set].items():
1446                if 'nested-attributes' in spec:
1447                    nested = spec['nested-attributes']
1448                elif 'sub-message' in spec:
1449                    nested = spec.sub_message
1450                else:
1451                    nested = None
1452
1453                if nested:
1454                    if attr in rs_members['request']:
1455                        self.pure_nested_structs[nested].request = True
1456                    if attr in rs_members['reply']:
1457                        self.pure_nested_structs[nested].reply = True
1458
1459                    if spec.is_multi_val():
1460                        child = self.pure_nested_structs.get(nested)
1461                        child.in_multi_val = True
1462
1463        self._sort_pure_types()
1464
1465        # Propagate the request / reply / recursive
1466        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1467            for _, spec in self.attr_sets[attr_set].items():
1468                if attr_set in struct.child_nests:
1469                    struct.recursive = True
1470
1471                if 'nested-attributes' in spec:
1472                    child_name = spec['nested-attributes']
1473                elif 'sub-message' in spec:
1474                    child_name = spec.sub_message
1475                else:
1476                    continue
1477
1478                struct.child_nests.add(child_name)
1479                child = self.pure_nested_structs.get(child_name)
1480                if child:
1481                    if not child.recursive:
1482                        struct.child_nests.update(child.child_nests)
1483                    child.request |= struct.request
1484                    child.reply |= struct.reply
1485                    if spec.is_multi_val():
1486                        child.in_multi_val = True
1487
1488        self._sort_pure_types()
1489
1490    def _load_attr_use(self):
1491        for _, struct in self.pure_nested_structs.items():
1492            if struct.request:
1493                for _, arg in struct.member_list():
1494                    arg.set_request()
1495            if struct.reply:
1496                for _, arg in struct.member_list():
1497                    arg.set_reply()
1498
1499        for root_set, rs_members in self.root_sets.items():
1500            for attr, spec in self.attr_sets[root_set].items():
1501                if attr in rs_members['request']:
1502                    spec.set_request()
1503                if attr in rs_members['reply']:
1504                    spec.set_reply()
1505
1506    def _load_selector_passing(self):
1507        def all_structs():
1508            for k, v in reversed(self.pure_nested_structs.items()):
1509                yield k, v
1510            for k, _ in self.root_sets.items():
1511                yield k, None  # we don't have a struct, but it must be terminal
1512
1513        for attr_set, struct in all_structs():
1514            for _, spec in self.attr_sets[attr_set].items():
1515                if 'nested-attributes' in spec:
1516                    child_name = spec['nested-attributes']
1517                elif 'sub-message' in spec:
1518                    child_name = spec.sub_message
1519                else:
1520                    continue
1521
1522                child = self.pure_nested_structs.get(child_name)
1523                for selector in child.external_selectors():
1524                    if selector.name in self.attr_sets[attr_set]:
1525                        sel_attr = self.attr_sets[attr_set][selector.name]
1526                        selector.set_attr(sel_attr)
1527                    else:
1528                        raise Exception("Passing selector thru more than one layer not supported")
1529
1530    def _load_global_policy(self):
1531        global_set = set()
1532        attr_set_name = None
1533        for op_name, op in self.ops.items():
1534            if not op:
1535                continue
1536            if 'attribute-set' not in op:
1537                continue
1538
1539            if attr_set_name is None:
1540                attr_set_name = op['attribute-set']
1541            if attr_set_name != op['attribute-set']:
1542                raise Exception('For a global policy all ops must use the same set')
1543
1544            for op_mode in ['do', 'dump']:
1545                if op_mode in op:
1546                    req = op[op_mode].get('request')
1547                    if req:
1548                        global_set.update(req.get('attributes', []))
1549
1550        self.global_policy = []
1551        self.global_policy_set = attr_set_name
1552        for attr in self.attr_sets[attr_set_name]:
1553            if attr in global_set:
1554                self.global_policy.append(attr)
1555
1556    def _load_hooks(self):
1557        for op in self.ops.values():
1558            for op_mode in ['do', 'dump']:
1559                if op_mode not in op:
1560                    continue
1561                for when in ['pre', 'post']:
1562                    if when not in op[op_mode]:
1563                        continue
1564                    name = op[op_mode][when]
1565                    if name in self.hooks[when][op_mode]['set']:
1566                        continue
1567                    self.hooks[when][op_mode]['set'].add(name)
1568                    self.hooks[when][op_mode]['list'].append(name)
1569
1570
1571class RenderInfo:
1572    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1573        self.family = family
1574        self.nl = cw.nlib
1575        self.ku_space = ku_space
1576        self.op_mode = op_mode
1577        self.op = op
1578
1579        fixed_hdr = op.fixed_header if op else None
1580        self.fixed_hdr_len = 'ys->family->hdr_len'
1581        if op and op.fixed_header:
1582            if op.fixed_header != family.fixed_header:
1583                if family.is_classic():
1584                    self.fixed_hdr_len = f"sizeof(struct {c_lower(fixed_hdr)})"
1585                else:
1586                    raise Exception("Per-op fixed header not supported, yet")
1587
1588
1589        # 'do' and 'dump' response parsing is identical
1590        self.type_consistent = True
1591        self.type_oneside = False
1592        if op_mode != 'do' and 'dump' in op:
1593            if 'do' in op:
1594                if ('reply' in op['do']) != ('reply' in op["dump"]):
1595                    self.type_consistent = False
1596                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1597                    self.type_consistent = False
1598            else:
1599                self.type_consistent = True
1600                self.type_oneside = True
1601
1602        self.attr_set = attr_set
1603        if not self.attr_set:
1604            self.attr_set = op['attribute-set']
1605
1606        self.type_name_conflict = False
1607        if op:
1608            self.type_name = c_lower(op.name)
1609        else:
1610            self.type_name = c_lower(attr_set)
1611            if attr_set in family.consts:
1612                self.type_name_conflict = True
1613
1614        self.cw = cw
1615
1616        self.struct = dict()
1617        if op_mode == 'notify':
1618            op_mode = 'do' if 'do' in op else 'dump'
1619        for op_dir in ['request', 'reply']:
1620            if op:
1621                type_list = []
1622                if op_dir in op[op_mode]:
1623                    type_list = op[op_mode][op_dir]['attributes']
1624                self.struct[op_dir] = Struct(family, self.attr_set,
1625                                             fixed_header=fixed_hdr,
1626                                             type_list=type_list)
1627        if op_mode == 'event':
1628            self.struct['reply'] = Struct(family, self.attr_set,
1629                                          fixed_header=fixed_hdr,
1630                                          type_list=op['event']['attributes'])
1631
1632    def type_empty(self, key):
1633        return len(self.struct[key].attr_list) == 0 and \
1634            self.struct['request'].fixed_header is None
1635
1636    def needs_nlflags(self, direction):
1637        return self.op_mode == 'do' and direction == 'request' and self.family.is_classic()
1638
1639
1640class CodeWriter:
1641    def __init__(self, nlib, out_file=None, overwrite=True):
1642        self.nlib = nlib
1643        self._overwrite = overwrite
1644
1645        self._nl = False
1646        self._block_end = False
1647        self._silent_block = False
1648        self._ind = 0
1649        self._ifdef_block = None
1650        if out_file is None:
1651            self._out = os.sys.stdout
1652        else:
1653            self._out = tempfile.NamedTemporaryFile('w+')
1654            self._out_file = out_file
1655
1656    def __del__(self):
1657        self.close_out_file()
1658
1659    def close_out_file(self):
1660        if self._out == os.sys.stdout:
1661            return
1662        # Avoid modifying the file if contents didn't change
1663        self._out.flush()
1664        if not self._overwrite and os.path.isfile(self._out_file):
1665            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1666                return
1667        with open(self._out_file, 'w+') as out_file:
1668            self._out.seek(0)
1669            shutil.copyfileobj(self._out, out_file)
1670            self._out.close()
1671        self._out = os.sys.stdout
1672
1673    @classmethod
1674    def _is_cond(cls, line):
1675        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1676
1677    def p(self, line, add_ind=0):
1678        if self._block_end:
1679            self._block_end = False
1680            if line.startswith('else'):
1681                line = '} ' + line
1682            else:
1683                self._out.write('\t' * self._ind + '}\n')
1684
1685        if self._nl:
1686            self._out.write('\n')
1687            self._nl = False
1688
1689        ind = self._ind
1690        if line[-1] == ':':
1691            ind -= 1
1692        if self._silent_block:
1693            ind += 1
1694        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1695        self._silent_block |= line.strip() == 'else'
1696        if line[0] == '#':
1697            ind = 0
1698        if add_ind:
1699            ind += add_ind
1700        self._out.write('\t' * ind + line + '\n')
1701
1702    def nl(self):
1703        self._nl = True
1704
1705    def block_start(self, line=''):
1706        if line:
1707            line = line + ' '
1708        self.p(line + '{')
1709        self._ind += 1
1710
1711    def block_end(self, line=''):
1712        if line and line[0] not in {';', ','}:
1713            line = ' ' + line
1714        self._ind -= 1
1715        self._nl = False
1716        if not line:
1717            # Delay printing closing bracket in case "else" comes next
1718            if self._block_end:
1719                self._out.write('\t' * (self._ind + 1) + '}\n')
1720            self._block_end = True
1721        else:
1722            self.p('}' + line)
1723
1724    def write_doc_line(self, doc, indent=True):
1725        words = doc.split()
1726        line = ' *'
1727        for word in words:
1728            if len(line) + len(word) >= 79:
1729                self.p(line)
1730                line = ' *'
1731                if indent:
1732                    line += '  '
1733            line += ' ' + word
1734        self.p(line)
1735
1736    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1737        if not args:
1738            args = ['void']
1739
1740        if doc:
1741            self.p('/*')
1742            self.p(' * ' + doc)
1743            self.p(' */')
1744
1745        oneline = qual_ret
1746        if qual_ret[-1] != '*':
1747            oneline += ' '
1748        oneline += f"{name}({', '.join(args)}){suffix}"
1749
1750        if len(oneline) < 80:
1751            self.p(oneline)
1752            return
1753
1754        v = qual_ret
1755        if len(v) > 3:
1756            self.p(v)
1757            v = ''
1758        elif qual_ret[-1] != '*':
1759            v += ' '
1760        v += name + '('
1761        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1762        delta_ind = len(v) - len(ind)
1763        v += args[0]
1764        i = 1
1765        while i < len(args):
1766            next_len = len(v) + len(args[i])
1767            if v[0] == '\t':
1768                next_len += delta_ind
1769            if next_len > 76:
1770                self.p(v + ',')
1771                v = ind
1772            else:
1773                v += ', '
1774            v += args[i]
1775            i += 1
1776        self.p(v + ')' + suffix)
1777
1778    def write_func_lvar(self, local_vars):
1779        if not local_vars:
1780            return
1781
1782        if type(local_vars) is str:
1783            local_vars = [local_vars]
1784
1785        local_vars.sort(key=len, reverse=True)
1786        for var in local_vars:
1787            self.p(var)
1788        self.nl()
1789
1790    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1791        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1792        self.block_start()
1793        self.write_func_lvar(local_vars=local_vars)
1794
1795        for line in body:
1796            self.p(line)
1797        self.block_end()
1798
1799    def writes_defines(self, defines):
1800        longest = 0
1801        for define in defines:
1802            if len(define[0]) > longest:
1803                longest = len(define[0])
1804        longest = ((longest + 8) // 8) * 8
1805        for define in defines:
1806            line = '#define ' + define[0]
1807            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1808            if type(define[1]) is int:
1809                line += str(define[1])
1810            elif type(define[1]) is str:
1811                line += '"' + define[1] + '"'
1812            self.p(line)
1813
1814    def write_struct_init(self, members):
1815        longest = max([len(x[0]) for x in members])
1816        longest += 1  # because we prepend a .
1817        longest = ((longest + 8) // 8) * 8
1818        for one in members:
1819            line = '.' + one[0]
1820            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1821            line += '= ' + str(one[1]) + ','
1822            self.p(line)
1823
1824    def ifdef_block(self, config):
1825        config_option = None
1826        if config:
1827            config_option = 'CONFIG_' + c_upper(config)
1828        if self._ifdef_block == config_option:
1829            return
1830
1831        if self._ifdef_block:
1832            self.p('#endif /* ' + self._ifdef_block + ' */')
1833        if config_option:
1834            self.p('#ifdef ' + config_option)
1835        self._ifdef_block = config_option
1836
1837
1838scalars = {'u8', 'u16', 'u32', 'u64', 's8', 's16', 's32', 's64', 'uint', 'sint'}
1839
1840direction_to_suffix = {
1841    'reply': '_rsp',
1842    'request': '_req',
1843    '': ''
1844}
1845
1846op_mode_to_wrapper = {
1847    'do': '',
1848    'dump': '_list',
1849    'notify': '_ntf',
1850    'event': '',
1851}
1852
1853_C_KW = {
1854    'auto',
1855    'bool',
1856    'break',
1857    'case',
1858    'char',
1859    'const',
1860    'continue',
1861    'default',
1862    'do',
1863    'double',
1864    'else',
1865    'enum',
1866    'extern',
1867    'float',
1868    'for',
1869    'goto',
1870    'if',
1871    'inline',
1872    'int',
1873    'long',
1874    'register',
1875    'return',
1876    'short',
1877    'signed',
1878    'sizeof',
1879    'static',
1880    'struct',
1881    'switch',
1882    'typedef',
1883    'union',
1884    'unsigned',
1885    'void',
1886    'volatile',
1887    'while'
1888}
1889
1890
1891def rdir(direction):
1892    if direction == 'reply':
1893        return 'request'
1894    if direction == 'request':
1895        return 'reply'
1896    return direction
1897
1898
1899def op_prefix(ri, direction, deref=False):
1900    suffix = f"_{ri.type_name}"
1901
1902    if not ri.op_mode:
1903        pass
1904    elif ri.op_mode == 'do':
1905        suffix += f"{direction_to_suffix[direction]}"
1906    else:
1907        if direction == 'request':
1908            suffix += '_req'
1909            if not ri.type_oneside:
1910                suffix += '_dump'
1911        else:
1912            if ri.type_consistent:
1913                if deref:
1914                    suffix += f"{direction_to_suffix[direction]}"
1915                else:
1916                    suffix += op_mode_to_wrapper[ri.op_mode]
1917            else:
1918                suffix += '_rsp'
1919                suffix += '_dump' if deref else '_list'
1920
1921    return f"{ri.family.c_name}{suffix}"
1922
1923
1924def type_name(ri, direction, deref=False):
1925    return f"struct {op_prefix(ri, direction, deref=deref)}"
1926
1927
1928def print_prototype(ri, direction, terminate=True, doc=None):
1929    suffix = ';' if terminate else ''
1930
1931    fname = ri.op.render_name
1932    if ri.op_mode == 'dump':
1933        fname += '_dump'
1934
1935    args = ['struct ynl_sock *ys']
1936    if 'request' in ri.op[ri.op_mode]:
1937        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1938
1939    ret = 'int'
1940    if 'reply' in ri.op[ri.op_mode]:
1941        ret = f"{type_name(ri, rdir(direction))} *"
1942
1943    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1944
1945
1946def print_req_prototype(ri):
1947    print_prototype(ri, "request", doc=ri.op['doc'])
1948
1949
1950def print_dump_prototype(ri):
1951    print_prototype(ri, "request")
1952
1953
1954def put_typol_submsg(cw, struct):
1955    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[] =')
1956
1957    i = 0
1958    for name, arg in struct.member_list():
1959        nest = ""
1960        if arg.type == 'nest':
1961            nest = f" .nest = &{arg.nested_render_name}_nest,"
1962        cw.p('[%d] = { .type = YNL_PT_SUBMSG, .name = "%s",%s },' %
1963             (i, name, nest))
1964        i += 1
1965
1966    cw.block_end(line=';')
1967    cw.nl()
1968
1969    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1970    cw.p(f'.max_attr = {i - 1},')
1971    cw.p(f'.table = {struct.render_name}_policy,')
1972    cw.block_end(line=';')
1973    cw.nl()
1974
1975
1976def put_typol_fwd(cw, struct):
1977    cw.p(f'extern const struct ynl_policy_nest {struct.render_name}_nest;')
1978
1979
1980def put_typol(cw, struct):
1981    if struct.submsg:
1982        put_typol_submsg(cw, struct)
1983        return
1984
1985    type_max = struct.attr_set.max_name
1986    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1987
1988    for _, arg in struct.member_list():
1989        arg.attr_typol(cw)
1990
1991    cw.block_end(line=';')
1992    cw.nl()
1993
1994    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1995    cw.p(f'.max_attr = {type_max},')
1996    cw.p(f'.table = {struct.render_name}_policy,')
1997    cw.block_end(line=';')
1998    cw.nl()
1999
2000
2001def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
2002    args = [f'int {arg_name}']
2003    if enum:
2004        args = [enum.user_type + ' ' + arg_name]
2005    cw.write_func_prot('const char *', f'{render_name}_str', args)
2006    cw.block_start()
2007    if enum and enum.type == 'flags':
2008        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
2009    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
2010    cw.p('return NULL;')
2011    cw.p(f'return {map_name}[{arg_name}];')
2012    cw.block_end()
2013    cw.nl()
2014
2015
2016def put_op_name_fwd(family, cw):
2017    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
2018
2019
2020def put_op_name(family, cw):
2021    map_name = f'{family.c_name}_op_strmap'
2022    cw.block_start(line=f"static const char * const {map_name}[] =")
2023    for op_name, op in family.msgs.items():
2024        if op.rsp_value:
2025            # Make sure we don't add duplicated entries, if multiple commands
2026            # produce the same response in legacy families.
2027            if family.rsp_by_value[op.rsp_value] != op:
2028                cw.p(f'// skip "{op_name}", duplicate reply value')
2029                continue
2030
2031            if op.req_value == op.rsp_value:
2032                cw.p(f'[{op.enum_name}] = "{op_name}",')
2033            else:
2034                cw.p(f'[{op.rsp_value}] = "{op_name}",')
2035    cw.block_end(line=';')
2036    cw.nl()
2037
2038    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
2039
2040
2041def put_enum_to_str_fwd(family, cw, enum):
2042    args = [enum.user_type + ' value']
2043    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
2044
2045
2046def put_enum_to_str(family, cw, enum):
2047    map_name = f'{enum.render_name}_strmap'
2048    cw.block_start(line=f"static const char * const {map_name}[] =")
2049    for entry in enum.entries.values():
2050        cw.p(f'[{entry.value}] = "{entry.name}",')
2051    cw.block_end(line=';')
2052    cw.nl()
2053
2054    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
2055
2056
2057def put_local_vars(struct):
2058    local_vars = []
2059    has_array = False
2060    has_count = False
2061    for _, arg in struct.member_list():
2062        has_array |= arg.type == 'indexed-array'
2063        has_count |= arg.presence_type() == 'count'
2064    if has_array:
2065        local_vars.append('struct nlattr *array;')
2066    if has_count:
2067        local_vars.append('unsigned int i;')
2068    return local_vars
2069
2070
2071def put_req_nested_prototype(ri, struct, suffix=';'):
2072    func_args = ['struct nlmsghdr *nlh',
2073                 'unsigned int attr_type',
2074                 f'{struct.ptr_name}obj']
2075
2076    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
2077                          suffix=suffix)
2078
2079
2080def put_req_nested(ri, struct):
2081    local_vars = []
2082    init_lines = []
2083
2084    if struct.submsg is None:
2085        local_vars.append('struct nlattr *nest;')
2086        init_lines.append("nest = ynl_attr_nest_start(nlh, attr_type);")
2087    if struct.fixed_header:
2088        local_vars.append('void *hdr;')
2089        struct_sz = f'sizeof({struct.fixed_header})'
2090        init_lines.append(f"hdr = ynl_nlmsg_put_extra_header(nlh, {struct_sz});")
2091        init_lines.append(f"memcpy(hdr, &obj->_hdr, {struct_sz});")
2092
2093    local_vars += put_local_vars(struct)
2094
2095    put_req_nested_prototype(ri, struct, suffix='')
2096    ri.cw.block_start()
2097    ri.cw.write_func_lvar(local_vars)
2098
2099    for line in init_lines:
2100        ri.cw.p(line)
2101
2102    for _, arg in struct.member_list():
2103        arg.attr_put(ri, "obj")
2104
2105    if struct.submsg is None:
2106        ri.cw.p("ynl_attr_nest_end(nlh, nest);")
2107
2108    ri.cw.nl()
2109    ri.cw.p('return 0;')
2110    ri.cw.block_end()
2111    ri.cw.nl()
2112
2113
2114def _multi_parse(ri, struct, init_lines, local_vars):
2115    if struct.fixed_header:
2116        local_vars += ['void *hdr;']
2117    if struct.nested:
2118        if struct.fixed_header:
2119            iter_line = f"ynl_attr_for_each_nested_off(attr, nested, sizeof({struct.fixed_header}))"
2120        else:
2121            iter_line = "ynl_attr_for_each_nested(attr, nested)"
2122    else:
2123        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
2124        if ri.op.fixed_header != ri.family.fixed_header:
2125            if ri.family.is_classic():
2126                iter_line = f"ynl_attr_for_each(attr, nlh, sizeof({struct.fixed_header}))"
2127            else:
2128                raise Exception("Per-op fixed header not supported, yet")
2129
2130    indexed_arrays = set()
2131    multi_attrs = set()
2132    needs_parg = False
2133    var_set = set()
2134    for arg, aspec in struct.member_list():
2135        if aspec['type'] == 'indexed-array' and 'sub-type' in aspec:
2136            if aspec["sub-type"] in {'binary', 'nest'}:
2137                local_vars.append(f'const struct nlattr *attr_{aspec.c_name} = NULL;')
2138                indexed_arrays.add(arg)
2139            elif aspec['sub-type'] in scalars:
2140                local_vars.append(f'const struct nlattr *attr_{aspec.c_name} = NULL;')
2141                indexed_arrays.add(arg)
2142            else:
2143                raise Exception(f'Not supported sub-type {aspec["sub-type"]}')
2144        if 'multi-attr' in aspec:
2145            multi_attrs.add(arg)
2146        needs_parg |= 'nested-attributes' in aspec
2147        needs_parg |= 'sub-message' in aspec
2148
2149        try:
2150            _, _, l_vars = aspec._attr_get(ri, '')
2151            var_set |= set(l_vars) if l_vars else set()
2152        except Exception:
2153            pass  # _attr_get() not implemented by simple types, ignore
2154    local_vars += list(var_set)
2155    if indexed_arrays or multi_attrs:
2156        local_vars.append('int i;')
2157    if needs_parg:
2158        local_vars.append('struct ynl_parse_arg parg;')
2159        init_lines.append('parg.ys = yarg->ys;')
2160
2161    all_multi = indexed_arrays | multi_attrs
2162
2163    for arg in sorted(all_multi):
2164        local_vars.append(f"unsigned int n_{struct[arg].c_name} = 0;")
2165
2166    ri.cw.block_start()
2167    ri.cw.write_func_lvar(local_vars)
2168
2169    for line in init_lines:
2170        ri.cw.p(line)
2171    ri.cw.nl()
2172
2173    for arg in struct.inherited:
2174        ri.cw.p(f'dst->{arg} = {arg};')
2175
2176    if struct.fixed_header:
2177        if struct.nested:
2178            ri.cw.p('hdr = ynl_attr_data(nested);')
2179        elif ri.family.is_classic():
2180            ri.cw.p('hdr = ynl_nlmsg_data(nlh);')
2181        else:
2182            ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
2183        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({struct.fixed_header}));")
2184    for arg in sorted(all_multi):
2185        aspec = struct[arg]
2186        ri.cw.p(f"if (dst->{aspec.c_name})")
2187        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
2188
2189    ri.cw.nl()
2190    ri.cw.block_start(line=iter_line)
2191    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
2192    ri.cw.nl()
2193
2194    first = True
2195    for _, arg in struct.member_list():
2196        good = arg.attr_get(ri, 'dst', first=first)
2197        # First may be 'unused' or 'pad', ignore those
2198        first &= not good
2199
2200    ri.cw.block_end()
2201    ri.cw.nl()
2202
2203    for arg in sorted(indexed_arrays):
2204        aspec = struct[arg]
2205
2206        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2207        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2208        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2209        ri.cw.p('i = 0;')
2210        if 'nested-attributes' in aspec:
2211            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2212        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
2213        if 'nested-attributes' in aspec:
2214            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2215            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
2216            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2217        elif aspec.sub_type in scalars:
2218            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.sub_type}(attr);")
2219        elif aspec.sub_type == 'binary' and 'exact-len' in aspec.checks:
2220            # Length is validated by typol
2221            ri.cw.p(f'memcpy(dst->{aspec.c_name}[i], ynl_attr_data(attr), {aspec.checks["exact-len"]});')
2222        else:
2223            raise Exception(f"Nest parsing type not supported in {aspec['name']}")
2224        ri.cw.p('i++;')
2225        ri.cw.block_end()
2226        ri.cw.block_end()
2227    ri.cw.nl()
2228
2229    for arg in sorted(multi_attrs):
2230        aspec = struct[arg]
2231        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2232        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2233        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2234        ri.cw.p('i = 0;')
2235        if 'nested-attributes' in aspec:
2236            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2237        ri.cw.block_start(line=iter_line)
2238        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
2239        if 'nested-attributes' in aspec:
2240            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2241            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
2242            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2243        elif aspec.type in scalars:
2244            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
2245        elif aspec.type == 'binary' and 'struct' in aspec:
2246            ri.cw.p('size_t len = ynl_attr_data_len(attr);')
2247            ri.cw.nl()
2248            ri.cw.p(f'if (len > sizeof(dst->{aspec.c_name}[0]))')
2249            ri.cw.p(f'len = sizeof(dst->{aspec.c_name}[0]);')
2250            ri.cw.p(f"memcpy(&dst->{aspec.c_name}[i], ynl_attr_data(attr), len);")
2251        elif aspec.type == 'string':
2252            ri.cw.p('unsigned int len;')
2253            ri.cw.nl()
2254            ri.cw.p('len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));')
2255            ri.cw.p(f'dst->{aspec.c_name}[i] = malloc(sizeof(struct ynl_string) + len + 1);')
2256            ri.cw.p(f"dst->{aspec.c_name}[i]->len = len;")
2257            ri.cw.p(f"memcpy(dst->{aspec.c_name}[i]->str, ynl_attr_get_str(attr), len);")
2258            ri.cw.p(f"dst->{aspec.c_name}[i]->str[len] = 0;")
2259        else:
2260            raise Exception(f'Nest parsing of type {aspec.type} not supported yet')
2261        ri.cw.p('i++;')
2262        ri.cw.block_end()
2263        ri.cw.block_end()
2264        ri.cw.block_end()
2265    ri.cw.nl()
2266
2267    if struct.nested:
2268        ri.cw.p('return 0;')
2269    else:
2270        ri.cw.p('return YNL_PARSE_CB_OK;')
2271    ri.cw.block_end()
2272    ri.cw.nl()
2273
2274
2275def parse_rsp_submsg(ri, struct):
2276    parse_rsp_nested_prototype(ri, struct, suffix='')
2277
2278    var = 'dst'
2279    local_vars = {'const struct nlattr *attr = nested;',
2280                  f'{struct.ptr_name}{var} = yarg->data;',
2281                  'struct ynl_parse_arg parg;'}
2282
2283    for _, arg in struct.member_list():
2284        _, _, l_vars = arg._attr_get(ri, var)
2285        local_vars |= set(l_vars) if l_vars else set()
2286
2287    ri.cw.block_start()
2288    ri.cw.write_func_lvar(list(local_vars))
2289    ri.cw.p('parg.ys = yarg->ys;')
2290    ri.cw.nl()
2291
2292    first = True
2293    for name, arg in struct.member_list():
2294        kw = 'if' if first else 'else if'
2295        first = False
2296
2297        ri.cw.block_start(line=f'{kw} (!strcmp(sel, "{name}"))')
2298        get_lines, init_lines, _ = arg._attr_get(ri, var)
2299        for line in init_lines or []:
2300            ri.cw.p(line)
2301        for line in get_lines:
2302            ri.cw.p(line)
2303        if arg.presence_type() == 'present':
2304            ri.cw.p(f"{var}->_present.{arg.c_name} = 1;")
2305        ri.cw.block_end()
2306    ri.cw.p('return 0;')
2307    ri.cw.block_end()
2308    ri.cw.nl()
2309
2310
2311def parse_rsp_nested_prototype(ri, struct, suffix=';'):
2312    func_args = ['struct ynl_parse_arg *yarg',
2313                 'const struct nlattr *nested']
2314    for sel in struct.external_selectors():
2315        func_args.append('const char *_sel_' + sel.name)
2316    if struct.submsg:
2317        func_args.insert(1, 'const char *sel')
2318    for arg in struct.inherited:
2319        func_args.append('__u32 ' + arg)
2320
2321    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
2322                          suffix=suffix)
2323
2324
2325def parse_rsp_nested(ri, struct):
2326    if struct.submsg:
2327        return parse_rsp_submsg(ri, struct)
2328
2329    parse_rsp_nested_prototype(ri, struct, suffix='')
2330
2331    local_vars = ['const struct nlattr *attr;',
2332                  f'{struct.ptr_name}dst = yarg->data;']
2333    init_lines = []
2334
2335    if struct.member_list():
2336        _multi_parse(ri, struct, init_lines, local_vars)
2337    else:
2338        # Empty nest
2339        ri.cw.block_start()
2340        ri.cw.p('return 0;')
2341        ri.cw.block_end()
2342        ri.cw.nl()
2343
2344
2345def parse_rsp_msg(ri, deref=False):
2346    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
2347        return
2348
2349    func_args = ['const struct nlmsghdr *nlh',
2350                 'struct ynl_parse_arg *yarg']
2351
2352    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
2353                  'const struct nlattr *attr;']
2354    init_lines = ['dst = yarg->data;']
2355
2356    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
2357
2358    if ri.struct["reply"].member_list():
2359        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
2360    else:
2361        # Empty reply
2362        ri.cw.block_start()
2363        ri.cw.p('return YNL_PARSE_CB_OK;')
2364        ri.cw.block_end()
2365        ri.cw.nl()
2366
2367
2368def print_req(ri):
2369    ret_ok = '0'
2370    ret_err = '-1'
2371    direction = "request"
2372    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
2373                  'struct nlmsghdr *nlh;',
2374                  'int err;']
2375
2376    if 'reply' in ri.op[ri.op_mode]:
2377        ret_ok = 'rsp'
2378        ret_err = 'NULL'
2379        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
2380
2381    if ri.struct["request"].fixed_header:
2382        local_vars += ['size_t hdr_len;',
2383                       'void *hdr;']
2384
2385    local_vars += put_local_vars(ri.struct['request'])
2386
2387    print_prototype(ri, direction, terminate=False)
2388    ri.cw.block_start()
2389    ri.cw.write_func_lvar(local_vars)
2390
2391    if ri.family.is_classic():
2392        ri.cw.p(f"nlh = ynl_msg_start_req(ys, {ri.op.enum_name}, req->_nlmsg_flags);")
2393    else:
2394        ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2395
2396    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2397    ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2398    if 'reply' in ri.op[ri.op_mode]:
2399        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2400    ri.cw.nl()
2401
2402    if ri.struct['request'].fixed_header:
2403        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2404        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2405        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2406        ri.cw.nl()
2407
2408    for _, attr in ri.struct["request"].member_list():
2409        attr.attr_put(ri, "req")
2410    ri.cw.nl()
2411
2412    if 'reply' in ri.op[ri.op_mode]:
2413        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
2414        ri.cw.p('yrs.yarg.data = rsp;')
2415        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
2416        if ri.op.value is not None:
2417            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
2418        else:
2419            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
2420        ri.cw.nl()
2421    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
2422    ri.cw.p('if (err < 0)')
2423    if 'reply' in ri.op[ri.op_mode]:
2424        ri.cw.p('goto err_free;')
2425    else:
2426        ri.cw.p('return -1;')
2427    ri.cw.nl()
2428
2429    ri.cw.p(f"return {ret_ok};")
2430    ri.cw.nl()
2431
2432    if 'reply' in ri.op[ri.op_mode]:
2433        ri.cw.p('err_free:')
2434        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
2435        ri.cw.p(f"return {ret_err};")
2436
2437    ri.cw.block_end()
2438
2439
2440def print_dump(ri):
2441    direction = "request"
2442    print_prototype(ri, direction, terminate=False)
2443    ri.cw.block_start()
2444    local_vars = ['struct ynl_dump_state yds = {};',
2445                  'struct nlmsghdr *nlh;',
2446                  'int err;']
2447
2448    if ri.struct['request'].fixed_header:
2449        local_vars += ['size_t hdr_len;',
2450                       'void *hdr;']
2451
2452    if 'request' in ri.op[ri.op_mode]:
2453        local_vars += put_local_vars(ri.struct['request'])
2454
2455    ri.cw.write_func_lvar(local_vars)
2456
2457    ri.cw.p('yds.yarg.ys = ys;')
2458    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2459    ri.cw.p("yds.yarg.data = NULL;")
2460    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
2461    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
2462    if ri.op.value is not None:
2463        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
2464    else:
2465        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
2466    ri.cw.nl()
2467    if ri.family.is_classic():
2468        ri.cw.p(f"nlh = ynl_msg_start_dump(ys, {ri.op.enum_name});")
2469    else:
2470        ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2471
2472    if ri.struct['request'].fixed_header:
2473        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2474        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2475        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2476        ri.cw.nl()
2477
2478    if "request" in ri.op[ri.op_mode]:
2479        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2480        ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2481        ri.cw.nl()
2482        for _, attr in ri.struct["request"].member_list():
2483            attr.attr_put(ri, "req")
2484    ri.cw.nl()
2485
2486    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
2487    ri.cw.p('if (err < 0)')
2488    ri.cw.p('goto free_list;')
2489    ri.cw.nl()
2490
2491    ri.cw.p('return yds.first;')
2492    ri.cw.nl()
2493    ri.cw.p('free_list:')
2494    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
2495    ri.cw.p('return NULL;')
2496    ri.cw.block_end()
2497
2498
2499def call_free(ri, direction, var):
2500    return f"{op_prefix(ri, direction)}_free({var});"
2501
2502
2503def free_arg_name(direction):
2504    if direction:
2505        return direction_to_suffix[direction][1:]
2506    return 'obj'
2507
2508
2509def print_alloc_wrapper(ri, direction, struct=None):
2510    name = op_prefix(ri, direction)
2511    struct_name = name
2512    if ri.type_name_conflict:
2513        struct_name += '_'
2514
2515    args = ["void"]
2516    cnt = "1"
2517    if struct and struct.in_multi_val:
2518        args = ["unsigned int n"]
2519        cnt = "n"
2520
2521    ri.cw.write_func_prot(f'static inline struct {struct_name} *',
2522                          f"{name}_alloc", args)
2523    ri.cw.block_start()
2524    ri.cw.p(f'return calloc({cnt}, sizeof(struct {struct_name}));')
2525    ri.cw.block_end()
2526
2527
2528def print_free_prototype(ri, direction, suffix=';'):
2529    name = op_prefix(ri, direction)
2530    struct_name = name
2531    if ri.type_name_conflict:
2532        struct_name += '_'
2533    arg = free_arg_name(direction)
2534    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
2535
2536
2537def print_nlflags_set(ri, direction):
2538    name = op_prefix(ri, direction)
2539    ri.cw.write_func_prot('static inline void', f"{name}_set_nlflags",
2540                          [f"struct {name} *req", "__u16 nl_flags"])
2541    ri.cw.block_start()
2542    ri.cw.p('req->_nlmsg_flags = nl_flags;')
2543    ri.cw.block_end()
2544    ri.cw.nl()
2545
2546
2547def _print_type(ri, direction, struct):
2548    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
2549    if not direction and ri.type_name_conflict:
2550        suffix += '_'
2551
2552    if ri.op_mode == 'dump' and not ri.type_oneside:
2553        suffix += '_dump'
2554
2555    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
2556
2557    if ri.needs_nlflags(direction):
2558        ri.cw.p('__u16 _nlmsg_flags;')
2559        ri.cw.nl()
2560    if struct.fixed_header:
2561        ri.cw.p(struct.fixed_header + ' _hdr;')
2562        ri.cw.nl()
2563
2564    for type_filter in ['present', 'len', 'count']:
2565        meta_started = False
2566        for _, attr in struct.member_list():
2567            line = attr.presence_member(ri.ku_space, type_filter)
2568            if line:
2569                if not meta_started:
2570                    ri.cw.block_start(line="struct")
2571                    meta_started = True
2572                ri.cw.p(line)
2573        if meta_started:
2574            ri.cw.block_end(line=f'_{type_filter};')
2575    ri.cw.nl()
2576
2577    for arg in struct.inherited:
2578        ri.cw.p(f"__u32 {arg};")
2579
2580    for _, attr in struct.member_list():
2581        attr.struct_member(ri)
2582
2583    ri.cw.block_end(line=';')
2584    ri.cw.nl()
2585
2586
2587def print_type(ri, direction):
2588    _print_type(ri, direction, ri.struct[direction])
2589
2590
2591def print_type_full(ri, struct):
2592    _print_type(ri, "", struct)
2593
2594    if struct.request and struct.in_multi_val:
2595        print_alloc_wrapper(ri, "", struct)
2596        ri.cw.nl()
2597        free_rsp_nested_prototype(ri)
2598        ri.cw.nl()
2599
2600        # Name conflicts are too hard to deal with with the current code base,
2601        # they are very rare so don't bother printing setters in that case.
2602        if ri.ku_space == 'user' and not ri.type_name_conflict:
2603            for _, attr in struct.member_list():
2604                attr.setter(ri, ri.attr_set, "", var="obj")
2605        ri.cw.nl()
2606
2607
2608def print_type_helpers(ri, direction, deref=False):
2609    print_free_prototype(ri, direction)
2610    ri.cw.nl()
2611
2612    if ri.needs_nlflags(direction):
2613        print_nlflags_set(ri, direction)
2614
2615    if ri.ku_space == 'user' and direction == 'request':
2616        for _, attr in ri.struct[direction].member_list():
2617            attr.setter(ri, ri.attr_set, direction, deref=deref)
2618    ri.cw.nl()
2619
2620
2621def print_req_type_helpers(ri):
2622    if ri.type_empty("request"):
2623        return
2624    print_alloc_wrapper(ri, "request")
2625    print_type_helpers(ri, "request")
2626
2627
2628def print_rsp_type_helpers(ri):
2629    if 'reply' not in ri.op[ri.op_mode]:
2630        return
2631    print_type_helpers(ri, "reply")
2632
2633
2634def print_parse_prototype(ri, direction, terminate=True):
2635    suffix = "_rsp" if direction == "reply" else "_req"
2636    term = ';' if terminate else ''
2637
2638    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
2639                          ['const struct nlattr **tb',
2640                           f"struct {ri.op.render_name}{suffix} *req"],
2641                          suffix=term)
2642
2643
2644def print_req_type(ri):
2645    if ri.type_empty("request"):
2646        return
2647    print_type(ri, "request")
2648
2649
2650def print_req_free(ri):
2651    if 'request' not in ri.op[ri.op_mode]:
2652        return
2653    _free_type(ri, 'request', ri.struct['request'])
2654
2655
2656def print_rsp_type(ri):
2657    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2658        direction = 'reply'
2659    elif ri.op_mode == 'event':
2660        direction = 'reply'
2661    else:
2662        return
2663    print_type(ri, direction)
2664
2665
2666def print_wrapped_type(ri):
2667    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2668    if ri.op_mode == 'dump':
2669        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2670    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2671        ri.cw.p('__u16 family;')
2672        ri.cw.p('__u8 cmd;')
2673        ri.cw.p('struct ynl_ntf_base_type *next;')
2674        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2675    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2676    ri.cw.block_end(line=';')
2677    ri.cw.nl()
2678    print_free_prototype(ri, 'reply')
2679    ri.cw.nl()
2680
2681
2682def _free_type_members_iter(ri, struct):
2683    if struct.free_needs_iter():
2684        ri.cw.p('unsigned int i;')
2685        ri.cw.nl()
2686
2687
2688def _free_type_members(ri, var, struct, ref=''):
2689    for _, attr in struct.member_list():
2690        attr.free(ri, var, ref)
2691
2692
2693def _free_type(ri, direction, struct):
2694    var = free_arg_name(direction)
2695
2696    print_free_prototype(ri, direction, suffix='')
2697    ri.cw.block_start()
2698    _free_type_members_iter(ri, struct)
2699    _free_type_members(ri, var, struct)
2700    if direction:
2701        ri.cw.p(f'free({var});')
2702    ri.cw.block_end()
2703    ri.cw.nl()
2704
2705
2706def free_rsp_nested_prototype(ri):
2707        print_free_prototype(ri, "")
2708
2709
2710def free_rsp_nested(ri, struct):
2711    _free_type(ri, "", struct)
2712
2713
2714def print_rsp_free(ri):
2715    if 'reply' not in ri.op[ri.op_mode]:
2716        return
2717    _free_type(ri, 'reply', ri.struct['reply'])
2718
2719
2720def print_dump_type_free(ri):
2721    sub_type = type_name(ri, 'reply')
2722
2723    print_free_prototype(ri, 'reply', suffix='')
2724    ri.cw.block_start()
2725    ri.cw.p(f"{sub_type} *next = rsp;")
2726    ri.cw.nl()
2727    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2728    _free_type_members_iter(ri, ri.struct['reply'])
2729    ri.cw.p('rsp = next;')
2730    ri.cw.p('next = rsp->next;')
2731    ri.cw.nl()
2732
2733    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2734    ri.cw.p('free(rsp);')
2735    ri.cw.block_end()
2736    ri.cw.block_end()
2737    ri.cw.nl()
2738
2739
2740def print_ntf_type_free(ri):
2741    print_free_prototype(ri, 'reply', suffix='')
2742    ri.cw.block_start()
2743    _free_type_members_iter(ri, ri.struct['reply'])
2744    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2745    ri.cw.p('free(rsp);')
2746    ri.cw.block_end()
2747    ri.cw.nl()
2748
2749
2750def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2751    if terminate and ri and policy_should_be_static(struct.family):
2752        return
2753
2754    if terminate:
2755        prefix = 'extern '
2756    else:
2757        if ri and policy_should_be_static(struct.family):
2758            prefix = 'static '
2759        else:
2760            prefix = ''
2761
2762    suffix = ';' if terminate else ' = {'
2763
2764    max_attr = struct.attr_max_val
2765    if ri:
2766        name = ri.op.render_name
2767        if ri.op.dual_policy:
2768            name += '_' + ri.op_mode
2769    else:
2770        name = struct.render_name
2771    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2772
2773
2774def print_req_policy(cw, struct, ri=None):
2775    if ri and ri.op:
2776        cw.ifdef_block(ri.op.get('config-cond', None))
2777    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2778    for _, arg in struct.member_list():
2779        arg.attr_policy(cw)
2780    cw.p("};")
2781    cw.ifdef_block(None)
2782    cw.nl()
2783
2784
2785def kernel_can_gen_family_struct(family):
2786    return family.proto == 'genetlink'
2787
2788
2789def policy_should_be_static(family):
2790    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2791
2792
2793def print_kernel_policy_ranges(family, cw):
2794    first = True
2795    for _, attr_set in family.attr_sets.items():
2796        if attr_set.subset_of:
2797            continue
2798
2799        for _, attr in attr_set.items():
2800            if not attr.request:
2801                continue
2802            if 'full-range' not in attr.checks:
2803                continue
2804
2805            if first:
2806                cw.p('/* Integer value ranges */')
2807                first = False
2808
2809            sign = '' if attr.type[0] == 'u' else '_signed'
2810            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2811            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2812            members = []
2813            if 'min' in attr.checks:
2814                members.append(('min', attr.get_limit_str('min', suffix=suffix)))
2815            if 'max' in attr.checks:
2816                members.append(('max', attr.get_limit_str('max', suffix=suffix)))
2817            cw.write_struct_init(members)
2818            cw.block_end(line=';')
2819            cw.nl()
2820
2821
2822def print_kernel_policy_sparse_enum_validates(family, cw):
2823    first = True
2824    for _, attr_set in family.attr_sets.items():
2825        if attr_set.subset_of:
2826            continue
2827
2828        for _, attr in attr_set.items():
2829            if not attr.request:
2830                continue
2831            if not attr.enum_name:
2832                continue
2833            if 'sparse' not in attr.checks:
2834                continue
2835
2836            if first:
2837                cw.p('/* Sparse enums validation callbacks */')
2838                first = False
2839
2840            cw.write_func_prot('static int', f'{c_lower(attr.enum_name)}_validate',
2841                               ['const struct nlattr *attr', 'struct netlink_ext_ack *extack'])
2842            cw.block_start()
2843            cw.block_start(line=f'switch (nla_get_{attr["type"]}(attr))')
2844            enum = family.consts[attr['enum']]
2845            first_entry = True
2846            for entry in enum.entries.values():
2847                if first_entry:
2848                    first_entry = False
2849                else:
2850                    cw.p('fallthrough;')
2851                cw.p(f'case {entry.c_name}:')
2852            cw.p('return 0;')
2853            cw.block_end()
2854            cw.p('NL_SET_ERR_MSG_ATTR(extack, attr, "invalid enum value");')
2855            cw.p('return -EINVAL;')
2856            cw.block_end()
2857            cw.nl()
2858
2859
2860def print_kernel_op_table_fwd(family, cw, terminate):
2861    exported = not kernel_can_gen_family_struct(family)
2862
2863    if not terminate or exported:
2864        cw.p(f"/* Ops table for {family.ident_name} */")
2865
2866        pol_to_struct = {'global': 'genl_small_ops',
2867                         'per-op': 'genl_ops',
2868                         'split': 'genl_split_ops'}
2869        struct_type = pol_to_struct[family.kernel_policy]
2870
2871        if not exported:
2872            cnt = ""
2873        elif family.kernel_policy == 'split':
2874            cnt = 0
2875            for op in family.ops.values():
2876                if 'do' in op:
2877                    cnt += 1
2878                if 'dump' in op:
2879                    cnt += 1
2880        else:
2881            cnt = len(family.ops)
2882
2883        qual = 'static const' if not exported else 'const'
2884        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2885        if terminate:
2886            cw.p(f"extern {line};")
2887        else:
2888            cw.block_start(line=line + ' =')
2889
2890    if not terminate:
2891        return
2892
2893    cw.nl()
2894    for name in family.hooks['pre']['do']['list']:
2895        cw.write_func_prot('int', c_lower(name),
2896                           ['const struct genl_split_ops *ops',
2897                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2898    for name in family.hooks['post']['do']['list']:
2899        cw.write_func_prot('void', c_lower(name),
2900                           ['const struct genl_split_ops *ops',
2901                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2902    for name in family.hooks['pre']['dump']['list']:
2903        cw.write_func_prot('int', c_lower(name),
2904                           ['struct netlink_callback *cb'], suffix=';')
2905    for name in family.hooks['post']['dump']['list']:
2906        cw.write_func_prot('int', c_lower(name),
2907                           ['struct netlink_callback *cb'], suffix=';')
2908
2909    cw.nl()
2910
2911    for op_name, op in family.ops.items():
2912        if op.is_async:
2913            continue
2914
2915        if 'do' in op:
2916            name = c_lower(f"{family.fn_prefix}-{op_name}-doit")
2917            cw.write_func_prot('int', name,
2918                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2919
2920        if 'dump' in op:
2921            name = c_lower(f"{family.fn_prefix}-{op_name}-dumpit")
2922            cw.write_func_prot('int', name,
2923                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2924    cw.nl()
2925
2926
2927def print_kernel_op_table_hdr(family, cw):
2928    print_kernel_op_table_fwd(family, cw, terminate=True)
2929
2930
2931def print_kernel_op_table(family, cw):
2932    print_kernel_op_table_fwd(family, cw, terminate=False)
2933    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2934        for op_name, op in family.ops.items():
2935            if op.is_async:
2936                continue
2937
2938            cw.ifdef_block(op.get('config-cond', None))
2939            cw.block_start()
2940            members = [('cmd', op.enum_name)]
2941            if 'dont-validate' in op:
2942                members.append(('validate',
2943                                ' | '.join([c_upper('genl-dont-validate-' + x)
2944                                            for x in op['dont-validate']])), )
2945            for op_mode in ['do', 'dump']:
2946                if op_mode in op:
2947                    name = c_lower(f"{family.fn_prefix}-{op_name}-{op_mode}it")
2948                    members.append((op_mode + 'it', name))
2949            if family.kernel_policy == 'per-op':
2950                struct = Struct(family, op['attribute-set'],
2951                                type_list=op['do']['request']['attributes'])
2952
2953                name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2954                members.append(('policy', name))
2955                members.append(('maxattr', struct.attr_max_val.enum_name))
2956            if 'flags' in op:
2957                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2958            cw.write_struct_init(members)
2959            cw.block_end(line=',')
2960    elif family.kernel_policy == 'split':
2961        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2962                    'dump': {'pre': 'start', 'post': 'done'}}
2963
2964        for op_name, op in family.ops.items():
2965            for op_mode in ['do', 'dump']:
2966                if op.is_async or op_mode not in op:
2967                    continue
2968
2969                cw.ifdef_block(op.get('config-cond', None))
2970                cw.block_start()
2971                members = [('cmd', op.enum_name)]
2972                if 'dont-validate' in op:
2973                    dont_validate = []
2974                    for x in op['dont-validate']:
2975                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2976                            continue
2977                        if op_mode == "dump" and x == 'strict':
2978                            continue
2979                        dont_validate.append(x)
2980
2981                    if dont_validate:
2982                        members.append(('validate',
2983                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2984                                                    for x in dont_validate])), )
2985                name = c_lower(f"{family.fn_prefix}-{op_name}-{op_mode}it")
2986                if 'pre' in op[op_mode]:
2987                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2988                members.append((op_mode + 'it', name))
2989                if 'post' in op[op_mode]:
2990                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2991                if 'request' in op[op_mode]:
2992                    struct = Struct(family, op['attribute-set'],
2993                                    type_list=op[op_mode]['request']['attributes'])
2994
2995                    if op.dual_policy:
2996                        name = c_lower(f"{family.ident_name}-{op_name}-{op_mode}-nl-policy")
2997                    else:
2998                        name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2999                    members.append(('policy', name))
3000                    members.append(('maxattr', struct.attr_max_val.enum_name))
3001                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
3002                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
3003                cw.write_struct_init(members)
3004                cw.block_end(line=',')
3005    cw.ifdef_block(None)
3006
3007    cw.block_end(line=';')
3008    cw.nl()
3009
3010
3011def print_kernel_mcgrp_hdr(family, cw):
3012    if not family.mcgrps['list']:
3013        return
3014
3015    cw.block_start('enum')
3016    for grp in family.mcgrps['list']:
3017        grp_id = c_upper(f"{family.ident_name}-nlgrp-{grp['name']},")
3018        cw.p(grp_id)
3019    cw.block_end(';')
3020    cw.nl()
3021
3022
3023def print_kernel_mcgrp_src(family, cw):
3024    if not family.mcgrps['list']:
3025        return
3026
3027    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
3028    for grp in family.mcgrps['list']:
3029        name = grp['name']
3030        grp_id = c_upper(f"{family.ident_name}-nlgrp-{name}")
3031        cw.p('[' + grp_id + '] = { "' + name + '", },')
3032    cw.block_end(';')
3033    cw.nl()
3034
3035
3036def print_kernel_family_struct_hdr(family, cw):
3037    if not kernel_can_gen_family_struct(family):
3038        return
3039
3040    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
3041    cw.nl()
3042    if 'sock-priv' in family.kernel_family:
3043        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
3044        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
3045        cw.nl()
3046
3047
3048def print_kernel_family_struct_src(family, cw):
3049    if not kernel_can_gen_family_struct(family):
3050        return
3051
3052    if 'sock-priv' in family.kernel_family:
3053        # Generate "trampolines" to make CFI happy
3054        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_init",
3055                      [f"{family.c_name}_nl_sock_priv_init(priv);"],
3056                      ["void *priv"])
3057        cw.nl()
3058        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_destroy",
3059                      [f"{family.c_name}_nl_sock_priv_destroy(priv);"],
3060                      ["void *priv"])
3061        cw.nl()
3062
3063    cw.block_start(f"struct genl_family {family.ident_name}_nl_family __ro_after_init =")
3064    cw.p('.name\t\t= ' + family.fam_key + ',')
3065    cw.p('.version\t= ' + family.ver_key + ',')
3066    cw.p('.netnsok\t= true,')
3067    cw.p('.parallel_ops\t= true,')
3068    cw.p('.module\t\t= THIS_MODULE,')
3069    if family.kernel_policy == 'per-op':
3070        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
3071        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
3072    elif family.kernel_policy == 'split':
3073        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
3074        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
3075    if family.mcgrps['list']:
3076        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
3077        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
3078    if 'sock-priv' in family.kernel_family:
3079        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
3080        cw.p(f'.sock_priv_init\t= __{family.c_name}_nl_sock_priv_init,')
3081        cw.p(f'.sock_priv_destroy = __{family.c_name}_nl_sock_priv_destroy,')
3082    cw.block_end(';')
3083
3084
3085def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
3086    start_line = 'enum'
3087    if enum_name in obj:
3088        if obj[enum_name]:
3089            start_line = 'enum ' + c_lower(obj[enum_name])
3090    elif ckey and ckey in obj:
3091        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
3092    cw.block_start(line=start_line)
3093
3094
3095def render_uapi_unified(family, cw, max_by_define, separate_ntf):
3096    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
3097    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
3098    max_value = f"({cnt_name} - 1)"
3099
3100    uapi_enum_start(family, cw, family['operations'], 'enum-name')
3101    val = 0
3102    for op in family.msgs.values():
3103        if separate_ntf and ('notify' in op or 'event' in op):
3104            continue
3105
3106        suffix = ','
3107        if op.value != val:
3108            suffix = f" = {op.value},"
3109            val = op.value
3110        cw.p(op.enum_name + suffix)
3111        val += 1
3112    cw.nl()
3113    cw.p(cnt_name + ('' if max_by_define else ','))
3114    if not max_by_define:
3115        cw.p(f"{max_name} = {max_value}")
3116    cw.block_end(line=';')
3117    if max_by_define:
3118        cw.p(f"#define {max_name} {max_value}")
3119    cw.nl()
3120
3121
3122def render_uapi_directional(family, cw, max_by_define):
3123    max_name = f"{family.op_prefix}USER_MAX"
3124    cnt_name = f"__{family.op_prefix}USER_CNT"
3125    max_value = f"({cnt_name} - 1)"
3126
3127    cw.block_start(line='enum')
3128    cw.p(c_upper(f'{family.name}_MSG_USER_NONE = 0,'))
3129    val = 0
3130    for op in family.msgs.values():
3131        if 'do' in op and 'event' not in op:
3132            suffix = ','
3133            if op.value and op.value != val:
3134                suffix = f" = {op.value},"
3135                val = op.value
3136            cw.p(op.enum_name + suffix)
3137            val += 1
3138    cw.nl()
3139    cw.p(cnt_name + ('' if max_by_define else ','))
3140    if not max_by_define:
3141        cw.p(f"{max_name} = {max_value}")
3142    cw.block_end(line=';')
3143    if max_by_define:
3144        cw.p(f"#define {max_name} {max_value}")
3145    cw.nl()
3146
3147    max_name = f"{family.op_prefix}KERNEL_MAX"
3148    cnt_name = f"__{family.op_prefix}KERNEL_CNT"
3149    max_value = f"({cnt_name} - 1)"
3150
3151    cw.block_start(line='enum')
3152    cw.p(c_upper(f'{family.name}_MSG_KERNEL_NONE = 0,'))
3153    val = 0
3154    for op in family.msgs.values():
3155        if ('do' in op and 'reply' in op['do']) or 'notify' in op or 'event' in op:
3156            enum_name = op.enum_name
3157            if 'event' not in op and 'notify' not in op:
3158                enum_name = f'{enum_name}_REPLY'
3159
3160            suffix = ','
3161            if op.value and op.value != val:
3162                suffix = f" = {op.value},"
3163                val = op.value
3164            cw.p(enum_name + suffix)
3165            val += 1
3166    cw.nl()
3167    cw.p(cnt_name + ('' if max_by_define else ','))
3168    if not max_by_define:
3169        cw.p(f"{max_name} = {max_value}")
3170    cw.block_end(line=';')
3171    if max_by_define:
3172        cw.p(f"#define {max_name} {max_value}")
3173    cw.nl()
3174
3175
3176def render_uapi(family, cw):
3177    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
3178    hdr_prot = hdr_prot.replace('/', '_')
3179    cw.p('#ifndef ' + hdr_prot)
3180    cw.p('#define ' + hdr_prot)
3181    cw.nl()
3182
3183    defines = [(family.fam_key, family["name"]),
3184               (family.ver_key, family.get('version', 1))]
3185    cw.writes_defines(defines)
3186    cw.nl()
3187
3188    defines = []
3189    for const in family['definitions']:
3190        if const.get('header'):
3191            continue
3192
3193        if const['type'] != 'const':
3194            cw.writes_defines(defines)
3195            defines = []
3196            cw.nl()
3197
3198        # Write kdoc for enum and flags (one day maybe also structs)
3199        if const['type'] == 'enum' or const['type'] == 'flags':
3200            enum = family.consts[const['name']]
3201
3202            if enum.header:
3203                continue
3204
3205            if enum.has_doc():
3206                if enum.has_entry_doc():
3207                    cw.p('/**')
3208                    doc = ''
3209                    if 'doc' in enum:
3210                        doc = ' - ' + enum['doc']
3211                    cw.write_doc_line(enum.enum_name + doc)
3212                else:
3213                    cw.p('/*')
3214                    cw.write_doc_line(enum['doc'], indent=False)
3215                for entry in enum.entries.values():
3216                    if entry.has_doc():
3217                        doc = '@' + entry.c_name + ': ' + entry['doc']
3218                        cw.write_doc_line(doc)
3219                cw.p(' */')
3220
3221            uapi_enum_start(family, cw, const, 'name')
3222            name_pfx = const.get('name-prefix', f"{family.ident_name}-{const['name']}-")
3223            for entry in enum.entries.values():
3224                suffix = ','
3225                if entry.value_change:
3226                    suffix = f" = {entry.user_value()}" + suffix
3227                cw.p(entry.c_name + suffix)
3228
3229            if const.get('render-max', False):
3230                cw.nl()
3231                cw.p('/* private: */')
3232                if const['type'] == 'flags':
3233                    max_name = c_upper(name_pfx + 'mask')
3234                    max_val = f' = {enum.get_mask()},'
3235                    cw.p(max_name + max_val)
3236                else:
3237                    cnt_name = enum.enum_cnt_name
3238                    max_name = c_upper(name_pfx + 'max')
3239                    if not cnt_name:
3240                        cnt_name = '__' + name_pfx + 'max'
3241                    cw.p(c_upper(cnt_name) + ',')
3242                    cw.p(max_name + ' = (' + c_upper(cnt_name) + ' - 1)')
3243            cw.block_end(line=';')
3244            cw.nl()
3245        elif const['type'] == 'const':
3246            name_pfx = const.get('name-prefix', f"{family.ident_name}-")
3247            defines.append([c_upper(family.get('c-define-name',
3248                                               f"{name_pfx}{const['name']}")),
3249                            const['value']])
3250
3251    if defines:
3252        cw.writes_defines(defines)
3253        cw.nl()
3254
3255    max_by_define = family.get('max-by-define', False)
3256
3257    for _, attr_set in family.attr_sets.items():
3258        if attr_set.subset_of:
3259            continue
3260
3261        max_value = f"({attr_set.cnt_name} - 1)"
3262
3263        val = 0
3264        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
3265        for _, attr in attr_set.items():
3266            suffix = ','
3267            if attr.value != val:
3268                suffix = f" = {attr.value},"
3269                val = attr.value
3270            val += 1
3271            cw.p(attr.enum_name + suffix)
3272        if attr_set.items():
3273            cw.nl()
3274        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
3275        if not max_by_define:
3276            cw.p(f"{attr_set.max_name} = {max_value}")
3277        cw.block_end(line=';')
3278        if max_by_define:
3279            cw.p(f"#define {attr_set.max_name} {max_value}")
3280        cw.nl()
3281
3282    # Commands
3283    separate_ntf = 'async-prefix' in family['operations']
3284
3285    if family.msg_id_model == 'unified':
3286        render_uapi_unified(family, cw, max_by_define, separate_ntf)
3287    elif family.msg_id_model == 'directional':
3288        render_uapi_directional(family, cw, max_by_define)
3289    else:
3290        raise Exception(f'Unsupported message enum-model {family.msg_id_model}')
3291
3292    if separate_ntf:
3293        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
3294        for op in family.msgs.values():
3295            if separate_ntf and not ('notify' in op or 'event' in op):
3296                continue
3297
3298            suffix = ','
3299            if 'value' in op:
3300                suffix = f" = {op['value']},"
3301            cw.p(op.enum_name + suffix)
3302        cw.block_end(line=';')
3303        cw.nl()
3304
3305    # Multicast
3306    defines = []
3307    for grp in family.mcgrps['list']:
3308        name = grp['name']
3309        defines.append([c_upper(grp.get('c-define-name', f"{family.ident_name}-mcgrp-{name}")),
3310                        f'{name}'])
3311    cw.nl()
3312    if defines:
3313        cw.writes_defines(defines)
3314        cw.nl()
3315
3316    cw.p(f'#endif /* {hdr_prot} */')
3317
3318
3319def _render_user_ntf_entry(ri, op):
3320    if not ri.family.is_classic():
3321        ri.cw.block_start(line=f"[{op.enum_name}] = ")
3322    else:
3323        crud_op = ri.family.req_by_value[op.rsp_value]
3324        ri.cw.block_start(line=f"[{crud_op.enum_name}] = ")
3325    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
3326    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
3327    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
3328    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
3329    ri.cw.block_end(line=',')
3330
3331
3332def render_user_family(family, cw, prototype):
3333    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
3334    if prototype:
3335        cw.p(f'extern {symbol};')
3336        return
3337
3338    if family.ntfs:
3339        cw.block_start(line=f"static const struct ynl_ntf_info {family.c_name}_ntf_info[] = ")
3340        for ntf_op_name, ntf_op in family.ntfs.items():
3341            if 'notify' in ntf_op:
3342                op = family.ops[ntf_op['notify']]
3343                ri = RenderInfo(cw, family, "user", op, "notify")
3344            elif 'event' in ntf_op:
3345                ri = RenderInfo(cw, family, "user", ntf_op, "event")
3346            else:
3347                raise Exception('Invalid notification ' + ntf_op_name)
3348            _render_user_ntf_entry(ri, ntf_op)
3349        for op_name, op in family.ops.items():
3350            if 'event' not in op:
3351                continue
3352            ri = RenderInfo(cw, family, "user", op, "event")
3353            _render_user_ntf_entry(ri, op)
3354        cw.block_end(line=";")
3355        cw.nl()
3356
3357    cw.block_start(f'{symbol} = ')
3358    cw.p(f'.name\t\t= "{family.c_name}",')
3359    if family.is_classic():
3360        cw.p('.is_classic\t= true,')
3361        cw.p(f'.classic_id\t= {family.get("protonum")},')
3362    if family.is_classic():
3363        if family.fixed_header:
3364            cw.p(f'.hdr_len\t= sizeof(struct {c_lower(family.fixed_header)}),')
3365    elif family.fixed_header:
3366        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
3367    else:
3368        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
3369    if family.ntfs:
3370        cw.p(f".ntf_info\t= {family.c_name}_ntf_info,")
3371        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family.c_name}_ntf_info),")
3372    cw.block_end(line=';')
3373
3374
3375def family_contains_bitfield32(family):
3376    for _, attr_set in family.attr_sets.items():
3377        if attr_set.subset_of:
3378            continue
3379        for _, attr in attr_set.items():
3380            if attr.type == "bitfield32":
3381                return True
3382    return False
3383
3384
3385def find_kernel_root(full_path):
3386    sub_path = ''
3387    while True:
3388        sub_path = os.path.join(os.path.basename(full_path), sub_path)
3389        full_path = os.path.dirname(full_path)
3390        maintainers = os.path.join(full_path, "MAINTAINERS")
3391        if os.path.exists(maintainers):
3392            return full_path, sub_path[:-1]
3393
3394
3395def main():
3396    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
3397    parser.add_argument('--mode', dest='mode', type=str, required=True,
3398                        choices=('user', 'kernel', 'uapi'))
3399    parser.add_argument('--spec', dest='spec', type=str, required=True)
3400    parser.add_argument('--header', dest='header', action='store_true', default=None)
3401    parser.add_argument('--source', dest='header', action='store_false')
3402    parser.add_argument('--user-header', nargs='+', default=[])
3403    parser.add_argument('--cmp-out', action='store_true', default=None,
3404                        help='Do not overwrite the output file if the new output is identical to the old')
3405    parser.add_argument('--exclude-op', action='append', default=[])
3406    parser.add_argument('-o', dest='out_file', type=str, default=None)
3407    parser.add_argument('--function-prefix', dest='fn_prefix', type=str)
3408    args = parser.parse_args()
3409
3410    if args.header is None:
3411        parser.error("--header or --source is required")
3412
3413    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
3414
3415    try:
3416        parsed = Family(args.spec, exclude_ops, args.fn_prefix)
3417        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
3418            print('Spec license:', parsed.license)
3419            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
3420            os.sys.exit(1)
3421    except yaml.YAMLError as exc:
3422        print(exc)
3423        os.sys.exit(1)
3424        return
3425
3426    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
3427
3428    _, spec_kernel = find_kernel_root(args.spec)
3429    if args.mode == 'uapi' or args.header:
3430        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
3431    else:
3432        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
3433    cw.p("/* Do not edit directly, auto-generated from: */")
3434    cw.p(f"/*\t{spec_kernel} */")
3435    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
3436    if args.exclude_op or args.user_header or args.fn_prefix:
3437        line = ''
3438        if args.user_header:
3439            line += ' --user-header '.join([''] + args.user_header)
3440        if args.exclude_op:
3441            line += ' --exclude-op '.join([''] + args.exclude_op)
3442        if args.fn_prefix:
3443            line += f' --function-prefix {args.fn_prefix}'
3444        cw.p(f'/* YNL-ARG{line} */')
3445    cw.p('/* To regenerate run: tools/net/ynl/ynl-regen.sh */')
3446    cw.nl()
3447
3448    if args.mode == 'uapi':
3449        render_uapi(parsed, cw)
3450        return
3451
3452    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
3453    if args.header:
3454        cw.p('#ifndef ' + hdr_prot)
3455        cw.p('#define ' + hdr_prot)
3456        cw.nl()
3457
3458    if args.out_file:
3459        hdr_file = os.path.basename(args.out_file[:-2]) + ".h"
3460    else:
3461        hdr_file = "generated_header_file.h"
3462
3463    if args.mode == 'kernel':
3464        cw.p('#include <net/netlink.h>')
3465        cw.p('#include <net/genetlink.h>')
3466        cw.nl()
3467        if not args.header:
3468            if args.out_file:
3469                cw.p(f'#include "{hdr_file}"')
3470            cw.nl()
3471        headers = ['uapi/' + parsed.uapi_header]
3472        headers += parsed.kernel_family.get('headers', [])
3473    else:
3474        cw.p('#include <stdlib.h>')
3475        cw.p('#include <string.h>')
3476        if args.header:
3477            cw.p('#include <linux/types.h>')
3478            if family_contains_bitfield32(parsed):
3479                cw.p('#include <linux/netlink.h>')
3480        else:
3481            cw.p(f'#include "{hdr_file}"')
3482            cw.p('#include "ynl.h"')
3483        headers = []
3484    for definition in parsed['definitions'] + parsed['attribute-sets']:
3485        if 'header' in definition:
3486            headers.append(definition['header'])
3487    if args.mode == 'user':
3488        headers.append(parsed.uapi_header)
3489    seen_header = []
3490    for one in headers:
3491        if one not in seen_header:
3492            cw.p(f"#include <{one}>")
3493            seen_header.append(one)
3494    cw.nl()
3495
3496    if args.mode == "user":
3497        if not args.header:
3498            cw.p("#include <linux/genetlink.h>")
3499            cw.nl()
3500            for one in args.user_header:
3501                cw.p(f'#include "{one}"')
3502        else:
3503            cw.p('struct ynl_sock;')
3504            cw.nl()
3505            render_user_family(parsed, cw, True)
3506        cw.nl()
3507
3508    if args.mode == "kernel":
3509        if args.header:
3510            for _, struct in sorted(parsed.pure_nested_structs.items()):
3511                if struct.request:
3512                    cw.p('/* Common nested types */')
3513                    break
3514            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3515                if struct.request:
3516                    print_req_policy_fwd(cw, struct)
3517            cw.nl()
3518
3519            if parsed.kernel_policy == 'global':
3520                cw.p(f"/* Global operation policy for {parsed.name} */")
3521
3522                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3523                print_req_policy_fwd(cw, struct)
3524                cw.nl()
3525
3526            if parsed.kernel_policy in {'per-op', 'split'}:
3527                for op_name, op in parsed.ops.items():
3528                    if 'do' in op and 'event' not in op:
3529                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
3530                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
3531                        cw.nl()
3532
3533            print_kernel_op_table_hdr(parsed, cw)
3534            print_kernel_mcgrp_hdr(parsed, cw)
3535            print_kernel_family_struct_hdr(parsed, cw)
3536        else:
3537            print_kernel_policy_ranges(parsed, cw)
3538            print_kernel_policy_sparse_enum_validates(parsed, cw)
3539
3540            for _, struct in sorted(parsed.pure_nested_structs.items()):
3541                if struct.request:
3542                    cw.p('/* Common nested types */')
3543                    break
3544            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3545                if struct.request:
3546                    print_req_policy(cw, struct)
3547            cw.nl()
3548
3549            if parsed.kernel_policy == 'global':
3550                cw.p(f"/* Global operation policy for {parsed.name} */")
3551
3552                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3553                print_req_policy(cw, struct)
3554                cw.nl()
3555
3556            for op_name, op in parsed.ops.items():
3557                if parsed.kernel_policy in {'per-op', 'split'}:
3558                    for op_mode in ['do', 'dump']:
3559                        if op_mode in op and 'request' in op[op_mode]:
3560                            cw.p(f"/* {op.enum_name} - {op_mode} */")
3561                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
3562                            print_req_policy(cw, ri.struct['request'], ri=ri)
3563                            cw.nl()
3564
3565            print_kernel_op_table(parsed, cw)
3566            print_kernel_mcgrp_src(parsed, cw)
3567            print_kernel_family_struct_src(parsed, cw)
3568
3569    if args.mode == "user":
3570        if args.header:
3571            cw.p('/* Enums */')
3572            put_op_name_fwd(parsed, cw)
3573
3574            for name, const in parsed.consts.items():
3575                if isinstance(const, EnumSet):
3576                    put_enum_to_str_fwd(parsed, cw, const)
3577            cw.nl()
3578
3579            cw.p('/* Common nested types */')
3580            for attr_set, struct in parsed.pure_nested_structs.items():
3581                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3582                print_type_full(ri, struct)
3583
3584            for op_name, op in parsed.ops.items():
3585                cw.p(f"/* ============== {op.enum_name} ============== */")
3586
3587                if 'do' in op and 'event' not in op:
3588                    cw.p(f"/* {op.enum_name} - do */")
3589                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3590                    print_req_type(ri)
3591                    print_req_type_helpers(ri)
3592                    cw.nl()
3593                    print_rsp_type(ri)
3594                    print_rsp_type_helpers(ri)
3595                    cw.nl()
3596                    print_req_prototype(ri)
3597                    cw.nl()
3598
3599                if 'dump' in op:
3600                    cw.p(f"/* {op.enum_name} - dump */")
3601                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
3602                    print_req_type(ri)
3603                    print_req_type_helpers(ri)
3604                    if not ri.type_consistent or ri.type_oneside:
3605                        print_rsp_type(ri)
3606                    print_wrapped_type(ri)
3607                    print_dump_prototype(ri)
3608                    cw.nl()
3609
3610                if op.has_ntf:
3611                    cw.p(f"/* {op.enum_name} - notify */")
3612                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3613                    if not ri.type_consistent:
3614                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3615                    print_wrapped_type(ri)
3616
3617            for op_name, op in parsed.ntfs.items():
3618                if 'event' in op:
3619                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
3620                    cw.p(f"/* {op.enum_name} - event */")
3621                    print_rsp_type(ri)
3622                    cw.nl()
3623                    print_wrapped_type(ri)
3624            cw.nl()
3625        else:
3626            cw.p('/* Enums */')
3627            put_op_name(parsed, cw)
3628
3629            for name, const in parsed.consts.items():
3630                if isinstance(const, EnumSet):
3631                    put_enum_to_str(parsed, cw, const)
3632            cw.nl()
3633
3634            has_recursive_nests = False
3635            cw.p('/* Policies */')
3636            for struct in parsed.pure_nested_structs.values():
3637                if struct.recursive:
3638                    put_typol_fwd(cw, struct)
3639                    has_recursive_nests = True
3640            if has_recursive_nests:
3641                cw.nl()
3642            for struct in parsed.pure_nested_structs.values():
3643                put_typol(cw, struct)
3644            for name in parsed.root_sets:
3645                struct = Struct(parsed, name)
3646                put_typol(cw, struct)
3647
3648            cw.p('/* Common nested types */')
3649            if has_recursive_nests:
3650                for attr_set, struct in parsed.pure_nested_structs.items():
3651                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3652                    free_rsp_nested_prototype(ri)
3653                    if struct.request:
3654                        put_req_nested_prototype(ri, struct)
3655                    if struct.reply:
3656                        parse_rsp_nested_prototype(ri, struct)
3657                cw.nl()
3658            for attr_set, struct in parsed.pure_nested_structs.items():
3659                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3660
3661                free_rsp_nested(ri, struct)
3662                if struct.request:
3663                    put_req_nested(ri, struct)
3664                if struct.reply:
3665                    parse_rsp_nested(ri, struct)
3666
3667            for op_name, op in parsed.ops.items():
3668                cw.p(f"/* ============== {op.enum_name} ============== */")
3669                if 'do' in op and 'event' not in op:
3670                    cw.p(f"/* {op.enum_name} - do */")
3671                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3672                    print_req_free(ri)
3673                    print_rsp_free(ri)
3674                    parse_rsp_msg(ri)
3675                    print_req(ri)
3676                    cw.nl()
3677
3678                if 'dump' in op:
3679                    cw.p(f"/* {op.enum_name} - dump */")
3680                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
3681                    if not ri.type_consistent or ri.type_oneside:
3682                        parse_rsp_msg(ri, deref=True)
3683                    print_req_free(ri)
3684                    print_dump_type_free(ri)
3685                    print_dump(ri)
3686                    cw.nl()
3687
3688                if op.has_ntf:
3689                    cw.p(f"/* {op.enum_name} - notify */")
3690                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3691                    if not ri.type_consistent:
3692                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3693                    print_ntf_type_free(ri)
3694
3695            for op_name, op in parsed.ntfs.items():
3696                if 'event' in op:
3697                    cw.p(f"/* {op.enum_name} - event */")
3698
3699                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3700                    parse_rsp_msg(ri)
3701
3702                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
3703                    print_ntf_type_free(ri)
3704            cw.nl()
3705            render_user_family(parsed, cw, False)
3706
3707    if args.header:
3708        cw.p(f'#endif /* {hdr_prot} */')
3709
3710
3711if __name__ == "__main__":
3712    main()
3713