xref: /linux/tools/net/ynl/lib/ynl.py (revision f2ec98566775dd4341ec1dcf93aa5859c60de826)
1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2
3from collections import namedtuple
4import functools
5import os
6import random
7import socket
8import struct
9from struct import Struct
10import yaml
11import ipaddress
12import uuid
13
14from .nlspec import SpecFamily
15
16#
17# Generic Netlink code which should really be in some library, but I can't quickly find one.
18#
19
20
21class Netlink:
22    # Netlink socket
23    SOL_NETLINK = 270
24
25    NETLINK_ADD_MEMBERSHIP = 1
26    NETLINK_CAP_ACK = 10
27    NETLINK_EXT_ACK = 11
28    NETLINK_GET_STRICT_CHK = 12
29
30    # Netlink message
31    NLMSG_ERROR = 2
32    NLMSG_DONE = 3
33
34    NLM_F_REQUEST = 1
35    NLM_F_ACK = 4
36    NLM_F_ROOT = 0x100
37    NLM_F_MATCH = 0x200
38
39    NLM_F_REPLACE = 0x100
40    NLM_F_EXCL = 0x200
41    NLM_F_CREATE = 0x400
42    NLM_F_APPEND = 0x800
43
44    NLM_F_CAPPED = 0x100
45    NLM_F_ACK_TLVS = 0x200
46
47    NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
48
49    NLA_F_NESTED = 0x8000
50    NLA_F_NET_BYTEORDER = 0x4000
51
52    NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
53
54    # Genetlink defines
55    NETLINK_GENERIC = 16
56
57    GENL_ID_CTRL = 0x10
58
59    # nlctrl
60    CTRL_CMD_GETFAMILY = 3
61
62    CTRL_ATTR_FAMILY_ID = 1
63    CTRL_ATTR_FAMILY_NAME = 2
64    CTRL_ATTR_MAXATTR = 5
65    CTRL_ATTR_MCAST_GROUPS = 7
66
67    CTRL_ATTR_MCAST_GRP_NAME = 1
68    CTRL_ATTR_MCAST_GRP_ID = 2
69
70    # Extack types
71    NLMSGERR_ATTR_MSG = 1
72    NLMSGERR_ATTR_OFFS = 2
73    NLMSGERR_ATTR_COOKIE = 3
74    NLMSGERR_ATTR_POLICY = 4
75    NLMSGERR_ATTR_MISS_TYPE = 5
76    NLMSGERR_ATTR_MISS_NEST = 6
77
78
79class NlError(Exception):
80  def __init__(self, nl_msg):
81    self.nl_msg = nl_msg
82
83  def __str__(self):
84    return f"Netlink error: {os.strerror(-self.nl_msg.error)}\n{self.nl_msg}"
85
86
87class NlAttr:
88    ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
89    type_formats = {
90        'u8' : ScalarFormat(Struct('B'), Struct("B"),  Struct("B")),
91        's8' : ScalarFormat(Struct('b'), Struct("b"),  Struct("b")),
92        'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
93        's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
94        'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
95        's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
96        'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
97        's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
98    }
99
100    def __init__(self, raw, offset):
101        self._len, self._type = struct.unpack("HH", raw[offset : offset + 4])
102        self.type = self._type & ~Netlink.NLA_TYPE_MASK
103        self.is_nest = self._type & Netlink.NLA_F_NESTED
104        self.payload_len = self._len
105        self.full_len = (self.payload_len + 3) & ~3
106        self.raw = raw[offset + 4 : offset + self.payload_len]
107
108    @classmethod
109    def get_format(cls, attr_type, byte_order=None):
110        format = cls.type_formats[attr_type]
111        if byte_order:
112            return format.big if byte_order == "big-endian" \
113                else format.little
114        return format.native
115
116    def as_scalar(self, attr_type, byte_order=None):
117        format = self.get_format(attr_type, byte_order)
118        return format.unpack(self.raw)[0]
119
120    def as_auto_scalar(self, attr_type, byte_order=None):
121        if len(self.raw) != 4 and len(self.raw) != 8:
122            raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}")
123        real_type = attr_type[0] + str(len(self.raw) * 8)
124        format = self.get_format(real_type, byte_order)
125        return format.unpack(self.raw)[0]
126
127    def as_strz(self):
128        return self.raw.decode('ascii')[:-1]
129
130    def as_bin(self):
131        return self.raw
132
133    def as_c_array(self, type):
134        format = self.get_format(type)
135        return [ x[0] for x in format.iter_unpack(self.raw) ]
136
137    def __repr__(self):
138        return f"[type:{self.type} len:{self._len}] {self.raw}"
139
140
141class NlAttrs:
142    def __init__(self, msg, offset=0):
143        self.attrs = []
144
145        while offset < len(msg):
146            attr = NlAttr(msg, offset)
147            offset += attr.full_len
148            self.attrs.append(attr)
149
150    def __iter__(self):
151        yield from self.attrs
152
153    def __repr__(self):
154        msg = ''
155        for a in self.attrs:
156            if msg:
157                msg += '\n'
158            msg += repr(a)
159        return msg
160
161
162class NlMsg:
163    def __init__(self, msg, offset, attr_space=None):
164        self.hdr = msg[offset : offset + 16]
165
166        self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
167            struct.unpack("IHHII", self.hdr)
168
169        self.raw = msg[offset + 16 : offset + self.nl_len]
170
171        self.error = 0
172        self.done = 0
173
174        extack_off = None
175        if self.nl_type == Netlink.NLMSG_ERROR:
176            self.error = struct.unpack("i", self.raw[0:4])[0]
177            self.done = 1
178            extack_off = 20
179        elif self.nl_type == Netlink.NLMSG_DONE:
180            self.done = 1
181            extack_off = 4
182
183        self.extack = None
184        if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
185            self.extack = dict()
186            extack_attrs = NlAttrs(self.raw[extack_off:])
187            for extack in extack_attrs:
188                if extack.type == Netlink.NLMSGERR_ATTR_MSG:
189                    self.extack['msg'] = extack.as_strz()
190                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
191                    self.extack['miss-type'] = extack.as_scalar('u32')
192                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
193                    self.extack['miss-nest'] = extack.as_scalar('u32')
194                elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
195                    self.extack['bad-attr-offs'] = extack.as_scalar('u32')
196                else:
197                    if 'unknown' not in self.extack:
198                        self.extack['unknown'] = []
199                    self.extack['unknown'].append(extack)
200
201            if attr_space:
202                # We don't have the ability to parse nests yet, so only do global
203                if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
204                    miss_type = self.extack['miss-type']
205                    if miss_type in attr_space.attrs_by_val:
206                        spec = attr_space.attrs_by_val[miss_type]
207                        desc = spec['name']
208                        if 'doc' in spec:
209                            desc += f" ({spec['doc']})"
210                        self.extack['miss-type'] = desc
211
212    def cmd(self):
213        return self.nl_type
214
215    def __repr__(self):
216        msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}\n"
217        if self.error:
218            msg += '\terror: ' + str(self.error)
219        if self.extack:
220            msg += '\textack: ' + repr(self.extack)
221        return msg
222
223
224class NlMsgs:
225    def __init__(self, data, attr_space=None):
226        self.msgs = []
227
228        offset = 0
229        while offset < len(data):
230            msg = NlMsg(data, offset, attr_space=attr_space)
231            offset += msg.nl_len
232            self.msgs.append(msg)
233
234    def __iter__(self):
235        yield from self.msgs
236
237
238genl_family_name_to_id = None
239
240
241def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
242    # we prepend length in _genl_msg_finalize()
243    if seq is None:
244        seq = random.randint(1, 1024)
245    nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
246    genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
247    return nlmsg + genlmsg
248
249
250def _genl_msg_finalize(msg):
251    return struct.pack("I", len(msg) + 4) + msg
252
253
254def _genl_load_families():
255    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
256        sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
257
258        msg = _genl_msg(Netlink.GENL_ID_CTRL,
259                        Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
260                        Netlink.CTRL_CMD_GETFAMILY, 1)
261        msg = _genl_msg_finalize(msg)
262
263        sock.send(msg, 0)
264
265        global genl_family_name_to_id
266        genl_family_name_to_id = dict()
267
268        while True:
269            reply = sock.recv(128 * 1024)
270            nms = NlMsgs(reply)
271            for nl_msg in nms:
272                if nl_msg.error:
273                    print("Netlink error:", nl_msg.error)
274                    return
275                if nl_msg.done:
276                    return
277
278                gm = GenlMsg(nl_msg)
279                fam = dict()
280                for attr in NlAttrs(gm.raw):
281                    if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
282                        fam['id'] = attr.as_scalar('u16')
283                    elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
284                        fam['name'] = attr.as_strz()
285                    elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
286                        fam['maxattr'] = attr.as_scalar('u32')
287                    elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
288                        fam['mcast'] = dict()
289                        for entry in NlAttrs(attr.raw):
290                            mcast_name = None
291                            mcast_id = None
292                            for entry_attr in NlAttrs(entry.raw):
293                                if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
294                                    mcast_name = entry_attr.as_strz()
295                                elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
296                                    mcast_id = entry_attr.as_scalar('u32')
297                            if mcast_name and mcast_id is not None:
298                                fam['mcast'][mcast_name] = mcast_id
299                if 'name' in fam and 'id' in fam:
300                    genl_family_name_to_id[fam['name']] = fam
301
302
303class GenlMsg:
304    def __init__(self, nl_msg):
305        self.nl = nl_msg
306        self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
307        self.raw = nl_msg.raw[4:]
308
309    def cmd(self):
310        return self.genl_cmd
311
312    def __repr__(self):
313        msg = repr(self.nl)
314        msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
315        for a in self.raw_attrs:
316            msg += '\t\t' + repr(a) + '\n'
317        return msg
318
319
320class NetlinkProtocol:
321    def __init__(self, family_name, proto_num):
322        self.family_name = family_name
323        self.proto_num = proto_num
324
325    def _message(self, nl_type, nl_flags, seq=None):
326        if seq is None:
327            seq = random.randint(1, 1024)
328        nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
329        return nlmsg
330
331    def message(self, flags, command, version, seq=None):
332        return self._message(command, flags, seq)
333
334    def _decode(self, nl_msg):
335        return nl_msg
336
337    def decode(self, ynl, nl_msg):
338        msg = self._decode(nl_msg)
339        fixed_header_size = 0
340        if ynl:
341            op = ynl.rsp_by_value[msg.cmd()]
342            fixed_header_size = ynl._struct_size(op.fixed_header)
343        msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size)
344        return msg
345
346    def get_mcast_id(self, mcast_name, mcast_groups):
347        if mcast_name not in mcast_groups:
348            raise Exception(f'Multicast group "{mcast_name}" not present in the spec')
349        return mcast_groups[mcast_name].value
350
351
352class GenlProtocol(NetlinkProtocol):
353    def __init__(self, family_name):
354        super().__init__(family_name, Netlink.NETLINK_GENERIC)
355
356        global genl_family_name_to_id
357        if genl_family_name_to_id is None:
358            _genl_load_families()
359
360        self.genl_family = genl_family_name_to_id[family_name]
361        self.family_id = genl_family_name_to_id[family_name]['id']
362
363    def message(self, flags, command, version, seq=None):
364        nlmsg = self._message(self.family_id, flags, seq)
365        genlmsg = struct.pack("BBH", command, version, 0)
366        return nlmsg + genlmsg
367
368    def _decode(self, nl_msg):
369        return GenlMsg(nl_msg)
370
371    def get_mcast_id(self, mcast_name, mcast_groups):
372        if mcast_name not in self.genl_family['mcast']:
373            raise Exception(f'Multicast group "{mcast_name}" not present in the family')
374        return self.genl_family['mcast'][mcast_name]
375
376
377
378class SpaceAttrs:
379    SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values'])
380
381    def __init__(self, attr_space, attrs, outer = None):
382        outer_scopes = outer.scopes if outer else []
383        inner_scope = self.SpecValuesPair(attr_space, attrs)
384        self.scopes = [inner_scope] + outer_scopes
385
386    def lookup(self, name):
387        for scope in self.scopes:
388            if name in scope.spec:
389                if name in scope.values:
390                    return scope.values[name]
391                spec_name = scope.spec.yaml['name']
392                raise Exception(
393                    f"No value for '{name}' in attribute space '{spec_name}'")
394        raise Exception(f"Attribute '{name}' not defined in any attribute-set")
395
396
397#
398# YNL implementation details.
399#
400
401
402class YnlFamily(SpecFamily):
403    def __init__(self, def_path, schema=None, process_unknown=False):
404        super().__init__(def_path, schema)
405
406        self.include_raw = False
407        self.process_unknown = process_unknown
408
409        try:
410            if self.proto == "netlink-raw":
411                self.nlproto = NetlinkProtocol(self.yaml['name'],
412                                               self.yaml['protonum'])
413            else:
414                self.nlproto = GenlProtocol(self.yaml['name'])
415        except KeyError:
416            raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
417
418        self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
419        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
420        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
421        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)
422
423        self.async_msg_ids = set()
424        self.async_msg_queue = []
425
426        for msg in self.msgs.values():
427            if msg.is_async:
428                self.async_msg_ids.add(msg.rsp_value)
429
430        for op_name, op in self.ops.items():
431            bound_f = functools.partial(self._op, op_name)
432            setattr(self, op.ident_name, bound_f)
433
434
435    def ntf_subscribe(self, mcast_name):
436        mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
437        self.sock.bind((0, 0))
438        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
439                             mcast_id)
440
441    def _add_attr(self, space, name, value, search_attrs):
442        try:
443            attr = self.attr_sets[space][name]
444        except KeyError:
445            raise Exception(f"Space '{space}' has no attribute '{name}'")
446        nl_type = attr.value
447        if attr["type"] == 'nest':
448            nl_type |= Netlink.NLA_F_NESTED
449            attr_payload = b''
450            sub_attrs = SpaceAttrs(self.attr_sets[space], value, search_attrs)
451            for subname, subvalue in value.items():
452                attr_payload += self._add_attr(attr['nested-attributes'],
453                                               subname, subvalue, sub_attrs)
454        elif attr["type"] == 'flag':
455            attr_payload = b''
456        elif attr["type"] == 'string':
457            attr_payload = str(value).encode('ascii') + b'\x00'
458        elif attr["type"] == 'binary':
459            if isinstance(value, bytes):
460                attr_payload = value
461            elif isinstance(value, str):
462                attr_payload = bytes.fromhex(value)
463            elif isinstance(value, dict) and attr.struct_name:
464                attr_payload = self._encode_struct(attr.struct_name, value)
465            else:
466                raise Exception(f'Unknown type for binary attribute, value: {value}')
467        elif attr.is_auto_scalar:
468            scalar = int(value)
469            real_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64')
470            format = NlAttr.get_format(real_type, attr.byte_order)
471            attr_payload = format.pack(int(value))
472        elif attr['type'] in NlAttr.type_formats:
473            format = NlAttr.get_format(attr['type'], attr.byte_order)
474            attr_payload = format.pack(int(value))
475        elif attr['type'] in "bitfield32":
476            attr_payload = struct.pack("II", int(value["value"]), int(value["selector"]))
477        elif attr['type'] == 'sub-message':
478            msg_format = self._resolve_selector(attr, search_attrs)
479            attr_payload = b''
480            if msg_format.fixed_header:
481                attr_payload += self._encode_struct(msg_format.fixed_header, value)
482            if msg_format.attr_set:
483                if msg_format.attr_set in self.attr_sets:
484                    nl_type |= Netlink.NLA_F_NESTED
485                    sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs)
486                    for subname, subvalue in value.items():
487                        attr_payload += self._add_attr(msg_format.attr_set,
488                                                       subname, subvalue, sub_attrs)
489                else:
490                    raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'")
491        else:
492            raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
493
494        pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
495        return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
496
497    def _decode_enum(self, raw, attr_spec):
498        enum = self.consts[attr_spec['enum']]
499        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
500            i = 0
501            value = set()
502            while raw:
503                if raw & 1:
504                    value.add(enum.entries_by_val[i].name)
505                raw >>= 1
506                i += 1
507        else:
508            value = enum.entries_by_val[raw].name
509        return value
510
511    def _decode_binary(self, attr, attr_spec):
512        if attr_spec.struct_name:
513            decoded = self._decode_struct(attr.raw, attr_spec.struct_name)
514        elif attr_spec.sub_type:
515            decoded = attr.as_c_array(attr_spec.sub_type)
516        else:
517            decoded = attr.as_bin()
518            if attr_spec.display_hint:
519                decoded = self._formatted_string(decoded, attr_spec.display_hint)
520        return decoded
521
522    def _decode_array_nest(self, attr, attr_spec):
523        decoded = []
524        offset = 0
525        while offset < len(attr.raw):
526            item = NlAttr(attr.raw, offset)
527            offset += item.full_len
528
529            subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes'])
530            decoded.append({ item.type: subattrs })
531        return decoded
532
533    def _decode_unknown(self, attr):
534        if attr.is_nest:
535            return self._decode(NlAttrs(attr.raw), None)
536        else:
537            return attr.as_bin()
538
539    def _rsp_add(self, rsp, name, is_multi, decoded):
540        if is_multi == None:
541            if name in rsp and type(rsp[name]) is not list:
542                rsp[name] = [rsp[name]]
543                is_multi = True
544            else:
545                is_multi = False
546
547        if not is_multi:
548            rsp[name] = decoded
549        elif name in rsp:
550            rsp[name].append(decoded)
551        else:
552            rsp[name] = [decoded]
553
554    def _resolve_selector(self, attr_spec, search_attrs):
555        sub_msg = attr_spec.sub_message
556        if sub_msg not in self.sub_msgs:
557            raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}")
558        sub_msg_spec = self.sub_msgs[sub_msg]
559
560        selector = attr_spec.selector
561        value = search_attrs.lookup(selector)
562        if value not in sub_msg_spec.formats:
563            raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'")
564
565        spec = sub_msg_spec.formats[value]
566        return spec
567
568    def _decode_sub_msg(self, attr, attr_spec, search_attrs):
569        msg_format = self._resolve_selector(attr_spec, search_attrs)
570        decoded = {}
571        offset = 0
572        if msg_format.fixed_header:
573            decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header));
574            offset = self._struct_size(msg_format.fixed_header)
575        if msg_format.attr_set:
576            if msg_format.attr_set in self.attr_sets:
577                subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set)
578                decoded.update(subdict)
579            else:
580                raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'")
581        return decoded
582
583    def _decode(self, attrs, space, outer_attrs = None):
584        if space:
585            attr_space = self.attr_sets[space]
586        rsp = dict()
587        search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs)
588
589        for attr in attrs:
590            try:
591                attr_spec = attr_space.attrs_by_val[attr.type]
592            except (KeyError, UnboundLocalError):
593                if not self.process_unknown:
594                    raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'")
595                attr_name = f"UnknownAttr({attr.type})"
596                self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr))
597                continue
598
599            if attr_spec["type"] == 'nest':
600                subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs)
601                decoded = subdict
602            elif attr_spec["type"] == 'string':
603                decoded = attr.as_strz()
604            elif attr_spec["type"] == 'binary':
605                decoded = self._decode_binary(attr, attr_spec)
606            elif attr_spec["type"] == 'flag':
607                decoded = True
608            elif attr_spec.is_auto_scalar:
609                decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order)
610            elif attr_spec["type"] in NlAttr.type_formats:
611                decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
612                if 'enum' in attr_spec:
613                    decoded = self._decode_enum(decoded, attr_spec)
614            elif attr_spec["type"] == 'array-nest':
615                decoded = self._decode_array_nest(attr, attr_spec)
616            elif attr_spec["type"] == 'bitfield32':
617                value, selector = struct.unpack("II", attr.raw)
618                if 'enum' in attr_spec:
619                    value = self._decode_enum(value, attr_spec)
620                    selector = self._decode_enum(selector, attr_spec)
621                decoded = {"value": value, "selector": selector}
622            elif attr_spec["type"] == 'sub-message':
623                decoded = self._decode_sub_msg(attr, attr_spec, search_attrs)
624            else:
625                if not self.process_unknown:
626                    raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
627                decoded = self._decode_unknown(attr)
628
629            self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded)
630
631        return rsp
632
633    def _decode_extack_path(self, attrs, attr_set, offset, target):
634        for attr in attrs:
635            try:
636                attr_spec = attr_set.attrs_by_val[attr.type]
637            except KeyError:
638                raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'")
639            if offset > target:
640                break
641            if offset == target:
642                return '.' + attr_spec.name
643
644            if offset + attr.full_len <= target:
645                offset += attr.full_len
646                continue
647            if attr_spec['type'] != 'nest':
648                raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
649            offset += 4
650            subpath = self._decode_extack_path(NlAttrs(attr.raw),
651                                               self.attr_sets[attr_spec['nested-attributes']],
652                                               offset, target)
653            if subpath is None:
654                return None
655            return '.' + attr_spec.name + subpath
656
657        return None
658
659    def _decode_extack(self, request, op, extack):
660        if 'bad-attr-offs' not in extack:
661            return
662
663        msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set))
664        offset = 20 + self._struct_size(op.fixed_header)
665        path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
666                                        extack['bad-attr-offs'])
667        if path:
668            del extack['bad-attr-offs']
669            extack['bad-attr'] = path
670
671    def _struct_size(self, name):
672        if name:
673            members = self.consts[name].members
674            size = 0
675            for m in members:
676                if m.type in ['pad', 'binary']:
677                    if m.struct:
678                        size += self._struct_size(m.struct)
679                    else:
680                        size += m.len
681                else:
682                    format = NlAttr.get_format(m.type, m.byte_order)
683                    size += format.size
684            return size
685        else:
686            return 0
687
688    def _decode_struct(self, data, name):
689        members = self.consts[name].members
690        attrs = dict()
691        offset = 0
692        for m in members:
693            value = None
694            if m.type == 'pad':
695                offset += m.len
696            elif m.type == 'binary':
697                if m.struct:
698                    len = self._struct_size(m.struct)
699                    value = self._decode_struct(data[offset : offset + len],
700                                                m.struct)
701                    offset += len
702                else:
703                    value = data[offset : offset + m.len]
704                    offset += m.len
705            else:
706                format = NlAttr.get_format(m.type, m.byte_order)
707                [ value ] = format.unpack_from(data, offset)
708                offset += format.size
709            if value is not None:
710                if m.enum:
711                    value = self._decode_enum(value, m)
712                elif m.display_hint:
713                    value = self._formatted_string(value, m.display_hint)
714                attrs[m.name] = value
715        return attrs
716
717    def _encode_struct(self, name, vals):
718        members = self.consts[name].members
719        attr_payload = b''
720        for m in members:
721            value = vals.pop(m.name) if m.name in vals else None
722            if m.type == 'pad':
723                attr_payload += bytearray(m.len)
724            elif m.type == 'binary':
725                if m.struct:
726                    if value is None:
727                        value = dict()
728                    attr_payload += self._encode_struct(m.struct, value)
729                else:
730                    if value is None:
731                        attr_payload += bytearray(m.len)
732                    else:
733                        attr_payload += bytes.fromhex(value)
734            else:
735                if value is None:
736                    value = 0
737                format = NlAttr.get_format(m.type, m.byte_order)
738                attr_payload += format.pack(value)
739        return attr_payload
740
741    def _formatted_string(self, raw, display_hint):
742        if display_hint == 'mac':
743            formatted = ':'.join('%02x' % b for b in raw)
744        elif display_hint == 'hex':
745            formatted = bytes.hex(raw, ' ')
746        elif display_hint in [ 'ipv4', 'ipv6' ]:
747            formatted = format(ipaddress.ip_address(raw))
748        elif display_hint == 'uuid':
749            formatted = str(uuid.UUID(bytes=raw))
750        else:
751            formatted = raw
752        return formatted
753
754    def handle_ntf(self, decoded):
755        msg = dict()
756        if self.include_raw:
757            msg['raw'] = decoded
758        op = self.rsp_by_value[decoded.cmd()]
759        attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
760        if op.fixed_header:
761            attrs.update(self._decode_struct(decoded.raw, op.fixed_header))
762
763        msg['name'] = op['name']
764        msg['msg'] = attrs
765        self.async_msg_queue.append(msg)
766
767    def check_ntf(self):
768        while True:
769            try:
770                reply = self.sock.recv(128 * 1024, socket.MSG_DONTWAIT)
771            except BlockingIOError:
772                return
773
774            nms = NlMsgs(reply)
775            for nl_msg in nms:
776                if nl_msg.error:
777                    print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
778                    print(nl_msg)
779                    continue
780                if nl_msg.done:
781                    print("Netlink done while checking for ntf!?")
782                    continue
783
784                decoded = self.nlproto.decode(self, nl_msg)
785                if decoded.cmd() not in self.async_msg_ids:
786                    print("Unexpected msg id done while checking for ntf", decoded)
787                    continue
788
789                self.handle_ntf(decoded)
790
791    def operation_do_attributes(self, name):
792      """
793      For a given operation name, find and return a supported
794      set of attributes (as a dict).
795      """
796      op = self.find_operation(name)
797      if not op:
798        return None
799
800      return op['do']['request']['attributes'].copy()
801
802    def _op(self, method, vals, flags=None, dump=False):
803        op = self.ops[method]
804
805        nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
806        for flag in flags or []:
807            nl_flags |= flag
808        if dump:
809            nl_flags |= Netlink.NLM_F_DUMP
810
811        req_seq = random.randint(1024, 65535)
812        msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
813        if op.fixed_header:
814            msg += self._encode_struct(op.fixed_header, vals)
815        search_attrs = SpaceAttrs(op.attr_set, vals)
816        for name, value in vals.items():
817            msg += self._add_attr(op.attr_set.name, name, value, search_attrs)
818        msg = _genl_msg_finalize(msg)
819
820        self.sock.send(msg, 0)
821
822        done = False
823        rsp = []
824        while not done:
825            reply = self.sock.recv(128 * 1024)
826            nms = NlMsgs(reply, attr_space=op.attr_set)
827            for nl_msg in nms:
828                if nl_msg.extack:
829                    self._decode_extack(msg, op, nl_msg.extack)
830
831                if nl_msg.error:
832                    raise NlError(nl_msg)
833                if nl_msg.done:
834                    if nl_msg.extack:
835                        print("Netlink warning:")
836                        print(nl_msg)
837                    done = True
838                    break
839
840                decoded = self.nlproto.decode(self, nl_msg)
841
842                # Check if this is a reply to our request
843                if nl_msg.nl_seq != req_seq or decoded.cmd() != op.rsp_value:
844                    if decoded.cmd() in self.async_msg_ids:
845                        self.handle_ntf(decoded)
846                        continue
847                    else:
848                        print('Unexpected message: ' + repr(decoded))
849                        continue
850
851                rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
852                if op.fixed_header:
853                    rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header))
854                rsp.append(rsp_msg)
855
856        if not rsp:
857            return None
858        if not dump and len(rsp) == 1:
859            return rsp[0]
860        return rsp
861
862    def do(self, method, vals, flags=None):
863        return self._op(method, vals, flags)
864
865    def dump(self, method, vals):
866        return self._op(method, vals, [], dump=True)
867