xref: /linux/tools/net/ynl/pyynl/lib/ynl.py (revision f7c595c9d9f4cce9ec335f0d3c5d875bb547f9d5)
1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2
3from collections import namedtuple
4from enum import Enum
5import functools
6import os
7import random
8import socket
9import struct
10from struct import Struct
11import sys
12import yaml
13import ipaddress
14import uuid
15import queue
16import selectors
17import time
18
19from .nlspec import SpecFamily
20
21#
22# Generic Netlink code which should really be in some library, but I can't quickly find one.
23#
24
25
26class Netlink:
27    # Netlink socket
28    SOL_NETLINK = 270
29
30    NETLINK_ADD_MEMBERSHIP = 1
31    NETLINK_CAP_ACK = 10
32    NETLINK_EXT_ACK = 11
33    NETLINK_GET_STRICT_CHK = 12
34
35    # Netlink message
36    NLMSG_ERROR = 2
37    NLMSG_DONE = 3
38
39    NLM_F_REQUEST = 1
40    NLM_F_ACK = 4
41    NLM_F_ROOT = 0x100
42    NLM_F_MATCH = 0x200
43
44    NLM_F_REPLACE = 0x100
45    NLM_F_EXCL = 0x200
46    NLM_F_CREATE = 0x400
47    NLM_F_APPEND = 0x800
48
49    NLM_F_CAPPED = 0x100
50    NLM_F_ACK_TLVS = 0x200
51
52    NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
53
54    NLA_F_NESTED = 0x8000
55    NLA_F_NET_BYTEORDER = 0x4000
56
57    NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
58
59    # Genetlink defines
60    NETLINK_GENERIC = 16
61
62    GENL_ID_CTRL = 0x10
63
64    # nlctrl
65    CTRL_CMD_GETFAMILY = 3
66
67    CTRL_ATTR_FAMILY_ID = 1
68    CTRL_ATTR_FAMILY_NAME = 2
69    CTRL_ATTR_MAXATTR = 5
70    CTRL_ATTR_MCAST_GROUPS = 7
71
72    CTRL_ATTR_MCAST_GRP_NAME = 1
73    CTRL_ATTR_MCAST_GRP_ID = 2
74
75    # Extack types
76    NLMSGERR_ATTR_MSG = 1
77    NLMSGERR_ATTR_OFFS = 2
78    NLMSGERR_ATTR_COOKIE = 3
79    NLMSGERR_ATTR_POLICY = 4
80    NLMSGERR_ATTR_MISS_TYPE = 5
81    NLMSGERR_ATTR_MISS_NEST = 6
82
83    # Policy types
84    NL_POLICY_TYPE_ATTR_TYPE = 1
85    NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2
86    NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3
87    NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4
88    NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5
89    NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6
90    NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7
91    NL_POLICY_TYPE_ATTR_POLICY_IDX = 8
92    NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9
93    NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10
94    NL_POLICY_TYPE_ATTR_PAD = 11
95    NL_POLICY_TYPE_ATTR_MASK = 12
96
97    AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64',
98                                  's8', 's16', 's32', 's64',
99                                  'binary', 'string', 'nul-string',
100                                  'nested', 'nested-array',
101                                  'bitfield32', 'sint', 'uint'])
102
103class NlError(Exception):
104  def __init__(self, nl_msg):
105    self.nl_msg = nl_msg
106    self.error = -nl_msg.error
107
108  def __str__(self):
109    return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}"
110
111
112class ConfigError(Exception):
113    pass
114
115
116class NlAttr:
117    ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
118    type_formats = {
119        'u8' : ScalarFormat(Struct('B'), Struct("B"),  Struct("B")),
120        's8' : ScalarFormat(Struct('b'), Struct("b"),  Struct("b")),
121        'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
122        's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
123        'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
124        's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
125        'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
126        's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
127    }
128
129    def __init__(self, raw, offset):
130        self._len, self._type = struct.unpack("HH", raw[offset : offset + 4])
131        self.type = self._type & ~Netlink.NLA_TYPE_MASK
132        self.is_nest = self._type & Netlink.NLA_F_NESTED
133        self.payload_len = self._len
134        self.full_len = (self.payload_len + 3) & ~3
135        self.raw = raw[offset + 4 : offset + self.payload_len]
136
137    @classmethod
138    def get_format(cls, attr_type, byte_order=None):
139        format = cls.type_formats[attr_type]
140        if byte_order:
141            return format.big if byte_order == "big-endian" \
142                else format.little
143        return format.native
144
145    def as_scalar(self, attr_type, byte_order=None):
146        format = self.get_format(attr_type, byte_order)
147        return format.unpack(self.raw)[0]
148
149    def as_auto_scalar(self, attr_type, byte_order=None):
150        if len(self.raw) != 4 and len(self.raw) != 8:
151            raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}")
152        real_type = attr_type[0] + str(len(self.raw) * 8)
153        format = self.get_format(real_type, byte_order)
154        return format.unpack(self.raw)[0]
155
156    def as_strz(self):
157        return self.raw.decode('ascii')[:-1]
158
159    def as_bin(self):
160        return self.raw
161
162    def as_c_array(self, type):
163        format = self.get_format(type)
164        return [ x[0] for x in format.iter_unpack(self.raw) ]
165
166    def __repr__(self):
167        return f"[type:{self.type} len:{self._len}] {self.raw}"
168
169
170class NlAttrs:
171    def __init__(self, msg, offset=0):
172        self.attrs = []
173
174        while offset < len(msg):
175            attr = NlAttr(msg, offset)
176            offset += attr.full_len
177            self.attrs.append(attr)
178
179    def __iter__(self):
180        yield from self.attrs
181
182    def __repr__(self):
183        msg = ''
184        for a in self.attrs:
185            if msg:
186                msg += '\n'
187            msg += repr(a)
188        return msg
189
190
191class NlMsg:
192    def __init__(self, msg, offset, attr_space=None):
193        self.hdr = msg[offset : offset + 16]
194
195        self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
196            struct.unpack("IHHII", self.hdr)
197
198        self.raw = msg[offset + 16 : offset + self.nl_len]
199
200        self.error = 0
201        self.done = 0
202
203        extack_off = None
204        if self.nl_type == Netlink.NLMSG_ERROR:
205            self.error = struct.unpack("i", self.raw[0:4])[0]
206            self.done = 1
207            extack_off = 20
208        elif self.nl_type == Netlink.NLMSG_DONE:
209            self.error = struct.unpack("i", self.raw[0:4])[0]
210            self.done = 1
211            extack_off = 4
212
213        self.extack = None
214        if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
215            self.extack = dict()
216            extack_attrs = NlAttrs(self.raw[extack_off:])
217            for extack in extack_attrs:
218                if extack.type == Netlink.NLMSGERR_ATTR_MSG:
219                    self.extack['msg'] = extack.as_strz()
220                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
221                    self.extack['miss-type'] = extack.as_scalar('u32')
222                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
223                    self.extack['miss-nest'] = extack.as_scalar('u32')
224                elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
225                    self.extack['bad-attr-offs'] = extack.as_scalar('u32')
226                elif extack.type == Netlink.NLMSGERR_ATTR_POLICY:
227                    self.extack['policy'] = self._decode_policy(extack.raw)
228                else:
229                    if 'unknown' not in self.extack:
230                        self.extack['unknown'] = []
231                    self.extack['unknown'].append(extack)
232
233            if attr_space:
234                self.annotate_extack(attr_space)
235
236    def _decode_policy(self, raw):
237        policy = {}
238        for attr in NlAttrs(raw):
239            if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE:
240                type = attr.as_scalar('u32')
241                policy['type'] = Netlink.AttrType(type).name
242            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S:
243                policy['min-value'] = attr.as_scalar('s64')
244            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S:
245                policy['max-value'] = attr.as_scalar('s64')
246            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U:
247                policy['min-value'] = attr.as_scalar('u64')
248            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U:
249                policy['max-value'] = attr.as_scalar('u64')
250            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH:
251                policy['min-length'] = attr.as_scalar('u32')
252            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH:
253                policy['max-length'] = attr.as_scalar('u32')
254            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK:
255                policy['bitfield32-mask'] = attr.as_scalar('u32')
256            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK:
257                policy['mask'] = attr.as_scalar('u64')
258        return policy
259
260    def annotate_extack(self, attr_space):
261        """ Make extack more human friendly with attribute information """
262
263        # We don't have the ability to parse nests yet, so only do global
264        if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
265            miss_type = self.extack['miss-type']
266            if miss_type in attr_space.attrs_by_val:
267                spec = attr_space.attrs_by_val[miss_type]
268                self.extack['miss-type'] = spec['name']
269                if 'doc' in spec:
270                    self.extack['miss-type-doc'] = spec['doc']
271
272    def cmd(self):
273        return self.nl_type
274
275    def __repr__(self):
276        msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}"
277        if self.error:
278            msg += '\n\terror: ' + str(self.error)
279        if self.extack:
280            msg += '\n\textack: ' + repr(self.extack)
281        return msg
282
283
284class NlMsgs:
285    def __init__(self, data):
286        self.msgs = []
287
288        offset = 0
289        while offset < len(data):
290            msg = NlMsg(data, offset)
291            offset += msg.nl_len
292            self.msgs.append(msg)
293
294    def __iter__(self):
295        yield from self.msgs
296
297
298genl_family_name_to_id = None
299
300
301def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
302    # we prepend length in _genl_msg_finalize()
303    if seq is None:
304        seq = random.randint(1, 1024)
305    nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
306    genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
307    return nlmsg + genlmsg
308
309
310def _genl_msg_finalize(msg):
311    return struct.pack("I", len(msg) + 4) + msg
312
313
314def _genl_load_families():
315    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
316        sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
317
318        msg = _genl_msg(Netlink.GENL_ID_CTRL,
319                        Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
320                        Netlink.CTRL_CMD_GETFAMILY, 1)
321        msg = _genl_msg_finalize(msg)
322
323        sock.send(msg, 0)
324
325        global genl_family_name_to_id
326        genl_family_name_to_id = dict()
327
328        while True:
329            reply = sock.recv(128 * 1024)
330            nms = NlMsgs(reply)
331            for nl_msg in nms:
332                if nl_msg.error:
333                    print("Netlink error:", nl_msg.error)
334                    return
335                if nl_msg.done:
336                    return
337
338                gm = GenlMsg(nl_msg)
339                fam = dict()
340                for attr in NlAttrs(gm.raw):
341                    if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
342                        fam['id'] = attr.as_scalar('u16')
343                    elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
344                        fam['name'] = attr.as_strz()
345                    elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
346                        fam['maxattr'] = attr.as_scalar('u32')
347                    elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
348                        fam['mcast'] = dict()
349                        for entry in NlAttrs(attr.raw):
350                            mcast_name = None
351                            mcast_id = None
352                            for entry_attr in NlAttrs(entry.raw):
353                                if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
354                                    mcast_name = entry_attr.as_strz()
355                                elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
356                                    mcast_id = entry_attr.as_scalar('u32')
357                            if mcast_name and mcast_id is not None:
358                                fam['mcast'][mcast_name] = mcast_id
359                if 'name' in fam and 'id' in fam:
360                    genl_family_name_to_id[fam['name']] = fam
361
362
363class GenlMsg:
364    def __init__(self, nl_msg):
365        self.nl = nl_msg
366        self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
367        self.raw = nl_msg.raw[4:]
368
369    def cmd(self):
370        return self.genl_cmd
371
372    def __repr__(self):
373        msg = repr(self.nl)
374        msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
375        for a in self.raw_attrs:
376            msg += '\t\t' + repr(a) + '\n'
377        return msg
378
379
380class NetlinkProtocol:
381    def __init__(self, family_name, proto_num):
382        self.family_name = family_name
383        self.proto_num = proto_num
384
385    def _message(self, nl_type, nl_flags, seq=None):
386        if seq is None:
387            seq = random.randint(1, 1024)
388        nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
389        return nlmsg
390
391    def message(self, flags, command, version, seq=None):
392        return self._message(command, flags, seq)
393
394    def _decode(self, nl_msg):
395        return nl_msg
396
397    def decode(self, ynl, nl_msg, op):
398        msg = self._decode(nl_msg)
399        if op is None:
400            op = ynl.rsp_by_value[msg.cmd()]
401        fixed_header_size = ynl._struct_size(op.fixed_header)
402        msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size)
403        return msg
404
405    def get_mcast_id(self, mcast_name, mcast_groups):
406        if mcast_name not in mcast_groups:
407            raise Exception(f'Multicast group "{mcast_name}" not present in the spec')
408        return mcast_groups[mcast_name].value
409
410    def msghdr_size(self):
411        return 16
412
413
414class GenlProtocol(NetlinkProtocol):
415    def __init__(self, family_name):
416        super().__init__(family_name, Netlink.NETLINK_GENERIC)
417
418        global genl_family_name_to_id
419        if genl_family_name_to_id is None:
420            _genl_load_families()
421
422        self.genl_family = genl_family_name_to_id[family_name]
423        self.family_id = genl_family_name_to_id[family_name]['id']
424
425    def message(self, flags, command, version, seq=None):
426        nlmsg = self._message(self.family_id, flags, seq)
427        genlmsg = struct.pack("BBH", command, version, 0)
428        return nlmsg + genlmsg
429
430    def _decode(self, nl_msg):
431        return GenlMsg(nl_msg)
432
433    def get_mcast_id(self, mcast_name, mcast_groups):
434        if mcast_name not in self.genl_family['mcast']:
435            raise Exception(f'Multicast group "{mcast_name}" not present in the family')
436        return self.genl_family['mcast'][mcast_name]
437
438    def msghdr_size(self):
439        return super().msghdr_size() + 4
440
441
442class SpaceAttrs:
443    SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values'])
444
445    def __init__(self, attr_space, attrs, outer = None):
446        outer_scopes = outer.scopes if outer else []
447        inner_scope = self.SpecValuesPair(attr_space, attrs)
448        self.scopes = [inner_scope] + outer_scopes
449
450    def lookup(self, name):
451        for scope in self.scopes:
452            if name in scope.spec:
453                if name in scope.values:
454                    return scope.values[name]
455                spec_name = scope.spec.yaml['name']
456                raise Exception(
457                    f"No value for '{name}' in attribute space '{spec_name}'")
458        raise Exception(f"Attribute '{name}' not defined in any attribute-set")
459
460
461#
462# YNL implementation details.
463#
464
465
466class YnlFamily(SpecFamily):
467    def __init__(self, def_path, schema=None, process_unknown=False,
468                 recv_size=0):
469        super().__init__(def_path, schema)
470
471        self.include_raw = False
472        self.process_unknown = process_unknown
473
474        try:
475            if self.proto == "netlink-raw":
476                self.nlproto = NetlinkProtocol(self.yaml['name'],
477                                               self.yaml['protonum'])
478            else:
479                self.nlproto = GenlProtocol(self.yaml['name'])
480        except KeyError:
481            raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
482
483        self._recv_dbg = False
484        # Note that netlink will use conservative (min) message size for
485        # the first dump recv() on the socket, our setting will only matter
486        # from the second recv() on.
487        self._recv_size = recv_size if recv_size else 131072
488        # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo)
489        # for a message, so smaller receive sizes will lead to truncation.
490        # Note that the min size for other families may be larger than 4k!
491        if self._recv_size < 4000:
492            raise ConfigError()
493
494        self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
495        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
496        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
497        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)
498
499        self.async_msg_ids = set()
500        self.async_msg_queue = queue.Queue()
501
502        for msg in self.msgs.values():
503            if msg.is_async:
504                self.async_msg_ids.add(msg.rsp_value)
505
506        for op_name, op in self.ops.items():
507            bound_f = functools.partial(self._op, op_name)
508            setattr(self, op.ident_name, bound_f)
509
510
511    def ntf_subscribe(self, mcast_name):
512        mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
513        self.sock.bind((0, 0))
514        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
515                             mcast_id)
516
517    def set_recv_dbg(self, enabled):
518        self._recv_dbg = enabled
519
520    def _recv_dbg_print(self, reply, nl_msgs):
521        if not self._recv_dbg:
522            return
523        print("Recv: read", len(reply), "bytes,",
524              len(nl_msgs.msgs), "messages", file=sys.stderr)
525        for nl_msg in nl_msgs:
526            print("  ", nl_msg, file=sys.stderr)
527
528    def _encode_enum(self, attr_spec, value):
529        enum = self.consts[attr_spec['enum']]
530        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
531            scalar = 0
532            if isinstance(value, str):
533                value = [value]
534            for single_value in value:
535                scalar += enum.entries[single_value].user_value(as_flags = True)
536            return scalar
537        else:
538            return enum.entries[value].user_value()
539
540    def _get_scalar(self, attr_spec, value):
541        try:
542            return int(value)
543        except (ValueError, TypeError) as e:
544            if 'enum' in attr_spec:
545                return self._encode_enum(attr_spec, value)
546            if attr_spec.display_hint:
547                return self._from_string(value, attr_spec)
548            raise e
549
550    def _add_attr(self, space, name, value, search_attrs):
551        try:
552            attr = self.attr_sets[space][name]
553        except KeyError:
554            raise Exception(f"Space '{space}' has no attribute '{name}'")
555        nl_type = attr.value
556
557        if attr.is_multi and isinstance(value, list):
558            attr_payload = b''
559            for subvalue in value:
560                attr_payload += self._add_attr(space, name, subvalue, search_attrs)
561            return attr_payload
562
563        if attr["type"] == 'nest':
564            nl_type |= Netlink.NLA_F_NESTED
565            attr_payload = b''
566            sub_space = attr['nested-attributes']
567            sub_attrs = SpaceAttrs(self.attr_sets[sub_space], value, search_attrs)
568            for subname, subvalue in value.items():
569                attr_payload += self._add_attr(sub_space, subname, subvalue, sub_attrs)
570        elif attr["type"] == 'flag':
571            if not value:
572                # If value is absent or false then skip attribute creation.
573                return b''
574            attr_payload = b''
575        elif attr["type"] == 'string':
576            attr_payload = str(value).encode('ascii') + b'\x00'
577        elif attr["type"] == 'binary':
578            if isinstance(value, bytes):
579                attr_payload = value
580            elif isinstance(value, str):
581                if attr.display_hint:
582                    attr_payload = self._from_string(value, attr)
583                else:
584                    attr_payload = bytes.fromhex(value)
585            elif isinstance(value, dict) and attr.struct_name:
586                attr_payload = self._encode_struct(attr.struct_name, value)
587            else:
588                raise Exception(f'Unknown type for binary attribute, value: {value}')
589        elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar:
590            scalar = self._get_scalar(attr, value)
591            if attr.is_auto_scalar:
592                attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64')
593            else:
594                attr_type = attr["type"]
595            format = NlAttr.get_format(attr_type, attr.byte_order)
596            attr_payload = format.pack(scalar)
597        elif attr['type'] in "bitfield32":
598            scalar_value = self._get_scalar(attr, value["value"])
599            scalar_selector = self._get_scalar(attr, value["selector"])
600            attr_payload = struct.pack("II", scalar_value, scalar_selector)
601        elif attr['type'] == 'sub-message':
602            msg_format, _ = self._resolve_selector(attr, search_attrs)
603            attr_payload = b''
604            if msg_format.fixed_header:
605                attr_payload += self._encode_struct(msg_format.fixed_header, value)
606            if msg_format.attr_set:
607                if msg_format.attr_set in self.attr_sets:
608                    nl_type |= Netlink.NLA_F_NESTED
609                    sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs)
610                    for subname, subvalue in value.items():
611                        attr_payload += self._add_attr(msg_format.attr_set,
612                                                       subname, subvalue, sub_attrs)
613                else:
614                    raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'")
615        else:
616            raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
617
618        pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
619        return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
620
621    def _decode_enum(self, raw, attr_spec):
622        enum = self.consts[attr_spec['enum']]
623        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
624            i = 0
625            value = set()
626            while raw:
627                if raw & 1:
628                    value.add(enum.entries_by_val[i].name)
629                raw >>= 1
630                i += 1
631        else:
632            value = enum.entries_by_val[raw].name
633        return value
634
635    def _decode_binary(self, attr, attr_spec):
636        if attr_spec.struct_name:
637            decoded = self._decode_struct(attr.raw, attr_spec.struct_name)
638        elif attr_spec.sub_type:
639            decoded = attr.as_c_array(attr_spec.sub_type)
640            if 'enum' in attr_spec:
641                decoded = [ self._decode_enum(x, attr_spec) for x in decoded ]
642            elif attr_spec.display_hint:
643                decoded = [ self._formatted_string(x, attr_spec.display_hint)
644                            for x in decoded ]
645        else:
646            decoded = attr.as_bin()
647            if attr_spec.display_hint:
648                decoded = self._formatted_string(decoded, attr_spec.display_hint)
649        return decoded
650
651    def _decode_array_attr(self, attr, attr_spec):
652        decoded = []
653        offset = 0
654        while offset < len(attr.raw):
655            item = NlAttr(attr.raw, offset)
656            offset += item.full_len
657
658            if attr_spec["sub-type"] == 'nest':
659                subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes'])
660                decoded.append({ item.type: subattrs })
661            elif attr_spec["sub-type"] == 'binary':
662                subattr = item.as_bin()
663                if attr_spec.display_hint:
664                    subattr = self._formatted_string(subattr, attr_spec.display_hint)
665                decoded.append(subattr)
666            elif attr_spec["sub-type"] in NlAttr.type_formats:
667                subattr = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order)
668                if 'enum' in attr_spec:
669                    subattr = self._decode_enum(subattr, attr_spec)
670                elif attr_spec.display_hint:
671                    subattr = self._formatted_string(subattr, attr_spec.display_hint)
672                decoded.append(subattr)
673            else:
674                raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}')
675        return decoded
676
677    def _decode_nest_type_value(self, attr, attr_spec):
678        decoded = {}
679        value = attr
680        for name in attr_spec['type-value']:
681            value = NlAttr(value.raw, 0)
682            decoded[name] = value.type
683        subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes'])
684        decoded.update(subattrs)
685        return decoded
686
687    def _decode_unknown(self, attr):
688        if attr.is_nest:
689            return self._decode(NlAttrs(attr.raw), None)
690        else:
691            return attr.as_bin()
692
693    def _rsp_add(self, rsp, name, is_multi, decoded):
694        if is_multi == None:
695            if name in rsp and type(rsp[name]) is not list:
696                rsp[name] = [rsp[name]]
697                is_multi = True
698            else:
699                is_multi = False
700
701        if not is_multi:
702            rsp[name] = decoded
703        elif name in rsp:
704            rsp[name].append(decoded)
705        else:
706            rsp[name] = [decoded]
707
708    def _resolve_selector(self, attr_spec, search_attrs):
709        sub_msg = attr_spec.sub_message
710        if sub_msg not in self.sub_msgs:
711            raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}")
712        sub_msg_spec = self.sub_msgs[sub_msg]
713
714        selector = attr_spec.selector
715        value = search_attrs.lookup(selector)
716        if value not in sub_msg_spec.formats:
717            raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'")
718
719        spec = sub_msg_spec.formats[value]
720        return spec, value
721
722    def _decode_sub_msg(self, attr, attr_spec, search_attrs):
723        msg_format, _ = self._resolve_selector(attr_spec, search_attrs)
724        decoded = {}
725        offset = 0
726        if msg_format.fixed_header:
727            decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header));
728            offset = self._struct_size(msg_format.fixed_header)
729        if msg_format.attr_set:
730            if msg_format.attr_set in self.attr_sets:
731                subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set)
732                decoded.update(subdict)
733            else:
734                raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'")
735        return decoded
736
737    def _decode(self, attrs, space, outer_attrs = None):
738        rsp = dict()
739        if space:
740            attr_space = self.attr_sets[space]
741            search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs)
742
743        for attr in attrs:
744            try:
745                attr_spec = attr_space.attrs_by_val[attr.type]
746            except (KeyError, UnboundLocalError):
747                if not self.process_unknown:
748                    raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'")
749                attr_name = f"UnknownAttr({attr.type})"
750                self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr))
751                continue
752
753            try:
754                if attr_spec["type"] == 'nest':
755                    subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs)
756                    decoded = subdict
757                elif attr_spec["type"] == 'string':
758                    decoded = attr.as_strz()
759                elif attr_spec["type"] == 'binary':
760                    decoded = self._decode_binary(attr, attr_spec)
761                elif attr_spec["type"] == 'flag':
762                    decoded = True
763                elif attr_spec.is_auto_scalar:
764                    decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order)
765                    if 'enum' in attr_spec:
766                        decoded = self._decode_enum(decoded, attr_spec)
767                elif attr_spec["type"] in NlAttr.type_formats:
768                    decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
769                    if 'enum' in attr_spec:
770                        decoded = self._decode_enum(decoded, attr_spec)
771                    elif attr_spec.display_hint:
772                        decoded = self._formatted_string(decoded, attr_spec.display_hint)
773                elif attr_spec["type"] == 'indexed-array':
774                    decoded = self._decode_array_attr(attr, attr_spec)
775                elif attr_spec["type"] == 'bitfield32':
776                    value, selector = struct.unpack("II", attr.raw)
777                    if 'enum' in attr_spec:
778                        value = self._decode_enum(value, attr_spec)
779                        selector = self._decode_enum(selector, attr_spec)
780                    decoded = {"value": value, "selector": selector}
781                elif attr_spec["type"] == 'sub-message':
782                    decoded = self._decode_sub_msg(attr, attr_spec, search_attrs)
783                elif attr_spec["type"] == 'nest-type-value':
784                    decoded = self._decode_nest_type_value(attr, attr_spec)
785                else:
786                    if not self.process_unknown:
787                        raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
788                    decoded = self._decode_unknown(attr)
789
790                self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded)
791            except:
792                print(f"Error decoding '{attr_spec.name}' from '{space}'")
793                raise
794
795        return rsp
796
797    def _decode_extack_path(self, attrs, attr_set, offset, target, search_attrs):
798        for attr in attrs:
799            try:
800                attr_spec = attr_set.attrs_by_val[attr.type]
801            except KeyError:
802                raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'")
803            if offset > target:
804                break
805            if offset == target:
806                return '.' + attr_spec.name
807
808            if offset + attr.full_len <= target:
809                offset += attr.full_len
810                continue
811
812            pathname = attr_spec.name
813            if attr_spec['type'] == 'nest':
814                sub_attrs = self.attr_sets[attr_spec['nested-attributes']]
815                search_attrs = SpaceAttrs(sub_attrs, search_attrs.lookup(attr_spec['name']))
816            elif attr_spec['type'] == 'sub-message':
817                msg_format, value = self._resolve_selector(attr_spec, search_attrs)
818                if msg_format is None:
819                    raise Exception(f"Can't resolve sub-message of {attr_spec['name']} for extack")
820                sub_attrs = self.attr_sets[msg_format.attr_set]
821                pathname += f"({value})"
822            else:
823                raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
824            offset += 4
825            subpath = self._decode_extack_path(NlAttrs(attr.raw), sub_attrs,
826                                               offset, target, search_attrs)
827            if subpath is None:
828                return None
829            return '.' + pathname + subpath
830
831        return None
832
833    def _decode_extack(self, request, op, extack, vals):
834        if 'bad-attr-offs' not in extack:
835            return
836
837        msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op)
838        offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header)
839        search_attrs = SpaceAttrs(op.attr_set, vals)
840        path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
841                                        extack['bad-attr-offs'], search_attrs)
842        if path:
843            del extack['bad-attr-offs']
844            extack['bad-attr'] = path
845
846    def _struct_size(self, name):
847        if name:
848            members = self.consts[name].members
849            size = 0
850            for m in members:
851                if m.type in ['pad', 'binary']:
852                    if m.struct:
853                        size += self._struct_size(m.struct)
854                    else:
855                        size += m.len
856                else:
857                    format = NlAttr.get_format(m.type, m.byte_order)
858                    size += format.size
859            return size
860        else:
861            return 0
862
863    def _decode_struct(self, data, name):
864        members = self.consts[name].members
865        attrs = dict()
866        offset = 0
867        for m in members:
868            value = None
869            if m.type == 'pad':
870                offset += m.len
871            elif m.type == 'binary':
872                if m.struct:
873                    len = self._struct_size(m.struct)
874                    value = self._decode_struct(data[offset : offset + len],
875                                                m.struct)
876                    offset += len
877                else:
878                    value = data[offset : offset + m.len]
879                    offset += m.len
880            else:
881                format = NlAttr.get_format(m.type, m.byte_order)
882                [ value ] = format.unpack_from(data, offset)
883                offset += format.size
884            if value is not None:
885                if m.enum:
886                    value = self._decode_enum(value, m)
887                elif m.display_hint:
888                    value = self._formatted_string(value, m.display_hint)
889                attrs[m.name] = value
890        return attrs
891
892    def _encode_struct(self, name, vals):
893        members = self.consts[name].members
894        attr_payload = b''
895        for m in members:
896            value = vals.pop(m.name) if m.name in vals else None
897            if m.type == 'pad':
898                attr_payload += bytearray(m.len)
899            elif m.type == 'binary':
900                if m.struct:
901                    if value is None:
902                        value = dict()
903                    attr_payload += self._encode_struct(m.struct, value)
904                else:
905                    if value is None:
906                        attr_payload += bytearray(m.len)
907                    else:
908                        attr_payload += bytes.fromhex(value)
909            else:
910                if value is None:
911                    value = 0
912                format = NlAttr.get_format(m.type, m.byte_order)
913                attr_payload += format.pack(value)
914        return attr_payload
915
916    def _formatted_string(self, raw, display_hint):
917        if display_hint == 'mac':
918            formatted = ':'.join('%02x' % b for b in raw)
919        elif display_hint == 'hex':
920            if isinstance(raw, int):
921                formatted = hex(raw)
922            else:
923                formatted = bytes.hex(raw, ' ')
924        elif display_hint in [ 'ipv4', 'ipv6' ]:
925            formatted = format(ipaddress.ip_address(raw))
926        elif display_hint == 'uuid':
927            formatted = str(uuid.UUID(bytes=raw))
928        else:
929            formatted = raw
930        return formatted
931
932    def _from_string(self, string, attr_spec):
933        if attr_spec.display_hint in ['ipv4', 'ipv6']:
934            ip = ipaddress.ip_address(string)
935            if attr_spec['type'] == 'binary':
936                raw = ip.packed
937            else:
938                raw = int(ip)
939        else:
940            raise Exception(f"Display hint '{attr_spec.display_hint}' not implemented"
941                            f" when parsing '{attr_spec['name']}'")
942        return raw
943
944    def handle_ntf(self, decoded):
945        msg = dict()
946        if self.include_raw:
947            msg['raw'] = decoded
948        op = self.rsp_by_value[decoded.cmd()]
949        attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
950        if op.fixed_header:
951            attrs.update(self._decode_struct(decoded.raw, op.fixed_header))
952
953        msg['name'] = op['name']
954        msg['msg'] = attrs
955        self.async_msg_queue.put(msg)
956
957    def check_ntf(self):
958        while True:
959            try:
960                reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT)
961            except BlockingIOError:
962                return
963
964            nms = NlMsgs(reply)
965            self._recv_dbg_print(reply, nms)
966            for nl_msg in nms:
967                if nl_msg.error:
968                    print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
969                    print(nl_msg)
970                    continue
971                if nl_msg.done:
972                    print("Netlink done while checking for ntf!?")
973                    continue
974
975                decoded = self.nlproto.decode(self, nl_msg, None)
976                if decoded.cmd() not in self.async_msg_ids:
977                    print("Unexpected msg id while checking for ntf", decoded)
978                    continue
979
980                self.handle_ntf(decoded)
981
982    def poll_ntf(self, duration=None):
983        start_time = time.time()
984        selector = selectors.DefaultSelector()
985        selector.register(self.sock, selectors.EVENT_READ)
986
987        while True:
988            try:
989                yield self.async_msg_queue.get_nowait()
990            except queue.Empty:
991                if duration is not None:
992                    timeout = start_time + duration - time.time()
993                    if timeout <= 0:
994                        return
995                else:
996                    timeout = None
997                events = selector.select(timeout)
998                if events:
999                    self.check_ntf()
1000
1001    def operation_do_attributes(self, name):
1002      """
1003      For a given operation name, find and return a supported
1004      set of attributes (as a dict).
1005      """
1006      op = self.find_operation(name)
1007      if not op:
1008        return None
1009
1010      return op['do']['request']['attributes'].copy()
1011
1012    def _encode_message(self, op, vals, flags, req_seq):
1013        nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
1014        for flag in flags or []:
1015            nl_flags |= flag
1016
1017        msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
1018        if op.fixed_header:
1019            msg += self._encode_struct(op.fixed_header, vals)
1020        search_attrs = SpaceAttrs(op.attr_set, vals)
1021        for name, value in vals.items():
1022            msg += self._add_attr(op.attr_set.name, name, value, search_attrs)
1023        msg = _genl_msg_finalize(msg)
1024        return msg
1025
1026    def _ops(self, ops):
1027        reqs_by_seq = {}
1028        req_seq = random.randint(1024, 65535)
1029        payload = b''
1030        for (method, vals, flags) in ops:
1031            op = self.ops[method]
1032            msg = self._encode_message(op, vals, flags, req_seq)
1033            reqs_by_seq[req_seq] = (op, vals, msg, flags)
1034            payload += msg
1035            req_seq += 1
1036
1037        self.sock.send(payload, 0)
1038
1039        done = False
1040        rsp = []
1041        op_rsp = []
1042        while not done:
1043            reply = self.sock.recv(self._recv_size)
1044            nms = NlMsgs(reply)
1045            self._recv_dbg_print(reply, nms)
1046            for nl_msg in nms:
1047                if nl_msg.nl_seq in reqs_by_seq:
1048                    (op, vals, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq]
1049                    if nl_msg.extack:
1050                        nl_msg.annotate_extack(op.attr_set)
1051                        self._decode_extack(req_msg, op, nl_msg.extack, vals)
1052                else:
1053                    op = None
1054                    req_flags = []
1055
1056                if nl_msg.error:
1057                    raise NlError(nl_msg)
1058                if nl_msg.done:
1059                    if nl_msg.extack:
1060                        print("Netlink warning:")
1061                        print(nl_msg)
1062
1063                    if Netlink.NLM_F_DUMP in req_flags:
1064                        rsp.append(op_rsp)
1065                    elif not op_rsp:
1066                        rsp.append(None)
1067                    elif len(op_rsp) == 1:
1068                        rsp.append(op_rsp[0])
1069                    else:
1070                        rsp.append(op_rsp)
1071                    op_rsp = []
1072
1073                    del reqs_by_seq[nl_msg.nl_seq]
1074                    done = len(reqs_by_seq) == 0
1075                    break
1076
1077                decoded = self.nlproto.decode(self, nl_msg, op)
1078
1079                # Check if this is a reply to our request
1080                if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value:
1081                    if decoded.cmd() in self.async_msg_ids:
1082                        self.handle_ntf(decoded)
1083                        continue
1084                    else:
1085                        print('Unexpected message: ' + repr(decoded))
1086                        continue
1087
1088                rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
1089                if op.fixed_header:
1090                    rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header))
1091                op_rsp.append(rsp_msg)
1092
1093        return rsp
1094
1095    def _op(self, method, vals, flags=None, dump=False):
1096        req_flags = flags or []
1097        if dump:
1098            req_flags.append(Netlink.NLM_F_DUMP)
1099
1100        ops = [(method, vals, req_flags)]
1101        return self._ops(ops)[0]
1102
1103    def do(self, method, vals, flags=None):
1104        return self._op(method, vals, flags)
1105
1106    def dump(self, method, vals):
1107        return self._op(method, vals, dump=True)
1108
1109    def do_multi(self, ops):
1110        return self._ops(ops)
1111