xref: /linux/tools/net/ynl/pyynl/lib/ynl.py (revision fcab107abe1ab5be9dbe874baa722372da8f4f73)
1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2
3from collections import namedtuple
4from enum import Enum
5import functools
6import os
7import random
8import socket
9import struct
10from struct import Struct
11import sys
12import yaml
13import ipaddress
14import uuid
15import queue
16import selectors
17import time
18
19from .nlspec import SpecFamily
20
21#
22# Generic Netlink code which should really be in some library, but I can't quickly find one.
23#
24
25
26class Netlink:
27    # Netlink socket
28    SOL_NETLINK = 270
29
30    NETLINK_ADD_MEMBERSHIP = 1
31    NETLINK_CAP_ACK = 10
32    NETLINK_EXT_ACK = 11
33    NETLINK_GET_STRICT_CHK = 12
34
35    # Netlink message
36    NLMSG_ERROR = 2
37    NLMSG_DONE = 3
38
39    NLM_F_REQUEST = 1
40    NLM_F_ACK = 4
41    NLM_F_ROOT = 0x100
42    NLM_F_MATCH = 0x200
43
44    NLM_F_REPLACE = 0x100
45    NLM_F_EXCL = 0x200
46    NLM_F_CREATE = 0x400
47    NLM_F_APPEND = 0x800
48
49    NLM_F_CAPPED = 0x100
50    NLM_F_ACK_TLVS = 0x200
51
52    NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
53
54    NLA_F_NESTED = 0x8000
55    NLA_F_NET_BYTEORDER = 0x4000
56
57    NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
58
59    # Genetlink defines
60    NETLINK_GENERIC = 16
61
62    GENL_ID_CTRL = 0x10
63
64    # nlctrl
65    CTRL_CMD_GETFAMILY = 3
66
67    CTRL_ATTR_FAMILY_ID = 1
68    CTRL_ATTR_FAMILY_NAME = 2
69    CTRL_ATTR_MAXATTR = 5
70    CTRL_ATTR_MCAST_GROUPS = 7
71
72    CTRL_ATTR_MCAST_GRP_NAME = 1
73    CTRL_ATTR_MCAST_GRP_ID = 2
74
75    # Extack types
76    NLMSGERR_ATTR_MSG = 1
77    NLMSGERR_ATTR_OFFS = 2
78    NLMSGERR_ATTR_COOKIE = 3
79    NLMSGERR_ATTR_POLICY = 4
80    NLMSGERR_ATTR_MISS_TYPE = 5
81    NLMSGERR_ATTR_MISS_NEST = 6
82
83    # Policy types
84    NL_POLICY_TYPE_ATTR_TYPE = 1
85    NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2
86    NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3
87    NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4
88    NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5
89    NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6
90    NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7
91    NL_POLICY_TYPE_ATTR_POLICY_IDX = 8
92    NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9
93    NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10
94    NL_POLICY_TYPE_ATTR_PAD = 11
95    NL_POLICY_TYPE_ATTR_MASK = 12
96
97    AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64',
98                                  's8', 's16', 's32', 's64',
99                                  'binary', 'string', 'nul-string',
100                                  'nested', 'nested-array',
101                                  'bitfield32', 'sint', 'uint'])
102
103class NlError(Exception):
104  def __init__(self, nl_msg):
105    self.nl_msg = nl_msg
106    self.error = -nl_msg.error
107
108  def __str__(self):
109    return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}"
110
111
112class ConfigError(Exception):
113    pass
114
115
116class NlAttr:
117    ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
118    type_formats = {
119        'u8' : ScalarFormat(Struct('B'), Struct("B"),  Struct("B")),
120        's8' : ScalarFormat(Struct('b'), Struct("b"),  Struct("b")),
121        'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
122        's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
123        'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
124        's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
125        'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
126        's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
127    }
128
129    def __init__(self, raw, offset):
130        self._len, self._type = struct.unpack("HH", raw[offset : offset + 4])
131        self.type = self._type & ~Netlink.NLA_TYPE_MASK
132        self.is_nest = self._type & Netlink.NLA_F_NESTED
133        self.payload_len = self._len
134        self.full_len = (self.payload_len + 3) & ~3
135        self.raw = raw[offset + 4 : offset + self.payload_len]
136
137    @classmethod
138    def get_format(cls, attr_type, byte_order=None):
139        format = cls.type_formats[attr_type]
140        if byte_order:
141            return format.big if byte_order == "big-endian" \
142                else format.little
143        return format.native
144
145    def as_scalar(self, attr_type, byte_order=None):
146        format = self.get_format(attr_type, byte_order)
147        return format.unpack(self.raw)[0]
148
149    def as_auto_scalar(self, attr_type, byte_order=None):
150        if len(self.raw) != 4 and len(self.raw) != 8:
151            raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}")
152        real_type = attr_type[0] + str(len(self.raw) * 8)
153        format = self.get_format(real_type, byte_order)
154        return format.unpack(self.raw)[0]
155
156    def as_strz(self):
157        return self.raw.decode('ascii')[:-1]
158
159    def as_bin(self):
160        return self.raw
161
162    def as_c_array(self, type):
163        format = self.get_format(type)
164        return [ x[0] for x in format.iter_unpack(self.raw) ]
165
166    def __repr__(self):
167        return f"[type:{self.type} len:{self._len}] {self.raw}"
168
169
170class NlAttrs:
171    def __init__(self, msg, offset=0):
172        self.attrs = []
173
174        while offset < len(msg):
175            attr = NlAttr(msg, offset)
176            offset += attr.full_len
177            self.attrs.append(attr)
178
179    def __iter__(self):
180        yield from self.attrs
181
182    def __repr__(self):
183        msg = ''
184        for a in self.attrs:
185            if msg:
186                msg += '\n'
187            msg += repr(a)
188        return msg
189
190
191class NlMsg:
192    def __init__(self, msg, offset, attr_space=None):
193        self.hdr = msg[offset : offset + 16]
194
195        self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
196            struct.unpack("IHHII", self.hdr)
197
198        self.raw = msg[offset + 16 : offset + self.nl_len]
199
200        self.error = 0
201        self.done = 0
202
203        extack_off = None
204        if self.nl_type == Netlink.NLMSG_ERROR:
205            self.error = struct.unpack("i", self.raw[0:4])[0]
206            self.done = 1
207            extack_off = 20
208        elif self.nl_type == Netlink.NLMSG_DONE:
209            self.error = struct.unpack("i", self.raw[0:4])[0]
210            self.done = 1
211            extack_off = 4
212
213        self.extack = None
214        if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
215            self.extack = dict()
216            extack_attrs = NlAttrs(self.raw[extack_off:])
217            for extack in extack_attrs:
218                if extack.type == Netlink.NLMSGERR_ATTR_MSG:
219                    self.extack['msg'] = extack.as_strz()
220                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
221                    self.extack['miss-type'] = extack.as_scalar('u32')
222                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
223                    self.extack['miss-nest'] = extack.as_scalar('u32')
224                elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
225                    self.extack['bad-attr-offs'] = extack.as_scalar('u32')
226                elif extack.type == Netlink.NLMSGERR_ATTR_POLICY:
227                    self.extack['policy'] = self._decode_policy(extack.raw)
228                else:
229                    if 'unknown' not in self.extack:
230                        self.extack['unknown'] = []
231                    self.extack['unknown'].append(extack)
232
233            if attr_space:
234                # We don't have the ability to parse nests yet, so only do global
235                if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
236                    miss_type = self.extack['miss-type']
237                    if miss_type in attr_space.attrs_by_val:
238                        spec = attr_space.attrs_by_val[miss_type]
239                        self.extack['miss-type'] = spec['name']
240                        if 'doc' in spec:
241                            self.extack['miss-type-doc'] = spec['doc']
242
243    def _decode_policy(self, raw):
244        policy = {}
245        for attr in NlAttrs(raw):
246            if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE:
247                type = attr.as_scalar('u32')
248                policy['type'] = Netlink.AttrType(type).name
249            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S:
250                policy['min-value'] = attr.as_scalar('s64')
251            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S:
252                policy['max-value'] = attr.as_scalar('s64')
253            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U:
254                policy['min-value'] = attr.as_scalar('u64')
255            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U:
256                policy['max-value'] = attr.as_scalar('u64')
257            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH:
258                policy['min-length'] = attr.as_scalar('u32')
259            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH:
260                policy['max-length'] = attr.as_scalar('u32')
261            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK:
262                policy['bitfield32-mask'] = attr.as_scalar('u32')
263            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK:
264                policy['mask'] = attr.as_scalar('u64')
265        return policy
266
267    def cmd(self):
268        return self.nl_type
269
270    def __repr__(self):
271        msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}"
272        if self.error:
273            msg += '\n\terror: ' + str(self.error)
274        if self.extack:
275            msg += '\n\textack: ' + repr(self.extack)
276        return msg
277
278
279class NlMsgs:
280    def __init__(self, data, attr_space=None):
281        self.msgs = []
282
283        offset = 0
284        while offset < len(data):
285            msg = NlMsg(data, offset, attr_space=attr_space)
286            offset += msg.nl_len
287            self.msgs.append(msg)
288
289    def __iter__(self):
290        yield from self.msgs
291
292
293genl_family_name_to_id = None
294
295
296def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
297    # we prepend length in _genl_msg_finalize()
298    if seq is None:
299        seq = random.randint(1, 1024)
300    nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
301    genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
302    return nlmsg + genlmsg
303
304
305def _genl_msg_finalize(msg):
306    return struct.pack("I", len(msg) + 4) + msg
307
308
309def _genl_load_families():
310    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
311        sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
312
313        msg = _genl_msg(Netlink.GENL_ID_CTRL,
314                        Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
315                        Netlink.CTRL_CMD_GETFAMILY, 1)
316        msg = _genl_msg_finalize(msg)
317
318        sock.send(msg, 0)
319
320        global genl_family_name_to_id
321        genl_family_name_to_id = dict()
322
323        while True:
324            reply = sock.recv(128 * 1024)
325            nms = NlMsgs(reply)
326            for nl_msg in nms:
327                if nl_msg.error:
328                    print("Netlink error:", nl_msg.error)
329                    return
330                if nl_msg.done:
331                    return
332
333                gm = GenlMsg(nl_msg)
334                fam = dict()
335                for attr in NlAttrs(gm.raw):
336                    if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
337                        fam['id'] = attr.as_scalar('u16')
338                    elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
339                        fam['name'] = attr.as_strz()
340                    elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
341                        fam['maxattr'] = attr.as_scalar('u32')
342                    elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
343                        fam['mcast'] = dict()
344                        for entry in NlAttrs(attr.raw):
345                            mcast_name = None
346                            mcast_id = None
347                            for entry_attr in NlAttrs(entry.raw):
348                                if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
349                                    mcast_name = entry_attr.as_strz()
350                                elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
351                                    mcast_id = entry_attr.as_scalar('u32')
352                            if mcast_name and mcast_id is not None:
353                                fam['mcast'][mcast_name] = mcast_id
354                if 'name' in fam and 'id' in fam:
355                    genl_family_name_to_id[fam['name']] = fam
356
357
358class GenlMsg:
359    def __init__(self, nl_msg):
360        self.nl = nl_msg
361        self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
362        self.raw = nl_msg.raw[4:]
363
364    def cmd(self):
365        return self.genl_cmd
366
367    def __repr__(self):
368        msg = repr(self.nl)
369        msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
370        for a in self.raw_attrs:
371            msg += '\t\t' + repr(a) + '\n'
372        return msg
373
374
375class NetlinkProtocol:
376    def __init__(self, family_name, proto_num):
377        self.family_name = family_name
378        self.proto_num = proto_num
379
380    def _message(self, nl_type, nl_flags, seq=None):
381        if seq is None:
382            seq = random.randint(1, 1024)
383        nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
384        return nlmsg
385
386    def message(self, flags, command, version, seq=None):
387        return self._message(command, flags, seq)
388
389    def _decode(self, nl_msg):
390        return nl_msg
391
392    def decode(self, ynl, nl_msg, op):
393        msg = self._decode(nl_msg)
394        if op is None:
395            op = ynl.rsp_by_value[msg.cmd()]
396        fixed_header_size = ynl._struct_size(op.fixed_header)
397        msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size)
398        return msg
399
400    def get_mcast_id(self, mcast_name, mcast_groups):
401        if mcast_name not in mcast_groups:
402            raise Exception(f'Multicast group "{mcast_name}" not present in the spec')
403        return mcast_groups[mcast_name].value
404
405    def msghdr_size(self):
406        return 16
407
408
409class GenlProtocol(NetlinkProtocol):
410    def __init__(self, family_name):
411        super().__init__(family_name, Netlink.NETLINK_GENERIC)
412
413        global genl_family_name_to_id
414        if genl_family_name_to_id is None:
415            _genl_load_families()
416
417        self.genl_family = genl_family_name_to_id[family_name]
418        self.family_id = genl_family_name_to_id[family_name]['id']
419
420    def message(self, flags, command, version, seq=None):
421        nlmsg = self._message(self.family_id, flags, seq)
422        genlmsg = struct.pack("BBH", command, version, 0)
423        return nlmsg + genlmsg
424
425    def _decode(self, nl_msg):
426        return GenlMsg(nl_msg)
427
428    def get_mcast_id(self, mcast_name, mcast_groups):
429        if mcast_name not in self.genl_family['mcast']:
430            raise Exception(f'Multicast group "{mcast_name}" not present in the family')
431        return self.genl_family['mcast'][mcast_name]
432
433    def msghdr_size(self):
434        return super().msghdr_size() + 4
435
436
437class SpaceAttrs:
438    SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values'])
439
440    def __init__(self, attr_space, attrs, outer = None):
441        outer_scopes = outer.scopes if outer else []
442        inner_scope = self.SpecValuesPair(attr_space, attrs)
443        self.scopes = [inner_scope] + outer_scopes
444
445    def lookup(self, name):
446        for scope in self.scopes:
447            if name in scope.spec:
448                if name in scope.values:
449                    return scope.values[name]
450                spec_name = scope.spec.yaml['name']
451                raise Exception(
452                    f"No value for '{name}' in attribute space '{spec_name}'")
453        raise Exception(f"Attribute '{name}' not defined in any attribute-set")
454
455
456#
457# YNL implementation details.
458#
459
460
461class YnlFamily(SpecFamily):
462    def __init__(self, def_path, schema=None, process_unknown=False,
463                 recv_size=0):
464        super().__init__(def_path, schema)
465
466        self.include_raw = False
467        self.process_unknown = process_unknown
468
469        try:
470            if self.proto == "netlink-raw":
471                self.nlproto = NetlinkProtocol(self.yaml['name'],
472                                               self.yaml['protonum'])
473            else:
474                self.nlproto = GenlProtocol(self.yaml['name'])
475        except KeyError:
476            raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
477
478        self._recv_dbg = False
479        # Note that netlink will use conservative (min) message size for
480        # the first dump recv() on the socket, our setting will only matter
481        # from the second recv() on.
482        self._recv_size = recv_size if recv_size else 131072
483        # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo)
484        # for a message, so smaller receive sizes will lead to truncation.
485        # Note that the min size for other families may be larger than 4k!
486        if self._recv_size < 4000:
487            raise ConfigError()
488
489        self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
490        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
491        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
492        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)
493
494        self.async_msg_ids = set()
495        self.async_msg_queue = queue.Queue()
496
497        for msg in self.msgs.values():
498            if msg.is_async:
499                self.async_msg_ids.add(msg.rsp_value)
500
501        for op_name, op in self.ops.items():
502            bound_f = functools.partial(self._op, op_name)
503            setattr(self, op.ident_name, bound_f)
504
505
506    def ntf_subscribe(self, mcast_name):
507        mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
508        self.sock.bind((0, 0))
509        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
510                             mcast_id)
511
512    def set_recv_dbg(self, enabled):
513        self._recv_dbg = enabled
514
515    def _recv_dbg_print(self, reply, nl_msgs):
516        if not self._recv_dbg:
517            return
518        print("Recv: read", len(reply), "bytes,",
519              len(nl_msgs.msgs), "messages", file=sys.stderr)
520        for nl_msg in nl_msgs:
521            print("  ", nl_msg, file=sys.stderr)
522
523    def _encode_enum(self, attr_spec, value):
524        enum = self.consts[attr_spec['enum']]
525        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
526            scalar = 0
527            if isinstance(value, str):
528                value = [value]
529            for single_value in value:
530                scalar += enum.entries[single_value].user_value(as_flags = True)
531            return scalar
532        else:
533            return enum.entries[value].user_value()
534
535    def _get_scalar(self, attr_spec, value):
536        try:
537            return int(value)
538        except (ValueError, TypeError) as e:
539            if 'enum' in attr_spec:
540                return self._encode_enum(attr_spec, value)
541            if attr_spec.display_hint:
542                return self._from_string(value, attr_spec)
543            raise e
544
545    def _add_attr(self, space, name, value, search_attrs):
546        try:
547            attr = self.attr_sets[space][name]
548        except KeyError:
549            raise Exception(f"Space '{space}' has no attribute '{name}'")
550        nl_type = attr.value
551
552        if attr.is_multi and isinstance(value, list):
553            attr_payload = b''
554            for subvalue in value:
555                attr_payload += self._add_attr(space, name, subvalue, search_attrs)
556            return attr_payload
557
558        if attr["type"] == 'nest':
559            nl_type |= Netlink.NLA_F_NESTED
560            attr_payload = b''
561            sub_space = attr['nested-attributes']
562            sub_attrs = SpaceAttrs(self.attr_sets[sub_space], value, search_attrs)
563            for subname, subvalue in value.items():
564                attr_payload += self._add_attr(sub_space, subname, subvalue, sub_attrs)
565        elif attr["type"] == 'flag':
566            if not value:
567                # If value is absent or false then skip attribute creation.
568                return b''
569            attr_payload = b''
570        elif attr["type"] == 'string':
571            attr_payload = str(value).encode('ascii') + b'\x00'
572        elif attr["type"] == 'binary':
573            if isinstance(value, bytes):
574                attr_payload = value
575            elif isinstance(value, str):
576                if attr.display_hint:
577                    attr_payload = self._from_string(value, attr)
578                else:
579                    attr_payload = bytes.fromhex(value)
580            elif isinstance(value, dict) and attr.struct_name:
581                attr_payload = self._encode_struct(attr.struct_name, value)
582            else:
583                raise Exception(f'Unknown type for binary attribute, value: {value}')
584        elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar:
585            scalar = self._get_scalar(attr, value)
586            if attr.is_auto_scalar:
587                attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64')
588            else:
589                attr_type = attr["type"]
590            format = NlAttr.get_format(attr_type, attr.byte_order)
591            attr_payload = format.pack(scalar)
592        elif attr['type'] in "bitfield32":
593            scalar_value = self._get_scalar(attr, value["value"])
594            scalar_selector = self._get_scalar(attr, value["selector"])
595            attr_payload = struct.pack("II", scalar_value, scalar_selector)
596        elif attr['type'] == 'sub-message':
597            msg_format, _ = self._resolve_selector(attr, search_attrs)
598            attr_payload = b''
599            if msg_format.fixed_header:
600                attr_payload += self._encode_struct(msg_format.fixed_header, value)
601            if msg_format.attr_set:
602                if msg_format.attr_set in self.attr_sets:
603                    nl_type |= Netlink.NLA_F_NESTED
604                    sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs)
605                    for subname, subvalue in value.items():
606                        attr_payload += self._add_attr(msg_format.attr_set,
607                                                       subname, subvalue, sub_attrs)
608                else:
609                    raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'")
610        else:
611            raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
612
613        pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
614        return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
615
616    def _decode_enum(self, raw, attr_spec):
617        enum = self.consts[attr_spec['enum']]
618        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
619            i = 0
620            value = set()
621            while raw:
622                if raw & 1:
623                    value.add(enum.entries_by_val[i].name)
624                raw >>= 1
625                i += 1
626        else:
627            value = enum.entries_by_val[raw].name
628        return value
629
630    def _decode_binary(self, attr, attr_spec):
631        if attr_spec.struct_name:
632            decoded = self._decode_struct(attr.raw, attr_spec.struct_name)
633        elif attr_spec.sub_type:
634            decoded = attr.as_c_array(attr_spec.sub_type)
635            if 'enum' in attr_spec:
636                decoded = [ self._decode_enum(x, attr_spec) for x in decoded ]
637            elif attr_spec.display_hint:
638                decoded = [ self._formatted_string(x, attr_spec.display_hint)
639                            for x in decoded ]
640        else:
641            decoded = attr.as_bin()
642            if attr_spec.display_hint:
643                decoded = self._formatted_string(decoded, attr_spec.display_hint)
644        return decoded
645
646    def _decode_array_attr(self, attr, attr_spec):
647        decoded = []
648        offset = 0
649        while offset < len(attr.raw):
650            item = NlAttr(attr.raw, offset)
651            offset += item.full_len
652
653            if attr_spec["sub-type"] == 'nest':
654                subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes'])
655                decoded.append({ item.type: subattrs })
656            elif attr_spec["sub-type"] == 'binary':
657                subattr = item.as_bin()
658                if attr_spec.display_hint:
659                    subattr = self._formatted_string(subattr, attr_spec.display_hint)
660                decoded.append(subattr)
661            elif attr_spec["sub-type"] in NlAttr.type_formats:
662                subattr = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order)
663                if 'enum' in attr_spec:
664                    subattr = self._decode_enum(subattr, attr_spec)
665                elif attr_spec.display_hint:
666                    subattr = self._formatted_string(subattr, attr_spec.display_hint)
667                decoded.append(subattr)
668            else:
669                raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}')
670        return decoded
671
672    def _decode_nest_type_value(self, attr, attr_spec):
673        decoded = {}
674        value = attr
675        for name in attr_spec['type-value']:
676            value = NlAttr(value.raw, 0)
677            decoded[name] = value.type
678        subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes'])
679        decoded.update(subattrs)
680        return decoded
681
682    def _decode_unknown(self, attr):
683        if attr.is_nest:
684            return self._decode(NlAttrs(attr.raw), None)
685        else:
686            return attr.as_bin()
687
688    def _rsp_add(self, rsp, name, is_multi, decoded):
689        if is_multi == None:
690            if name in rsp and type(rsp[name]) is not list:
691                rsp[name] = [rsp[name]]
692                is_multi = True
693            else:
694                is_multi = False
695
696        if not is_multi:
697            rsp[name] = decoded
698        elif name in rsp:
699            rsp[name].append(decoded)
700        else:
701            rsp[name] = [decoded]
702
703    def _resolve_selector(self, attr_spec, search_attrs):
704        sub_msg = attr_spec.sub_message
705        if sub_msg not in self.sub_msgs:
706            raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}")
707        sub_msg_spec = self.sub_msgs[sub_msg]
708
709        selector = attr_spec.selector
710        value = search_attrs.lookup(selector)
711        if value not in sub_msg_spec.formats:
712            raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'")
713
714        spec = sub_msg_spec.formats[value]
715        return spec, value
716
717    def _decode_sub_msg(self, attr, attr_spec, search_attrs):
718        msg_format, _ = self._resolve_selector(attr_spec, search_attrs)
719        decoded = {}
720        offset = 0
721        if msg_format.fixed_header:
722            decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header));
723            offset = self._struct_size(msg_format.fixed_header)
724        if msg_format.attr_set:
725            if msg_format.attr_set in self.attr_sets:
726                subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set)
727                decoded.update(subdict)
728            else:
729                raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'")
730        return decoded
731
732    def _decode(self, attrs, space, outer_attrs = None):
733        rsp = dict()
734        if space:
735            attr_space = self.attr_sets[space]
736            search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs)
737
738        for attr in attrs:
739            try:
740                attr_spec = attr_space.attrs_by_val[attr.type]
741            except (KeyError, UnboundLocalError):
742                if not self.process_unknown:
743                    raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'")
744                attr_name = f"UnknownAttr({attr.type})"
745                self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr))
746                continue
747
748            try:
749                if attr_spec["type"] == 'nest':
750                    subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs)
751                    decoded = subdict
752                elif attr_spec["type"] == 'string':
753                    decoded = attr.as_strz()
754                elif attr_spec["type"] == 'binary':
755                    decoded = self._decode_binary(attr, attr_spec)
756                elif attr_spec["type"] == 'flag':
757                    decoded = True
758                elif attr_spec.is_auto_scalar:
759                    decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order)
760                elif attr_spec["type"] in NlAttr.type_formats:
761                    decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
762                    if 'enum' in attr_spec:
763                        decoded = self._decode_enum(decoded, attr_spec)
764                    elif attr_spec.display_hint:
765                        decoded = self._formatted_string(decoded, attr_spec.display_hint)
766                elif attr_spec["type"] == 'indexed-array':
767                    decoded = self._decode_array_attr(attr, attr_spec)
768                elif attr_spec["type"] == 'bitfield32':
769                    value, selector = struct.unpack("II", attr.raw)
770                    if 'enum' in attr_spec:
771                        value = self._decode_enum(value, attr_spec)
772                        selector = self._decode_enum(selector, attr_spec)
773                    decoded = {"value": value, "selector": selector}
774                elif attr_spec["type"] == 'sub-message':
775                    decoded = self._decode_sub_msg(attr, attr_spec, search_attrs)
776                elif attr_spec["type"] == 'nest-type-value':
777                    decoded = self._decode_nest_type_value(attr, attr_spec)
778                else:
779                    if not self.process_unknown:
780                        raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
781                    decoded = self._decode_unknown(attr)
782
783                self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded)
784            except:
785                print(f"Error decoding '{attr_spec.name}' from '{space}'")
786                raise
787
788        return rsp
789
790    def _decode_extack_path(self, attrs, attr_set, offset, target, search_attrs):
791        for attr in attrs:
792            try:
793                attr_spec = attr_set.attrs_by_val[attr.type]
794            except KeyError:
795                raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'")
796            if offset > target:
797                break
798            if offset == target:
799                return '.' + attr_spec.name
800
801            if offset + attr.full_len <= target:
802                offset += attr.full_len
803                continue
804
805            pathname = attr_spec.name
806            if attr_spec['type'] == 'nest':
807                sub_attrs = self.attr_sets[attr_spec['nested-attributes']]
808                search_attrs = SpaceAttrs(sub_attrs, search_attrs.lookup(attr_spec['name']))
809            elif attr_spec['type'] == 'sub-message':
810                msg_format, value = self._resolve_selector(attr_spec, search_attrs)
811                if msg_format is None:
812                    raise Exception(f"Can't resolve sub-message of {attr_spec['name']} for extack")
813                sub_attrs = self.attr_sets[msg_format.attr_set]
814                pathname += f"({value})"
815            else:
816                raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
817            offset += 4
818            subpath = self._decode_extack_path(NlAttrs(attr.raw), sub_attrs,
819                                               offset, target, search_attrs)
820            if subpath is None:
821                return None
822            return '.' + pathname + subpath
823
824        return None
825
826    def _decode_extack(self, request, op, extack, vals):
827        if 'bad-attr-offs' not in extack:
828            return
829
830        msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op)
831        offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header)
832        search_attrs = SpaceAttrs(op.attr_set, vals)
833        path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
834                                        extack['bad-attr-offs'], search_attrs)
835        if path:
836            del extack['bad-attr-offs']
837            extack['bad-attr'] = path
838
839    def _struct_size(self, name):
840        if name:
841            members = self.consts[name].members
842            size = 0
843            for m in members:
844                if m.type in ['pad', 'binary']:
845                    if m.struct:
846                        size += self._struct_size(m.struct)
847                    else:
848                        size += m.len
849                else:
850                    format = NlAttr.get_format(m.type, m.byte_order)
851                    size += format.size
852            return size
853        else:
854            return 0
855
856    def _decode_struct(self, data, name):
857        members = self.consts[name].members
858        attrs = dict()
859        offset = 0
860        for m in members:
861            value = None
862            if m.type == 'pad':
863                offset += m.len
864            elif m.type == 'binary':
865                if m.struct:
866                    len = self._struct_size(m.struct)
867                    value = self._decode_struct(data[offset : offset + len],
868                                                m.struct)
869                    offset += len
870                else:
871                    value = data[offset : offset + m.len]
872                    offset += m.len
873            else:
874                format = NlAttr.get_format(m.type, m.byte_order)
875                [ value ] = format.unpack_from(data, offset)
876                offset += format.size
877            if value is not None:
878                if m.enum:
879                    value = self._decode_enum(value, m)
880                elif m.display_hint:
881                    value = self._formatted_string(value, m.display_hint)
882                attrs[m.name] = value
883        return attrs
884
885    def _encode_struct(self, name, vals):
886        members = self.consts[name].members
887        attr_payload = b''
888        for m in members:
889            value = vals.pop(m.name) if m.name in vals else None
890            if m.type == 'pad':
891                attr_payload += bytearray(m.len)
892            elif m.type == 'binary':
893                if m.struct:
894                    if value is None:
895                        value = dict()
896                    attr_payload += self._encode_struct(m.struct, value)
897                else:
898                    if value is None:
899                        attr_payload += bytearray(m.len)
900                    else:
901                        attr_payload += bytes.fromhex(value)
902            else:
903                if value is None:
904                    value = 0
905                format = NlAttr.get_format(m.type, m.byte_order)
906                attr_payload += format.pack(value)
907        return attr_payload
908
909    def _formatted_string(self, raw, display_hint):
910        if display_hint == 'mac':
911            formatted = ':'.join('%02x' % b for b in raw)
912        elif display_hint == 'hex':
913            if isinstance(raw, int):
914                formatted = hex(raw)
915            else:
916                formatted = bytes.hex(raw, ' ')
917        elif display_hint in [ 'ipv4', 'ipv6' ]:
918            formatted = format(ipaddress.ip_address(raw))
919        elif display_hint == 'uuid':
920            formatted = str(uuid.UUID(bytes=raw))
921        else:
922            formatted = raw
923        return formatted
924
925    def _from_string(self, string, attr_spec):
926        if attr_spec.display_hint in ['ipv4', 'ipv6']:
927            ip = ipaddress.ip_address(string)
928            if attr_spec['type'] == 'binary':
929                raw = ip.packed
930            else:
931                raw = int(ip)
932        else:
933            raise Exception(f"Display hint '{attr_spec.display_hint}' not implemented"
934                            f" when parsing '{attr_spec['name']}'")
935        return raw
936
937    def handle_ntf(self, decoded):
938        msg = dict()
939        if self.include_raw:
940            msg['raw'] = decoded
941        op = self.rsp_by_value[decoded.cmd()]
942        attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
943        if op.fixed_header:
944            attrs.update(self._decode_struct(decoded.raw, op.fixed_header))
945
946        msg['name'] = op['name']
947        msg['msg'] = attrs
948        self.async_msg_queue.put(msg)
949
950    def check_ntf(self):
951        while True:
952            try:
953                reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT)
954            except BlockingIOError:
955                return
956
957            nms = NlMsgs(reply)
958            self._recv_dbg_print(reply, nms)
959            for nl_msg in nms:
960                if nl_msg.error:
961                    print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
962                    print(nl_msg)
963                    continue
964                if nl_msg.done:
965                    print("Netlink done while checking for ntf!?")
966                    continue
967
968                decoded = self.nlproto.decode(self, nl_msg, None)
969                if decoded.cmd() not in self.async_msg_ids:
970                    print("Unexpected msg id while checking for ntf", decoded)
971                    continue
972
973                self.handle_ntf(decoded)
974
975    def poll_ntf(self, duration=None):
976        start_time = time.time()
977        selector = selectors.DefaultSelector()
978        selector.register(self.sock, selectors.EVENT_READ)
979
980        while True:
981            try:
982                yield self.async_msg_queue.get_nowait()
983            except queue.Empty:
984                if duration is not None:
985                    timeout = start_time + duration - time.time()
986                    if timeout <= 0:
987                        return
988                else:
989                    timeout = None
990                events = selector.select(timeout)
991                if events:
992                    self.check_ntf()
993
994    def operation_do_attributes(self, name):
995      """
996      For a given operation name, find and return a supported
997      set of attributes (as a dict).
998      """
999      op = self.find_operation(name)
1000      if not op:
1001        return None
1002
1003      return op['do']['request']['attributes'].copy()
1004
1005    def _encode_message(self, op, vals, flags, req_seq):
1006        nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
1007        for flag in flags or []:
1008            nl_flags |= flag
1009
1010        msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
1011        if op.fixed_header:
1012            msg += self._encode_struct(op.fixed_header, vals)
1013        search_attrs = SpaceAttrs(op.attr_set, vals)
1014        for name, value in vals.items():
1015            msg += self._add_attr(op.attr_set.name, name, value, search_attrs)
1016        msg = _genl_msg_finalize(msg)
1017        return msg
1018
1019    def _ops(self, ops):
1020        reqs_by_seq = {}
1021        req_seq = random.randint(1024, 65535)
1022        payload = b''
1023        for (method, vals, flags) in ops:
1024            op = self.ops[method]
1025            msg = self._encode_message(op, vals, flags, req_seq)
1026            reqs_by_seq[req_seq] = (op, vals, msg, flags)
1027            payload += msg
1028            req_seq += 1
1029
1030        self.sock.send(payload, 0)
1031
1032        done = False
1033        rsp = []
1034        op_rsp = []
1035        while not done:
1036            reply = self.sock.recv(self._recv_size)
1037            nms = NlMsgs(reply, attr_space=op.attr_set)
1038            self._recv_dbg_print(reply, nms)
1039            for nl_msg in nms:
1040                if nl_msg.nl_seq in reqs_by_seq:
1041                    (op, vals, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq]
1042                    if nl_msg.extack:
1043                        self._decode_extack(req_msg, op, nl_msg.extack, vals)
1044                else:
1045                    op = None
1046                    req_flags = []
1047
1048                if nl_msg.error:
1049                    raise NlError(nl_msg)
1050                if nl_msg.done:
1051                    if nl_msg.extack:
1052                        print("Netlink warning:")
1053                        print(nl_msg)
1054
1055                    if Netlink.NLM_F_DUMP in req_flags:
1056                        rsp.append(op_rsp)
1057                    elif not op_rsp:
1058                        rsp.append(None)
1059                    elif len(op_rsp) == 1:
1060                        rsp.append(op_rsp[0])
1061                    else:
1062                        rsp.append(op_rsp)
1063                    op_rsp = []
1064
1065                    del reqs_by_seq[nl_msg.nl_seq]
1066                    done = len(reqs_by_seq) == 0
1067                    break
1068
1069                decoded = self.nlproto.decode(self, nl_msg, op)
1070
1071                # Check if this is a reply to our request
1072                if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value:
1073                    if decoded.cmd() in self.async_msg_ids:
1074                        self.handle_ntf(decoded)
1075                        continue
1076                    else:
1077                        print('Unexpected message: ' + repr(decoded))
1078                        continue
1079
1080                rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
1081                if op.fixed_header:
1082                    rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header))
1083                op_rsp.append(rsp_msg)
1084
1085        return rsp
1086
1087    def _op(self, method, vals, flags=None, dump=False):
1088        req_flags = flags or []
1089        if dump:
1090            req_flags.append(Netlink.NLM_F_DUMP)
1091
1092        ops = [(method, vals, req_flags)]
1093        return self._ops(ops)[0]
1094
1095    def do(self, method, vals, flags=None):
1096        return self._op(method, vals, flags)
1097
1098    def dump(self, method, vals):
1099        return self._op(method, vals, dump=True)
1100
1101    def do_multi(self, ops):
1102        return self._ops(ops)
1103