xref: /linux/tools/net/ynl/pyynl/lib/ynl.py (revision 8f7aa3d3c7323f4ca2768a9e74ebbe359c4f8f88)
1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2
3from collections import namedtuple
4from enum import Enum
5import functools
6import os
7import random
8import socket
9import struct
10from struct import Struct
11import sys
12import ipaddress
13import uuid
14import queue
15import selectors
16import time
17
18from .nlspec import SpecFamily
19
20#
21# Generic Netlink code which should really be in some library, but I can't quickly find one.
22#
23
24
25class Netlink:
26    # Netlink socket
27    SOL_NETLINK = 270
28
29    NETLINK_ADD_MEMBERSHIP = 1
30    NETLINK_CAP_ACK = 10
31    NETLINK_EXT_ACK = 11
32    NETLINK_GET_STRICT_CHK = 12
33
34    # Netlink message
35    NLMSG_ERROR = 2
36    NLMSG_DONE = 3
37
38    NLM_F_REQUEST = 1
39    NLM_F_ACK = 4
40    NLM_F_ROOT = 0x100
41    NLM_F_MATCH = 0x200
42
43    NLM_F_REPLACE = 0x100
44    NLM_F_EXCL = 0x200
45    NLM_F_CREATE = 0x400
46    NLM_F_APPEND = 0x800
47
48    NLM_F_CAPPED = 0x100
49    NLM_F_ACK_TLVS = 0x200
50
51    NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
52
53    NLA_F_NESTED = 0x8000
54    NLA_F_NET_BYTEORDER = 0x4000
55
56    NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
57
58    # Genetlink defines
59    NETLINK_GENERIC = 16
60
61    GENL_ID_CTRL = 0x10
62
63    # nlctrl
64    CTRL_CMD_GETFAMILY = 3
65
66    CTRL_ATTR_FAMILY_ID = 1
67    CTRL_ATTR_FAMILY_NAME = 2
68    CTRL_ATTR_MAXATTR = 5
69    CTRL_ATTR_MCAST_GROUPS = 7
70
71    CTRL_ATTR_MCAST_GRP_NAME = 1
72    CTRL_ATTR_MCAST_GRP_ID = 2
73
74    # Extack types
75    NLMSGERR_ATTR_MSG = 1
76    NLMSGERR_ATTR_OFFS = 2
77    NLMSGERR_ATTR_COOKIE = 3
78    NLMSGERR_ATTR_POLICY = 4
79    NLMSGERR_ATTR_MISS_TYPE = 5
80    NLMSGERR_ATTR_MISS_NEST = 6
81
82    # Policy types
83    NL_POLICY_TYPE_ATTR_TYPE = 1
84    NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2
85    NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3
86    NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4
87    NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5
88    NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6
89    NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7
90    NL_POLICY_TYPE_ATTR_POLICY_IDX = 8
91    NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9
92    NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10
93    NL_POLICY_TYPE_ATTR_PAD = 11
94    NL_POLICY_TYPE_ATTR_MASK = 12
95
96    AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64',
97                                  's8', 's16', 's32', 's64',
98                                  'binary', 'string', 'nul-string',
99                                  'nested', 'nested-array',
100                                  'bitfield32', 'sint', 'uint'])
101
102class NlError(Exception):
103    def __init__(self, nl_msg):
104        self.nl_msg = nl_msg
105        self.error = -nl_msg.error
106
107    def __str__(self):
108        msg = "Netlink error: "
109
110        extack = self.nl_msg.extack.copy() if self.nl_msg.extack else {}
111        if 'msg' in extack:
112            msg += extack['msg'] + ': '
113            del extack['msg']
114        msg += os.strerror(self.error)
115        if extack:
116            msg += ' ' + str(extack)
117        return msg
118
119
120class ConfigError(Exception):
121    pass
122
123
124class NlAttr:
125    ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
126    type_formats = {
127        'u8' : ScalarFormat(Struct('B'), Struct("B"),  Struct("B")),
128        's8' : ScalarFormat(Struct('b'), Struct("b"),  Struct("b")),
129        'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
130        's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
131        'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
132        's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
133        'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
134        's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
135    }
136
137    def __init__(self, raw, offset):
138        self._len, self._type = struct.unpack("HH", raw[offset : offset + 4])
139        self.type = self._type & ~Netlink.NLA_TYPE_MASK
140        self.is_nest = self._type & Netlink.NLA_F_NESTED
141        self.payload_len = self._len
142        self.full_len = (self.payload_len + 3) & ~3
143        self.raw = raw[offset + 4 : offset + self.payload_len]
144
145    @classmethod
146    def get_format(cls, attr_type, byte_order=None):
147        format = cls.type_formats[attr_type]
148        if byte_order:
149            return format.big if byte_order == "big-endian" \
150                else format.little
151        return format.native
152
153    def as_scalar(self, attr_type, byte_order=None):
154        format = self.get_format(attr_type, byte_order)
155        return format.unpack(self.raw)[0]
156
157    def as_auto_scalar(self, attr_type, byte_order=None):
158        if len(self.raw) != 4 and len(self.raw) != 8:
159            raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}")
160        real_type = attr_type[0] + str(len(self.raw) * 8)
161        format = self.get_format(real_type, byte_order)
162        return format.unpack(self.raw)[0]
163
164    def as_strz(self):
165        return self.raw.decode('ascii')[:-1]
166
167    def as_bin(self):
168        return self.raw
169
170    def as_c_array(self, type):
171        format = self.get_format(type)
172        return [ x[0] for x in format.iter_unpack(self.raw) ]
173
174    def __repr__(self):
175        return f"[type:{self.type} len:{self._len}] {self.raw}"
176
177
178class NlAttrs:
179    def __init__(self, msg, offset=0):
180        self.attrs = []
181
182        while offset < len(msg):
183            attr = NlAttr(msg, offset)
184            offset += attr.full_len
185            self.attrs.append(attr)
186
187    def __iter__(self):
188        yield from self.attrs
189
190    def __repr__(self):
191        msg = ''
192        for a in self.attrs:
193            if msg:
194                msg += '\n'
195            msg += repr(a)
196        return msg
197
198
199class NlMsg:
200    def __init__(self, msg, offset, attr_space=None):
201        self.hdr = msg[offset : offset + 16]
202
203        self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
204            struct.unpack("IHHII", self.hdr)
205
206        self.raw = msg[offset + 16 : offset + self.nl_len]
207
208        self.error = 0
209        self.done = 0
210
211        extack_off = None
212        if self.nl_type == Netlink.NLMSG_ERROR:
213            self.error = struct.unpack("i", self.raw[0:4])[0]
214            self.done = 1
215            extack_off = 20
216        elif self.nl_type == Netlink.NLMSG_DONE:
217            self.error = struct.unpack("i", self.raw[0:4])[0]
218            self.done = 1
219            extack_off = 4
220
221        self.extack = None
222        if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
223            self.extack = dict()
224            extack_attrs = NlAttrs(self.raw[extack_off:])
225            for extack in extack_attrs:
226                if extack.type == Netlink.NLMSGERR_ATTR_MSG:
227                    self.extack['msg'] = extack.as_strz()
228                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
229                    self.extack['miss-type'] = extack.as_scalar('u32')
230                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
231                    self.extack['miss-nest'] = extack.as_scalar('u32')
232                elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
233                    self.extack['bad-attr-offs'] = extack.as_scalar('u32')
234                elif extack.type == Netlink.NLMSGERR_ATTR_POLICY:
235                    self.extack['policy'] = self._decode_policy(extack.raw)
236                else:
237                    if 'unknown' not in self.extack:
238                        self.extack['unknown'] = []
239                    self.extack['unknown'].append(extack)
240
241            if attr_space:
242                self.annotate_extack(attr_space)
243
244    def _decode_policy(self, raw):
245        policy = {}
246        for attr in NlAttrs(raw):
247            if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE:
248                type = attr.as_scalar('u32')
249                policy['type'] = Netlink.AttrType(type).name
250            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S:
251                policy['min-value'] = attr.as_scalar('s64')
252            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S:
253                policy['max-value'] = attr.as_scalar('s64')
254            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U:
255                policy['min-value'] = attr.as_scalar('u64')
256            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U:
257                policy['max-value'] = attr.as_scalar('u64')
258            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH:
259                policy['min-length'] = attr.as_scalar('u32')
260            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH:
261                policy['max-length'] = attr.as_scalar('u32')
262            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK:
263                policy['bitfield32-mask'] = attr.as_scalar('u32')
264            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK:
265                policy['mask'] = attr.as_scalar('u64')
266        return policy
267
268    def annotate_extack(self, attr_space):
269        """ Make extack more human friendly with attribute information """
270
271        # We don't have the ability to parse nests yet, so only do global
272        if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
273            miss_type = self.extack['miss-type']
274            if miss_type in attr_space.attrs_by_val:
275                spec = attr_space.attrs_by_val[miss_type]
276                self.extack['miss-type'] = spec['name']
277                if 'doc' in spec:
278                    self.extack['miss-type-doc'] = spec['doc']
279
280    def cmd(self):
281        return self.nl_type
282
283    def __repr__(self):
284        msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}"
285        if self.error:
286            msg += '\n\terror: ' + str(self.error)
287        if self.extack:
288            msg += '\n\textack: ' + repr(self.extack)
289        return msg
290
291
292class NlMsgs:
293    def __init__(self, data):
294        self.msgs = []
295
296        offset = 0
297        while offset < len(data):
298            msg = NlMsg(data, offset)
299            offset += msg.nl_len
300            self.msgs.append(msg)
301
302    def __iter__(self):
303        yield from self.msgs
304
305
306genl_family_name_to_id = None
307
308
309def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
310    # we prepend length in _genl_msg_finalize()
311    if seq is None:
312        seq = random.randint(1, 1024)
313    nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
314    genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
315    return nlmsg + genlmsg
316
317
318def _genl_msg_finalize(msg):
319    return struct.pack("I", len(msg) + 4) + msg
320
321
322def _genl_load_families():
323    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
324        sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
325
326        msg = _genl_msg(Netlink.GENL_ID_CTRL,
327                        Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
328                        Netlink.CTRL_CMD_GETFAMILY, 1)
329        msg = _genl_msg_finalize(msg)
330
331        sock.send(msg, 0)
332
333        global genl_family_name_to_id
334        genl_family_name_to_id = dict()
335
336        while True:
337            reply = sock.recv(128 * 1024)
338            nms = NlMsgs(reply)
339            for nl_msg in nms:
340                if nl_msg.error:
341                    print("Netlink error:", nl_msg.error)
342                    return
343                if nl_msg.done:
344                    return
345
346                gm = GenlMsg(nl_msg)
347                fam = dict()
348                for attr in NlAttrs(gm.raw):
349                    if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
350                        fam['id'] = attr.as_scalar('u16')
351                    elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
352                        fam['name'] = attr.as_strz()
353                    elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
354                        fam['maxattr'] = attr.as_scalar('u32')
355                    elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
356                        fam['mcast'] = dict()
357                        for entry in NlAttrs(attr.raw):
358                            mcast_name = None
359                            mcast_id = None
360                            for entry_attr in NlAttrs(entry.raw):
361                                if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
362                                    mcast_name = entry_attr.as_strz()
363                                elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
364                                    mcast_id = entry_attr.as_scalar('u32')
365                            if mcast_name and mcast_id is not None:
366                                fam['mcast'][mcast_name] = mcast_id
367                if 'name' in fam and 'id' in fam:
368                    genl_family_name_to_id[fam['name']] = fam
369
370
371class GenlMsg:
372    def __init__(self, nl_msg):
373        self.nl = nl_msg
374        self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
375        self.raw = nl_msg.raw[4:]
376
377    def cmd(self):
378        return self.genl_cmd
379
380    def __repr__(self):
381        msg = repr(self.nl)
382        msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
383        for a in self.raw_attrs:
384            msg += '\t\t' + repr(a) + '\n'
385        return msg
386
387
388class NetlinkProtocol:
389    def __init__(self, family_name, proto_num):
390        self.family_name = family_name
391        self.proto_num = proto_num
392
393    def _message(self, nl_type, nl_flags, seq=None):
394        if seq is None:
395            seq = random.randint(1, 1024)
396        nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
397        return nlmsg
398
399    def message(self, flags, command, version, seq=None):
400        return self._message(command, flags, seq)
401
402    def _decode(self, nl_msg):
403        return nl_msg
404
405    def decode(self, ynl, nl_msg, op):
406        msg = self._decode(nl_msg)
407        if op is None:
408            op = ynl.rsp_by_value[msg.cmd()]
409        fixed_header_size = ynl._struct_size(op.fixed_header)
410        msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size)
411        return msg
412
413    def get_mcast_id(self, mcast_name, mcast_groups):
414        if mcast_name not in mcast_groups:
415            raise Exception(f'Multicast group "{mcast_name}" not present in the spec')
416        return mcast_groups[mcast_name].value
417
418    def msghdr_size(self):
419        return 16
420
421
422class GenlProtocol(NetlinkProtocol):
423    def __init__(self, family_name):
424        super().__init__(family_name, Netlink.NETLINK_GENERIC)
425
426        global genl_family_name_to_id
427        if genl_family_name_to_id is None:
428            _genl_load_families()
429
430        self.genl_family = genl_family_name_to_id[family_name]
431        self.family_id = genl_family_name_to_id[family_name]['id']
432
433    def message(self, flags, command, version, seq=None):
434        nlmsg = self._message(self.family_id, flags, seq)
435        genlmsg = struct.pack("BBH", command, version, 0)
436        return nlmsg + genlmsg
437
438    def _decode(self, nl_msg):
439        return GenlMsg(nl_msg)
440
441    def get_mcast_id(self, mcast_name, mcast_groups):
442        if mcast_name not in self.genl_family['mcast']:
443            raise Exception(f'Multicast group "{mcast_name}" not present in the family')
444        return self.genl_family['mcast'][mcast_name]
445
446    def msghdr_size(self):
447        return super().msghdr_size() + 4
448
449
450class SpaceAttrs:
451    SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values'])
452
453    def __init__(self, attr_space, attrs, outer = None):
454        outer_scopes = outer.scopes if outer else []
455        inner_scope = self.SpecValuesPair(attr_space, attrs)
456        self.scopes = [inner_scope] + outer_scopes
457
458    def lookup(self, name):
459        for scope in self.scopes:
460            if name in scope.spec:
461                if name in scope.values:
462                    return scope.values[name]
463                spec_name = scope.spec.yaml['name']
464                raise Exception(
465                    f"No value for '{name}' in attribute space '{spec_name}'")
466        raise Exception(f"Attribute '{name}' not defined in any attribute-set")
467
468
469#
470# YNL implementation details.
471#
472
473
474class YnlFamily(SpecFamily):
475    def __init__(self, def_path, schema=None, process_unknown=False,
476                 recv_size=0):
477        super().__init__(def_path, schema)
478
479        self.include_raw = False
480        self.process_unknown = process_unknown
481
482        try:
483            if self.proto == "netlink-raw":
484                self.nlproto = NetlinkProtocol(self.yaml['name'],
485                                               self.yaml['protonum'])
486            else:
487                self.nlproto = GenlProtocol(self.yaml['name'])
488        except KeyError:
489            raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
490
491        self._recv_dbg = False
492        # Note that netlink will use conservative (min) message size for
493        # the first dump recv() on the socket, our setting will only matter
494        # from the second recv() on.
495        self._recv_size = recv_size if recv_size else 131072
496        # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo)
497        # for a message, so smaller receive sizes will lead to truncation.
498        # Note that the min size for other families may be larger than 4k!
499        if self._recv_size < 4000:
500            raise ConfigError()
501
502        self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
503        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
504        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
505        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)
506
507        self.async_msg_ids = set()
508        self.async_msg_queue = queue.Queue()
509
510        for msg in self.msgs.values():
511            if msg.is_async:
512                self.async_msg_ids.add(msg.rsp_value)
513
514        for op_name, op in self.ops.items():
515            bound_f = functools.partial(self._op, op_name)
516            setattr(self, op.ident_name, bound_f)
517
518
519    def ntf_subscribe(self, mcast_name):
520        mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
521        self.sock.bind((0, 0))
522        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
523                             mcast_id)
524
525    def set_recv_dbg(self, enabled):
526        self._recv_dbg = enabled
527
528    def _recv_dbg_print(self, reply, nl_msgs):
529        if not self._recv_dbg:
530            return
531        print("Recv: read", len(reply), "bytes,",
532              len(nl_msgs.msgs), "messages", file=sys.stderr)
533        for nl_msg in nl_msgs:
534            print("  ", nl_msg, file=sys.stderr)
535
536    def _encode_enum(self, attr_spec, value):
537        enum = self.consts[attr_spec['enum']]
538        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
539            scalar = 0
540            if isinstance(value, str):
541                value = [value]
542            for single_value in value:
543                scalar += enum.entries[single_value].user_value(as_flags = True)
544            return scalar
545        else:
546            return enum.entries[value].user_value()
547
548    def _get_scalar(self, attr_spec, value):
549        try:
550            return int(value)
551        except (ValueError, TypeError) as e:
552            if 'enum' in attr_spec:
553                return self._encode_enum(attr_spec, value)
554            if attr_spec.display_hint:
555                return self._from_string(value, attr_spec)
556            raise e
557
558    def _add_attr(self, space, name, value, search_attrs):
559        try:
560            attr = self.attr_sets[space][name]
561        except KeyError:
562            raise Exception(f"Space '{space}' has no attribute '{name}'")
563        nl_type = attr.value
564
565        if attr.is_multi and isinstance(value, list):
566            attr_payload = b''
567            for subvalue in value:
568                attr_payload += self._add_attr(space, name, subvalue, search_attrs)
569            return attr_payload
570
571        if attr["type"] == 'nest':
572            nl_type |= Netlink.NLA_F_NESTED
573            sub_space = attr['nested-attributes']
574            attr_payload = self._add_nest_attrs(value, sub_space, search_attrs)
575        elif attr['type'] == 'indexed-array' and attr['sub-type'] == 'nest':
576            nl_type |= Netlink.NLA_F_NESTED
577            sub_space = attr['nested-attributes']
578            attr_payload = self._encode_indexed_array(value, sub_space,
579                                                      search_attrs)
580        elif attr["type"] == 'flag':
581            if not value:
582                # If value is absent or false then skip attribute creation.
583                return b''
584            attr_payload = b''
585        elif attr["type"] == 'string':
586            attr_payload = str(value).encode('ascii') + b'\x00'
587        elif attr["type"] == 'binary':
588            if value is None:
589                attr_payload = b''
590            elif isinstance(value, bytes):
591                attr_payload = value
592            elif isinstance(value, str):
593                if attr.display_hint:
594                    attr_payload = self._from_string(value, attr)
595                else:
596                    attr_payload = bytes.fromhex(value)
597            elif isinstance(value, dict) and attr.struct_name:
598                attr_payload = self._encode_struct(attr.struct_name, value)
599            elif isinstance(value, list) and attr.sub_type in NlAttr.type_formats:
600                format = NlAttr.get_format(attr.sub_type)
601                attr_payload = b''.join([format.pack(x) for x in value])
602            else:
603                raise Exception(f'Unknown type for binary attribute, value: {value}')
604        elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar:
605            scalar = self._get_scalar(attr, value)
606            if attr.is_auto_scalar:
607                attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64')
608            else:
609                attr_type = attr["type"]
610            format = NlAttr.get_format(attr_type, attr.byte_order)
611            attr_payload = format.pack(scalar)
612        elif attr['type'] in "bitfield32":
613            scalar_value = self._get_scalar(attr, value["value"])
614            scalar_selector = self._get_scalar(attr, value["selector"])
615            attr_payload = struct.pack("II", scalar_value, scalar_selector)
616        elif attr['type'] == 'sub-message':
617            msg_format, _ = self._resolve_selector(attr, search_attrs)
618            attr_payload = b''
619            if msg_format.fixed_header:
620                attr_payload += self._encode_struct(msg_format.fixed_header, value)
621            if msg_format.attr_set:
622                if msg_format.attr_set in self.attr_sets:
623                    nl_type |= Netlink.NLA_F_NESTED
624                    sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs)
625                    for subname, subvalue in value.items():
626                        attr_payload += self._add_attr(msg_format.attr_set,
627                                                       subname, subvalue, sub_attrs)
628                else:
629                    raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'")
630        else:
631            raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
632
633        return self._add_attr_raw(nl_type, attr_payload)
634
635    def _add_attr_raw(self, nl_type, attr_payload):
636        pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
637        return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
638
639    def _add_nest_attrs(self, value, sub_space, search_attrs):
640        sub_attrs = SpaceAttrs(self.attr_sets[sub_space], value, search_attrs)
641        attr_payload = b''
642        for subname, subvalue in value.items():
643            attr_payload += self._add_attr(sub_space, subname, subvalue,
644                                           sub_attrs)
645        return attr_payload
646
647    def _encode_indexed_array(self, vals, sub_space, search_attrs):
648        attr_payload = b''
649        for i, val in enumerate(vals):
650            idx = i | Netlink.NLA_F_NESTED
651            val_payload = self._add_nest_attrs(val, sub_space, search_attrs)
652            attr_payload += self._add_attr_raw(idx, val_payload)
653        return attr_payload
654
655    def _get_enum_or_unknown(self, enum, raw):
656        try:
657            name = enum.entries_by_val[raw].name
658        except KeyError as error:
659            if self.process_unknown:
660                name = f"Unknown({raw})"
661            else:
662                raise error
663        return name
664
665    def _decode_enum(self, raw, attr_spec):
666        enum = self.consts[attr_spec['enum']]
667        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
668            i = 0
669            value = set()
670            while raw:
671                if raw & 1:
672                    value.add(self._get_enum_or_unknown(enum, i))
673                raw >>= 1
674                i += 1
675        else:
676            value = self._get_enum_or_unknown(enum, raw)
677        return value
678
679    def _decode_binary(self, attr, attr_spec):
680        if attr_spec.struct_name:
681            decoded = self._decode_struct(attr.raw, attr_spec.struct_name)
682        elif attr_spec.sub_type:
683            decoded = attr.as_c_array(attr_spec.sub_type)
684            if 'enum' in attr_spec:
685                decoded = [ self._decode_enum(x, attr_spec) for x in decoded ]
686            elif attr_spec.display_hint:
687                decoded = [ self._formatted_string(x, attr_spec.display_hint)
688                            for x in decoded ]
689        else:
690            decoded = attr.as_bin()
691            if attr_spec.display_hint:
692                decoded = self._formatted_string(decoded, attr_spec.display_hint)
693        return decoded
694
695    def _decode_array_attr(self, attr, attr_spec):
696        decoded = []
697        offset = 0
698        while offset < len(attr.raw):
699            item = NlAttr(attr.raw, offset)
700            offset += item.full_len
701
702            if attr_spec["sub-type"] == 'nest':
703                subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes'])
704                decoded.append({ item.type: subattrs })
705            elif attr_spec["sub-type"] == 'binary':
706                subattr = item.as_bin()
707                if attr_spec.display_hint:
708                    subattr = self._formatted_string(subattr, attr_spec.display_hint)
709                decoded.append(subattr)
710            elif attr_spec["sub-type"] in NlAttr.type_formats:
711                subattr = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order)
712                if 'enum' in attr_spec:
713                    subattr = self._decode_enum(subattr, attr_spec)
714                elif attr_spec.display_hint:
715                    subattr = self._formatted_string(subattr, attr_spec.display_hint)
716                decoded.append(subattr)
717            else:
718                raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}')
719        return decoded
720
721    def _decode_nest_type_value(self, attr, attr_spec):
722        decoded = {}
723        value = attr
724        for name in attr_spec['type-value']:
725            value = NlAttr(value.raw, 0)
726            decoded[name] = value.type
727        subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes'])
728        decoded.update(subattrs)
729        return decoded
730
731    def _decode_unknown(self, attr):
732        if attr.is_nest:
733            return self._decode(NlAttrs(attr.raw), None)
734        else:
735            return attr.as_bin()
736
737    def _rsp_add(self, rsp, name, is_multi, decoded):
738        if is_multi is None:
739            if name in rsp and type(rsp[name]) is not list:
740                rsp[name] = [rsp[name]]
741                is_multi = True
742            else:
743                is_multi = False
744
745        if not is_multi:
746            rsp[name] = decoded
747        elif name in rsp:
748            rsp[name].append(decoded)
749        else:
750            rsp[name] = [decoded]
751
752    def _resolve_selector(self, attr_spec, search_attrs):
753        sub_msg = attr_spec.sub_message
754        if sub_msg not in self.sub_msgs:
755            raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}")
756        sub_msg_spec = self.sub_msgs[sub_msg]
757
758        selector = attr_spec.selector
759        value = search_attrs.lookup(selector)
760        if value not in sub_msg_spec.formats:
761            raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'")
762
763        spec = sub_msg_spec.formats[value]
764        return spec, value
765
766    def _decode_sub_msg(self, attr, attr_spec, search_attrs):
767        msg_format, _ = self._resolve_selector(attr_spec, search_attrs)
768        decoded = {}
769        offset = 0
770        if msg_format.fixed_header:
771            decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header))
772            offset = self._struct_size(msg_format.fixed_header)
773        if msg_format.attr_set:
774            if msg_format.attr_set in self.attr_sets:
775                subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set)
776                decoded.update(subdict)
777            else:
778                raise Exception(f"Unknown attribute-set '{msg_format.attr_set}' when decoding '{attr_spec.name}'")
779        return decoded
780
781    def _decode(self, attrs, space, outer_attrs = None):
782        rsp = dict()
783        if space:
784            attr_space = self.attr_sets[space]
785            search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs)
786
787        for attr in attrs:
788            try:
789                attr_spec = attr_space.attrs_by_val[attr.type]
790            except (KeyError, UnboundLocalError):
791                if not self.process_unknown:
792                    raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'")
793                attr_name = f"UnknownAttr({attr.type})"
794                self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr))
795                continue
796
797            try:
798                if attr_spec["type"] == 'nest':
799                    subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs)
800                    decoded = subdict
801                elif attr_spec["type"] == 'string':
802                    decoded = attr.as_strz()
803                elif attr_spec["type"] == 'binary':
804                    decoded = self._decode_binary(attr, attr_spec)
805                elif attr_spec["type"] == 'flag':
806                    decoded = True
807                elif attr_spec.is_auto_scalar:
808                    decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order)
809                    if 'enum' in attr_spec:
810                        decoded = self._decode_enum(decoded, attr_spec)
811                elif attr_spec["type"] in NlAttr.type_formats:
812                    decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
813                    if 'enum' in attr_spec:
814                        decoded = self._decode_enum(decoded, attr_spec)
815                    elif attr_spec.display_hint:
816                        decoded = self._formatted_string(decoded, attr_spec.display_hint)
817                elif attr_spec["type"] == 'indexed-array':
818                    decoded = self._decode_array_attr(attr, attr_spec)
819                elif attr_spec["type"] == 'bitfield32':
820                    value, selector = struct.unpack("II", attr.raw)
821                    if 'enum' in attr_spec:
822                        value = self._decode_enum(value, attr_spec)
823                        selector = self._decode_enum(selector, attr_spec)
824                    decoded = {"value": value, "selector": selector}
825                elif attr_spec["type"] == 'sub-message':
826                    decoded = self._decode_sub_msg(attr, attr_spec, search_attrs)
827                elif attr_spec["type"] == 'nest-type-value':
828                    decoded = self._decode_nest_type_value(attr, attr_spec)
829                else:
830                    if not self.process_unknown:
831                        raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
832                    decoded = self._decode_unknown(attr)
833
834                self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded)
835            except:
836                print(f"Error decoding '{attr_spec.name}' from '{space}'")
837                raise
838
839        return rsp
840
841    def _decode_extack_path(self, attrs, attr_set, offset, target, search_attrs):
842        for attr in attrs:
843            try:
844                attr_spec = attr_set.attrs_by_val[attr.type]
845            except KeyError:
846                raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'")
847            if offset > target:
848                break
849            if offset == target:
850                return '.' + attr_spec.name
851
852            if offset + attr.full_len <= target:
853                offset += attr.full_len
854                continue
855
856            pathname = attr_spec.name
857            if attr_spec['type'] == 'nest':
858                sub_attrs = self.attr_sets[attr_spec['nested-attributes']]
859                search_attrs = SpaceAttrs(sub_attrs, search_attrs.lookup(attr_spec['name']))
860            elif attr_spec['type'] == 'sub-message':
861                msg_format, value = self._resolve_selector(attr_spec, search_attrs)
862                if msg_format is None:
863                    raise Exception(f"Can't resolve sub-message of {attr_spec['name']} for extack")
864                sub_attrs = self.attr_sets[msg_format.attr_set]
865                pathname += f"({value})"
866            else:
867                raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
868            offset += 4
869            subpath = self._decode_extack_path(NlAttrs(attr.raw), sub_attrs,
870                                               offset, target, search_attrs)
871            if subpath is None:
872                return None
873            return '.' + pathname + subpath
874
875        return None
876
877    def _decode_extack(self, request, op, extack, vals):
878        if 'bad-attr-offs' not in extack:
879            return
880
881        msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op)
882        offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header)
883        search_attrs = SpaceAttrs(op.attr_set, vals)
884        path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
885                                        extack['bad-attr-offs'], search_attrs)
886        if path:
887            del extack['bad-attr-offs']
888            extack['bad-attr'] = path
889
890    def _struct_size(self, name):
891        if name:
892            members = self.consts[name].members
893            size = 0
894            for m in members:
895                if m.type in ['pad', 'binary']:
896                    if m.struct:
897                        size += self._struct_size(m.struct)
898                    else:
899                        size += m.len
900                else:
901                    format = NlAttr.get_format(m.type, m.byte_order)
902                    size += format.size
903            return size
904        else:
905            return 0
906
907    def _decode_struct(self, data, name):
908        members = self.consts[name].members
909        attrs = dict()
910        offset = 0
911        for m in members:
912            value = None
913            if m.type == 'pad':
914                offset += m.len
915            elif m.type == 'binary':
916                if m.struct:
917                    len = self._struct_size(m.struct)
918                    value = self._decode_struct(data[offset : offset + len],
919                                                m.struct)
920                    offset += len
921                else:
922                    value = data[offset : offset + m.len]
923                    offset += m.len
924            else:
925                format = NlAttr.get_format(m.type, m.byte_order)
926                [ value ] = format.unpack_from(data, offset)
927                offset += format.size
928            if value is not None:
929                if m.enum:
930                    value = self._decode_enum(value, m)
931                elif m.display_hint:
932                    value = self._formatted_string(value, m.display_hint)
933                attrs[m.name] = value
934        return attrs
935
936    def _encode_struct(self, name, vals):
937        members = self.consts[name].members
938        attr_payload = b''
939        for m in members:
940            value = vals.pop(m.name) if m.name in vals else None
941            if m.type == 'pad':
942                attr_payload += bytearray(m.len)
943            elif m.type == 'binary':
944                if m.struct:
945                    if value is None:
946                        value = dict()
947                    attr_payload += self._encode_struct(m.struct, value)
948                else:
949                    if value is None:
950                        attr_payload += bytearray(m.len)
951                    else:
952                        attr_payload += bytes.fromhex(value)
953            else:
954                if value is None:
955                    value = 0
956                format = NlAttr.get_format(m.type, m.byte_order)
957                attr_payload += format.pack(value)
958        return attr_payload
959
960    def _formatted_string(self, raw, display_hint):
961        if display_hint == 'mac':
962            formatted = ':'.join('%02x' % b for b in raw)
963        elif display_hint == 'hex':
964            if isinstance(raw, int):
965                formatted = hex(raw)
966            else:
967                formatted = bytes.hex(raw, ' ')
968        elif display_hint in [ 'ipv4', 'ipv6', 'ipv4-or-v6' ]:
969            formatted = format(ipaddress.ip_address(raw))
970        elif display_hint == 'uuid':
971            formatted = str(uuid.UUID(bytes=raw))
972        else:
973            formatted = raw
974        return formatted
975
976    def _from_string(self, string, attr_spec):
977        if attr_spec.display_hint in ['ipv4', 'ipv6', 'ipv4-or-v6']:
978            ip = ipaddress.ip_address(string)
979            if attr_spec['type'] == 'binary':
980                raw = ip.packed
981            else:
982                raw = int(ip)
983        elif attr_spec.display_hint == 'hex':
984            if attr_spec['type'] == 'binary':
985                raw = bytes.fromhex(string)
986            else:
987                raw = int(string, 16)
988        elif attr_spec.display_hint == 'mac':
989            # Parse MAC address in format "00:11:22:33:44:55" or "001122334455"
990            if ':' in string:
991                mac_bytes = [int(x, 16) for x in string.split(':')]
992            else:
993                if len(string) % 2 != 0:
994                    raise Exception(f"Invalid MAC address format: {string}")
995                mac_bytes = [int(string[i:i+2], 16) for i in range(0, len(string), 2)]
996            raw = bytes(mac_bytes)
997        else:
998            raise Exception(f"Display hint '{attr_spec.display_hint}' not implemented"
999                            f" when parsing '{attr_spec['name']}'")
1000        return raw
1001
1002    def handle_ntf(self, decoded):
1003        msg = dict()
1004        if self.include_raw:
1005            msg['raw'] = decoded
1006        op = self.rsp_by_value[decoded.cmd()]
1007        attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
1008        if op.fixed_header:
1009            attrs.update(self._decode_struct(decoded.raw, op.fixed_header))
1010
1011        msg['name'] = op['name']
1012        msg['msg'] = attrs
1013        self.async_msg_queue.put(msg)
1014
1015    def check_ntf(self):
1016        while True:
1017            try:
1018                reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT)
1019            except BlockingIOError:
1020                return
1021
1022            nms = NlMsgs(reply)
1023            self._recv_dbg_print(reply, nms)
1024            for nl_msg in nms:
1025                if nl_msg.error:
1026                    print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
1027                    print(nl_msg)
1028                    continue
1029                if nl_msg.done:
1030                    print("Netlink done while checking for ntf!?")
1031                    continue
1032
1033                decoded = self.nlproto.decode(self, nl_msg, None)
1034                if decoded.cmd() not in self.async_msg_ids:
1035                    print("Unexpected msg id while checking for ntf", decoded)
1036                    continue
1037
1038                self.handle_ntf(decoded)
1039
1040    def poll_ntf(self, duration=None):
1041        start_time = time.time()
1042        selector = selectors.DefaultSelector()
1043        selector.register(self.sock, selectors.EVENT_READ)
1044
1045        while True:
1046            try:
1047                yield self.async_msg_queue.get_nowait()
1048            except queue.Empty:
1049                if duration is not None:
1050                    timeout = start_time + duration - time.time()
1051                    if timeout <= 0:
1052                        return
1053                else:
1054                    timeout = None
1055                events = selector.select(timeout)
1056                if events:
1057                    self.check_ntf()
1058
1059    def operation_do_attributes(self, name):
1060        """
1061        For a given operation name, find and return a supported
1062        set of attributes (as a dict).
1063        """
1064        op = self.find_operation(name)
1065        if not op:
1066            return None
1067
1068        return op['do']['request']['attributes'].copy()
1069
1070    def _encode_message(self, op, vals, flags, req_seq):
1071        nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
1072        for flag in flags or []:
1073            nl_flags |= flag
1074
1075        msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
1076        if op.fixed_header:
1077            msg += self._encode_struct(op.fixed_header, vals)
1078        search_attrs = SpaceAttrs(op.attr_set, vals)
1079        for name, value in vals.items():
1080            msg += self._add_attr(op.attr_set.name, name, value, search_attrs)
1081        msg = _genl_msg_finalize(msg)
1082        return msg
1083
1084    def _ops(self, ops):
1085        reqs_by_seq = {}
1086        req_seq = random.randint(1024, 65535)
1087        payload = b''
1088        for (method, vals, flags) in ops:
1089            op = self.ops[method]
1090            msg = self._encode_message(op, vals, flags, req_seq)
1091            reqs_by_seq[req_seq] = (op, vals, msg, flags)
1092            payload += msg
1093            req_seq += 1
1094
1095        self.sock.send(payload, 0)
1096
1097        done = False
1098        rsp = []
1099        op_rsp = []
1100        while not done:
1101            reply = self.sock.recv(self._recv_size)
1102            nms = NlMsgs(reply)
1103            self._recv_dbg_print(reply, nms)
1104            for nl_msg in nms:
1105                if nl_msg.nl_seq in reqs_by_seq:
1106                    (op, vals, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq]
1107                    if nl_msg.extack:
1108                        nl_msg.annotate_extack(op.attr_set)
1109                        self._decode_extack(req_msg, op, nl_msg.extack, vals)
1110                else:
1111                    op = None
1112                    req_flags = []
1113
1114                if nl_msg.error:
1115                    raise NlError(nl_msg)
1116                if nl_msg.done:
1117                    if nl_msg.extack:
1118                        print("Netlink warning:")
1119                        print(nl_msg)
1120
1121                    if Netlink.NLM_F_DUMP in req_flags:
1122                        rsp.append(op_rsp)
1123                    elif not op_rsp:
1124                        rsp.append(None)
1125                    elif len(op_rsp) == 1:
1126                        rsp.append(op_rsp[0])
1127                    else:
1128                        rsp.append(op_rsp)
1129                    op_rsp = []
1130
1131                    del reqs_by_seq[nl_msg.nl_seq]
1132                    done = len(reqs_by_seq) == 0
1133                    break
1134
1135                decoded = self.nlproto.decode(self, nl_msg, op)
1136
1137                # Check if this is a reply to our request
1138                if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value:
1139                    if decoded.cmd() in self.async_msg_ids:
1140                        self.handle_ntf(decoded)
1141                        continue
1142                    else:
1143                        print('Unexpected message: ' + repr(decoded))
1144                        continue
1145
1146                rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
1147                if op.fixed_header:
1148                    rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header))
1149                op_rsp.append(rsp_msg)
1150
1151        return rsp
1152
1153    def _op(self, method, vals, flags=None, dump=False):
1154        req_flags = flags or []
1155        if dump:
1156            req_flags.append(Netlink.NLM_F_DUMP)
1157
1158        ops = [(method, vals, req_flags)]
1159        return self._ops(ops)[0]
1160
1161    def do(self, method, vals, flags=None):
1162        return self._op(method, vals, flags)
1163
1164    def dump(self, method, vals):
1165        return self._op(method, vals, dump=True)
1166
1167    def do_multi(self, ops):
1168        return self._ops(ops)
1169