xref: /linux/tools/net/ynl/pyynl/lib/ynl.py (revision 8c2e602225f0a96f2c5c65de8ab06e304081e542)
1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2
3from collections import namedtuple
4from enum import Enum
5import functools
6import os
7import random
8import socket
9import struct
10from struct import Struct
11import sys
12import yaml
13import ipaddress
14import uuid
15import queue
16import selectors
17import time
18
19from .nlspec import SpecFamily
20
21#
22# Generic Netlink code which should really be in some library, but I can't quickly find one.
23#
24
25
26class Netlink:
27    # Netlink socket
28    SOL_NETLINK = 270
29
30    NETLINK_ADD_MEMBERSHIP = 1
31    NETLINK_CAP_ACK = 10
32    NETLINK_EXT_ACK = 11
33    NETLINK_GET_STRICT_CHK = 12
34
35    # Netlink message
36    NLMSG_ERROR = 2
37    NLMSG_DONE = 3
38
39    NLM_F_REQUEST = 1
40    NLM_F_ACK = 4
41    NLM_F_ROOT = 0x100
42    NLM_F_MATCH = 0x200
43
44    NLM_F_REPLACE = 0x100
45    NLM_F_EXCL = 0x200
46    NLM_F_CREATE = 0x400
47    NLM_F_APPEND = 0x800
48
49    NLM_F_CAPPED = 0x100
50    NLM_F_ACK_TLVS = 0x200
51
52    NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
53
54    NLA_F_NESTED = 0x8000
55    NLA_F_NET_BYTEORDER = 0x4000
56
57    NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
58
59    # Genetlink defines
60    NETLINK_GENERIC = 16
61
62    GENL_ID_CTRL = 0x10
63
64    # nlctrl
65    CTRL_CMD_GETFAMILY = 3
66
67    CTRL_ATTR_FAMILY_ID = 1
68    CTRL_ATTR_FAMILY_NAME = 2
69    CTRL_ATTR_MAXATTR = 5
70    CTRL_ATTR_MCAST_GROUPS = 7
71
72    CTRL_ATTR_MCAST_GRP_NAME = 1
73    CTRL_ATTR_MCAST_GRP_ID = 2
74
75    # Extack types
76    NLMSGERR_ATTR_MSG = 1
77    NLMSGERR_ATTR_OFFS = 2
78    NLMSGERR_ATTR_COOKIE = 3
79    NLMSGERR_ATTR_POLICY = 4
80    NLMSGERR_ATTR_MISS_TYPE = 5
81    NLMSGERR_ATTR_MISS_NEST = 6
82
83    # Policy types
84    NL_POLICY_TYPE_ATTR_TYPE = 1
85    NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2
86    NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3
87    NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4
88    NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5
89    NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6
90    NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7
91    NL_POLICY_TYPE_ATTR_POLICY_IDX = 8
92    NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9
93    NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10
94    NL_POLICY_TYPE_ATTR_PAD = 11
95    NL_POLICY_TYPE_ATTR_MASK = 12
96
97    AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64',
98                                  's8', 's16', 's32', 's64',
99                                  'binary', 'string', 'nul-string',
100                                  'nested', 'nested-array',
101                                  'bitfield32', 'sint', 'uint'])
102
103class NlError(Exception):
104  def __init__(self, nl_msg):
105    self.nl_msg = nl_msg
106    self.error = -nl_msg.error
107
108  def __str__(self):
109    return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}"
110
111
112class ConfigError(Exception):
113    pass
114
115
116class NlAttr:
117    ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
118    type_formats = {
119        'u8' : ScalarFormat(Struct('B'), Struct("B"),  Struct("B")),
120        's8' : ScalarFormat(Struct('b'), Struct("b"),  Struct("b")),
121        'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
122        's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
123        'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
124        's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
125        'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
126        's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
127    }
128
129    def __init__(self, raw, offset):
130        self._len, self._type = struct.unpack("HH", raw[offset : offset + 4])
131        self.type = self._type & ~Netlink.NLA_TYPE_MASK
132        self.is_nest = self._type & Netlink.NLA_F_NESTED
133        self.payload_len = self._len
134        self.full_len = (self.payload_len + 3) & ~3
135        self.raw = raw[offset + 4 : offset + self.payload_len]
136
137    @classmethod
138    def get_format(cls, attr_type, byte_order=None):
139        format = cls.type_formats[attr_type]
140        if byte_order:
141            return format.big if byte_order == "big-endian" \
142                else format.little
143        return format.native
144
145    def as_scalar(self, attr_type, byte_order=None):
146        format = self.get_format(attr_type, byte_order)
147        return format.unpack(self.raw)[0]
148
149    def as_auto_scalar(self, attr_type, byte_order=None):
150        if len(self.raw) != 4 and len(self.raw) != 8:
151            raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}")
152        real_type = attr_type[0] + str(len(self.raw) * 8)
153        format = self.get_format(real_type, byte_order)
154        return format.unpack(self.raw)[0]
155
156    def as_strz(self):
157        return self.raw.decode('ascii')[:-1]
158
159    def as_bin(self):
160        return self.raw
161
162    def as_c_array(self, type):
163        format = self.get_format(type)
164        return [ x[0] for x in format.iter_unpack(self.raw) ]
165
166    def __repr__(self):
167        return f"[type:{self.type} len:{self._len}] {self.raw}"
168
169
170class NlAttrs:
171    def __init__(self, msg, offset=0):
172        self.attrs = []
173
174        while offset < len(msg):
175            attr = NlAttr(msg, offset)
176            offset += attr.full_len
177            self.attrs.append(attr)
178
179    def __iter__(self):
180        yield from self.attrs
181
182    def __repr__(self):
183        msg = ''
184        for a in self.attrs:
185            if msg:
186                msg += '\n'
187            msg += repr(a)
188        return msg
189
190
191class NlMsg:
192    def __init__(self, msg, offset, attr_space=None):
193        self.hdr = msg[offset : offset + 16]
194
195        self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
196            struct.unpack("IHHII", self.hdr)
197
198        self.raw = msg[offset + 16 : offset + self.nl_len]
199
200        self.error = 0
201        self.done = 0
202
203        extack_off = None
204        if self.nl_type == Netlink.NLMSG_ERROR:
205            self.error = struct.unpack("i", self.raw[0:4])[0]
206            self.done = 1
207            extack_off = 20
208        elif self.nl_type == Netlink.NLMSG_DONE:
209            self.error = struct.unpack("i", self.raw[0:4])[0]
210            self.done = 1
211            extack_off = 4
212
213        self.extack = None
214        if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
215            self.extack = dict()
216            extack_attrs = NlAttrs(self.raw[extack_off:])
217            for extack in extack_attrs:
218                if extack.type == Netlink.NLMSGERR_ATTR_MSG:
219                    self.extack['msg'] = extack.as_strz()
220                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
221                    self.extack['miss-type'] = extack.as_scalar('u32')
222                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
223                    self.extack['miss-nest'] = extack.as_scalar('u32')
224                elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
225                    self.extack['bad-attr-offs'] = extack.as_scalar('u32')
226                elif extack.type == Netlink.NLMSGERR_ATTR_POLICY:
227                    self.extack['policy'] = self._decode_policy(extack.raw)
228                else:
229                    if 'unknown' not in self.extack:
230                        self.extack['unknown'] = []
231                    self.extack['unknown'].append(extack)
232
233            if attr_space:
234                self.annotate_extack(attr_space)
235
236    def _decode_policy(self, raw):
237        policy = {}
238        for attr in NlAttrs(raw):
239            if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE:
240                type = attr.as_scalar('u32')
241                policy['type'] = Netlink.AttrType(type).name
242            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S:
243                policy['min-value'] = attr.as_scalar('s64')
244            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S:
245                policy['max-value'] = attr.as_scalar('s64')
246            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U:
247                policy['min-value'] = attr.as_scalar('u64')
248            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U:
249                policy['max-value'] = attr.as_scalar('u64')
250            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH:
251                policy['min-length'] = attr.as_scalar('u32')
252            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH:
253                policy['max-length'] = attr.as_scalar('u32')
254            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK:
255                policy['bitfield32-mask'] = attr.as_scalar('u32')
256            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK:
257                policy['mask'] = attr.as_scalar('u64')
258        return policy
259
260    def annotate_extack(self, attr_space):
261        """ Make extack more human friendly with attribute information """
262
263        # We don't have the ability to parse nests yet, so only do global
264        if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
265            miss_type = self.extack['miss-type']
266            if miss_type in attr_space.attrs_by_val:
267                spec = attr_space.attrs_by_val[miss_type]
268                self.extack['miss-type'] = spec['name']
269                if 'doc' in spec:
270                    self.extack['miss-type-doc'] = spec['doc']
271
272    def cmd(self):
273        return self.nl_type
274
275    def __repr__(self):
276        msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}"
277        if self.error:
278            msg += '\n\terror: ' + str(self.error)
279        if self.extack:
280            msg += '\n\textack: ' + repr(self.extack)
281        return msg
282
283
284class NlMsgs:
285    def __init__(self, data):
286        self.msgs = []
287
288        offset = 0
289        while offset < len(data):
290            msg = NlMsg(data, offset)
291            offset += msg.nl_len
292            self.msgs.append(msg)
293
294    def __iter__(self):
295        yield from self.msgs
296
297
298genl_family_name_to_id = None
299
300
301def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
302    # we prepend length in _genl_msg_finalize()
303    if seq is None:
304        seq = random.randint(1, 1024)
305    nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
306    genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
307    return nlmsg + genlmsg
308
309
310def _genl_msg_finalize(msg):
311    return struct.pack("I", len(msg) + 4) + msg
312
313
314def _genl_load_families():
315    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
316        sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
317
318        msg = _genl_msg(Netlink.GENL_ID_CTRL,
319                        Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
320                        Netlink.CTRL_CMD_GETFAMILY, 1)
321        msg = _genl_msg_finalize(msg)
322
323        sock.send(msg, 0)
324
325        global genl_family_name_to_id
326        genl_family_name_to_id = dict()
327
328        while True:
329            reply = sock.recv(128 * 1024)
330            nms = NlMsgs(reply)
331            for nl_msg in nms:
332                if nl_msg.error:
333                    print("Netlink error:", nl_msg.error)
334                    return
335                if nl_msg.done:
336                    return
337
338                gm = GenlMsg(nl_msg)
339                fam = dict()
340                for attr in NlAttrs(gm.raw):
341                    if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
342                        fam['id'] = attr.as_scalar('u16')
343                    elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
344                        fam['name'] = attr.as_strz()
345                    elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
346                        fam['maxattr'] = attr.as_scalar('u32')
347                    elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
348                        fam['mcast'] = dict()
349                        for entry in NlAttrs(attr.raw):
350                            mcast_name = None
351                            mcast_id = None
352                            for entry_attr in NlAttrs(entry.raw):
353                                if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
354                                    mcast_name = entry_attr.as_strz()
355                                elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
356                                    mcast_id = entry_attr.as_scalar('u32')
357                            if mcast_name and mcast_id is not None:
358                                fam['mcast'][mcast_name] = mcast_id
359                if 'name' in fam and 'id' in fam:
360                    genl_family_name_to_id[fam['name']] = fam
361
362
363class GenlMsg:
364    def __init__(self, nl_msg):
365        self.nl = nl_msg
366        self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
367        self.raw = nl_msg.raw[4:]
368
369    def cmd(self):
370        return self.genl_cmd
371
372    def __repr__(self):
373        msg = repr(self.nl)
374        msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
375        for a in self.raw_attrs:
376            msg += '\t\t' + repr(a) + '\n'
377        return msg
378
379
380class NetlinkProtocol:
381    def __init__(self, family_name, proto_num):
382        self.family_name = family_name
383        self.proto_num = proto_num
384
385    def _message(self, nl_type, nl_flags, seq=None):
386        if seq is None:
387            seq = random.randint(1, 1024)
388        nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
389        return nlmsg
390
391    def message(self, flags, command, version, seq=None):
392        return self._message(command, flags, seq)
393
394    def _decode(self, nl_msg):
395        return nl_msg
396
397    def decode(self, ynl, nl_msg, op):
398        msg = self._decode(nl_msg)
399        if op is None:
400            op = ynl.rsp_by_value[msg.cmd()]
401        fixed_header_size = ynl._struct_size(op.fixed_header)
402        msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size)
403        return msg
404
405    def get_mcast_id(self, mcast_name, mcast_groups):
406        if mcast_name not in mcast_groups:
407            raise Exception(f'Multicast group "{mcast_name}" not present in the spec')
408        return mcast_groups[mcast_name].value
409
410    def msghdr_size(self):
411        return 16
412
413
414class GenlProtocol(NetlinkProtocol):
415    def __init__(self, family_name):
416        super().__init__(family_name, Netlink.NETLINK_GENERIC)
417
418        global genl_family_name_to_id
419        if genl_family_name_to_id is None:
420            _genl_load_families()
421
422        self.genl_family = genl_family_name_to_id[family_name]
423        self.family_id = genl_family_name_to_id[family_name]['id']
424
425    def message(self, flags, command, version, seq=None):
426        nlmsg = self._message(self.family_id, flags, seq)
427        genlmsg = struct.pack("BBH", command, version, 0)
428        return nlmsg + genlmsg
429
430    def _decode(self, nl_msg):
431        return GenlMsg(nl_msg)
432
433    def get_mcast_id(self, mcast_name, mcast_groups):
434        if mcast_name not in self.genl_family['mcast']:
435            raise Exception(f'Multicast group "{mcast_name}" not present in the family')
436        return self.genl_family['mcast'][mcast_name]
437
438    def msghdr_size(self):
439        return super().msghdr_size() + 4
440
441
442class SpaceAttrs:
443    SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values'])
444
445    def __init__(self, attr_space, attrs, outer = None):
446        outer_scopes = outer.scopes if outer else []
447        inner_scope = self.SpecValuesPair(attr_space, attrs)
448        self.scopes = [inner_scope] + outer_scopes
449
450    def lookup(self, name):
451        for scope in self.scopes:
452            if name in scope.spec:
453                if name in scope.values:
454                    return scope.values[name]
455                spec_name = scope.spec.yaml['name']
456                raise Exception(
457                    f"No value for '{name}' in attribute space '{spec_name}'")
458        raise Exception(f"Attribute '{name}' not defined in any attribute-set")
459
460
461#
462# YNL implementation details.
463#
464
465
466class YnlFamily(SpecFamily):
467    def __init__(self, def_path, schema=None, process_unknown=False,
468                 recv_size=0):
469        super().__init__(def_path, schema)
470
471        self.include_raw = False
472        self.process_unknown = process_unknown
473
474        try:
475            if self.proto == "netlink-raw":
476                self.nlproto = NetlinkProtocol(self.yaml['name'],
477                                               self.yaml['protonum'])
478            else:
479                self.nlproto = GenlProtocol(self.yaml['name'])
480        except KeyError:
481            raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
482
483        self._recv_dbg = False
484        # Note that netlink will use conservative (min) message size for
485        # the first dump recv() on the socket, our setting will only matter
486        # from the second recv() on.
487        self._recv_size = recv_size if recv_size else 131072
488        # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo)
489        # for a message, so smaller receive sizes will lead to truncation.
490        # Note that the min size for other families may be larger than 4k!
491        if self._recv_size < 4000:
492            raise ConfigError()
493
494        self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
495        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
496        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
497        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)
498
499        self.async_msg_ids = set()
500        self.async_msg_queue = queue.Queue()
501
502        for msg in self.msgs.values():
503            if msg.is_async:
504                self.async_msg_ids.add(msg.rsp_value)
505
506        for op_name, op in self.ops.items():
507            bound_f = functools.partial(self._op, op_name)
508            setattr(self, op.ident_name, bound_f)
509
510
511    def ntf_subscribe(self, mcast_name):
512        mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
513        self.sock.bind((0, 0))
514        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
515                             mcast_id)
516
517    def set_recv_dbg(self, enabled):
518        self._recv_dbg = enabled
519
520    def _recv_dbg_print(self, reply, nl_msgs):
521        if not self._recv_dbg:
522            return
523        print("Recv: read", len(reply), "bytes,",
524              len(nl_msgs.msgs), "messages", file=sys.stderr)
525        for nl_msg in nl_msgs:
526            print("  ", nl_msg, file=sys.stderr)
527
528    def _encode_enum(self, attr_spec, value):
529        enum = self.consts[attr_spec['enum']]
530        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
531            scalar = 0
532            if isinstance(value, str):
533                value = [value]
534            for single_value in value:
535                scalar += enum.entries[single_value].user_value(as_flags = True)
536            return scalar
537        else:
538            return enum.entries[value].user_value()
539
540    def _get_scalar(self, attr_spec, value):
541        try:
542            return int(value)
543        except (ValueError, TypeError) as e:
544            if 'enum' in attr_spec:
545                return self._encode_enum(attr_spec, value)
546            if attr_spec.display_hint:
547                return self._from_string(value, attr_spec)
548            raise e
549
550    def _add_attr(self, space, name, value, search_attrs):
551        try:
552            attr = self.attr_sets[space][name]
553        except KeyError:
554            raise Exception(f"Space '{space}' has no attribute '{name}'")
555        nl_type = attr.value
556
557        if attr.is_multi and isinstance(value, list):
558            attr_payload = b''
559            for subvalue in value:
560                attr_payload += self._add_attr(space, name, subvalue, search_attrs)
561            return attr_payload
562
563        if attr["type"] == 'nest':
564            nl_type |= Netlink.NLA_F_NESTED
565            attr_payload = b''
566            sub_space = attr['nested-attributes']
567            sub_attrs = SpaceAttrs(self.attr_sets[sub_space], value, search_attrs)
568            for subname, subvalue in value.items():
569                attr_payload += self._add_attr(sub_space, subname, subvalue, sub_attrs)
570        elif attr["type"] == 'flag':
571            if not value:
572                # If value is absent or false then skip attribute creation.
573                return b''
574            attr_payload = b''
575        elif attr["type"] == 'string':
576            attr_payload = str(value).encode('ascii') + b'\x00'
577        elif attr["type"] == 'binary':
578            if isinstance(value, bytes):
579                attr_payload = value
580            elif isinstance(value, str):
581                if attr.display_hint:
582                    attr_payload = self._from_string(value, attr)
583                else:
584                    attr_payload = bytes.fromhex(value)
585            elif isinstance(value, dict) and attr.struct_name:
586                attr_payload = self._encode_struct(attr.struct_name, value)
587            else:
588                raise Exception(f'Unknown type for binary attribute, value: {value}')
589        elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar:
590            scalar = self._get_scalar(attr, value)
591            if attr.is_auto_scalar:
592                attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64')
593            else:
594                attr_type = attr["type"]
595            format = NlAttr.get_format(attr_type, attr.byte_order)
596            attr_payload = format.pack(scalar)
597        elif attr['type'] in "bitfield32":
598            scalar_value = self._get_scalar(attr, value["value"])
599            scalar_selector = self._get_scalar(attr, value["selector"])
600            attr_payload = struct.pack("II", scalar_value, scalar_selector)
601        elif attr['type'] == 'sub-message':
602            msg_format, _ = self._resolve_selector(attr, search_attrs)
603            attr_payload = b''
604            if msg_format.fixed_header:
605                attr_payload += self._encode_struct(msg_format.fixed_header, value)
606            if msg_format.attr_set:
607                if msg_format.attr_set in self.attr_sets:
608                    nl_type |= Netlink.NLA_F_NESTED
609                    sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs)
610                    for subname, subvalue in value.items():
611                        attr_payload += self._add_attr(msg_format.attr_set,
612                                                       subname, subvalue, sub_attrs)
613                else:
614                    raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'")
615        else:
616            raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
617
618        pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
619        return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
620
621    def _get_enum_or_unknown(self, enum, raw):
622        try:
623            name = enum.entries_by_val[raw].name
624        except KeyError as error:
625            if self.process_unknown:
626                name = f"Unknown({raw})"
627            else:
628                raise error
629        return name
630
631    def _decode_enum(self, raw, attr_spec):
632        enum = self.consts[attr_spec['enum']]
633        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
634            i = 0
635            value = set()
636            while raw:
637                if raw & 1:
638                    value.add(self._get_enum_or_unknown(enum, i))
639                raw >>= 1
640                i += 1
641        else:
642            value = self._get_enum_or_unknown(enum, raw)
643        return value
644
645    def _decode_binary(self, attr, attr_spec):
646        if attr_spec.struct_name:
647            decoded = self._decode_struct(attr.raw, attr_spec.struct_name)
648        elif attr_spec.sub_type:
649            decoded = attr.as_c_array(attr_spec.sub_type)
650            if 'enum' in attr_spec:
651                decoded = [ self._decode_enum(x, attr_spec) for x in decoded ]
652            elif attr_spec.display_hint:
653                decoded = [ self._formatted_string(x, attr_spec.display_hint)
654                            for x in decoded ]
655        else:
656            decoded = attr.as_bin()
657            if attr_spec.display_hint:
658                decoded = self._formatted_string(decoded, attr_spec.display_hint)
659        return decoded
660
661    def _decode_array_attr(self, attr, attr_spec):
662        decoded = []
663        offset = 0
664        while offset < len(attr.raw):
665            item = NlAttr(attr.raw, offset)
666            offset += item.full_len
667
668            if attr_spec["sub-type"] == 'nest':
669                subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes'])
670                decoded.append({ item.type: subattrs })
671            elif attr_spec["sub-type"] == 'binary':
672                subattr = item.as_bin()
673                if attr_spec.display_hint:
674                    subattr = self._formatted_string(subattr, attr_spec.display_hint)
675                decoded.append(subattr)
676            elif attr_spec["sub-type"] in NlAttr.type_formats:
677                subattr = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order)
678                if 'enum' in attr_spec:
679                    subattr = self._decode_enum(subattr, attr_spec)
680                elif attr_spec.display_hint:
681                    subattr = self._formatted_string(subattr, attr_spec.display_hint)
682                decoded.append(subattr)
683            else:
684                raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}')
685        return decoded
686
687    def _decode_nest_type_value(self, attr, attr_spec):
688        decoded = {}
689        value = attr
690        for name in attr_spec['type-value']:
691            value = NlAttr(value.raw, 0)
692            decoded[name] = value.type
693        subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes'])
694        decoded.update(subattrs)
695        return decoded
696
697    def _decode_unknown(self, attr):
698        if attr.is_nest:
699            return self._decode(NlAttrs(attr.raw), None)
700        else:
701            return attr.as_bin()
702
703    def _rsp_add(self, rsp, name, is_multi, decoded):
704        if is_multi == None:
705            if name in rsp and type(rsp[name]) is not list:
706                rsp[name] = [rsp[name]]
707                is_multi = True
708            else:
709                is_multi = False
710
711        if not is_multi:
712            rsp[name] = decoded
713        elif name in rsp:
714            rsp[name].append(decoded)
715        else:
716            rsp[name] = [decoded]
717
718    def _resolve_selector(self, attr_spec, search_attrs):
719        sub_msg = attr_spec.sub_message
720        if sub_msg not in self.sub_msgs:
721            raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}")
722        sub_msg_spec = self.sub_msgs[sub_msg]
723
724        selector = attr_spec.selector
725        value = search_attrs.lookup(selector)
726        if value not in sub_msg_spec.formats:
727            raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'")
728
729        spec = sub_msg_spec.formats[value]
730        return spec, value
731
732    def _decode_sub_msg(self, attr, attr_spec, search_attrs):
733        msg_format, _ = self._resolve_selector(attr_spec, search_attrs)
734        decoded = {}
735        offset = 0
736        if msg_format.fixed_header:
737            decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header));
738            offset = self._struct_size(msg_format.fixed_header)
739        if msg_format.attr_set:
740            if msg_format.attr_set in self.attr_sets:
741                subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set)
742                decoded.update(subdict)
743            else:
744                raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'")
745        return decoded
746
747    def _decode(self, attrs, space, outer_attrs = None):
748        rsp = dict()
749        if space:
750            attr_space = self.attr_sets[space]
751            search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs)
752
753        for attr in attrs:
754            try:
755                attr_spec = attr_space.attrs_by_val[attr.type]
756            except (KeyError, UnboundLocalError):
757                if not self.process_unknown:
758                    raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'")
759                attr_name = f"UnknownAttr({attr.type})"
760                self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr))
761                continue
762
763            try:
764                if attr_spec["type"] == 'nest':
765                    subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs)
766                    decoded = subdict
767                elif attr_spec["type"] == 'string':
768                    decoded = attr.as_strz()
769                elif attr_spec["type"] == 'binary':
770                    decoded = self._decode_binary(attr, attr_spec)
771                elif attr_spec["type"] == 'flag':
772                    decoded = True
773                elif attr_spec.is_auto_scalar:
774                    decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order)
775                    if 'enum' in attr_spec:
776                        decoded = self._decode_enum(decoded, attr_spec)
777                elif attr_spec["type"] in NlAttr.type_formats:
778                    decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
779                    if 'enum' in attr_spec:
780                        decoded = self._decode_enum(decoded, attr_spec)
781                    elif attr_spec.display_hint:
782                        decoded = self._formatted_string(decoded, attr_spec.display_hint)
783                elif attr_spec["type"] == 'indexed-array':
784                    decoded = self._decode_array_attr(attr, attr_spec)
785                elif attr_spec["type"] == 'bitfield32':
786                    value, selector = struct.unpack("II", attr.raw)
787                    if 'enum' in attr_spec:
788                        value = self._decode_enum(value, attr_spec)
789                        selector = self._decode_enum(selector, attr_spec)
790                    decoded = {"value": value, "selector": selector}
791                elif attr_spec["type"] == 'sub-message':
792                    decoded = self._decode_sub_msg(attr, attr_spec, search_attrs)
793                elif attr_spec["type"] == 'nest-type-value':
794                    decoded = self._decode_nest_type_value(attr, attr_spec)
795                else:
796                    if not self.process_unknown:
797                        raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
798                    decoded = self._decode_unknown(attr)
799
800                self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded)
801            except:
802                print(f"Error decoding '{attr_spec.name}' from '{space}'")
803                raise
804
805        return rsp
806
807    def _decode_extack_path(self, attrs, attr_set, offset, target, search_attrs):
808        for attr in attrs:
809            try:
810                attr_spec = attr_set.attrs_by_val[attr.type]
811            except KeyError:
812                raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'")
813            if offset > target:
814                break
815            if offset == target:
816                return '.' + attr_spec.name
817
818            if offset + attr.full_len <= target:
819                offset += attr.full_len
820                continue
821
822            pathname = attr_spec.name
823            if attr_spec['type'] == 'nest':
824                sub_attrs = self.attr_sets[attr_spec['nested-attributes']]
825                search_attrs = SpaceAttrs(sub_attrs, search_attrs.lookup(attr_spec['name']))
826            elif attr_spec['type'] == 'sub-message':
827                msg_format, value = self._resolve_selector(attr_spec, search_attrs)
828                if msg_format is None:
829                    raise Exception(f"Can't resolve sub-message of {attr_spec['name']} for extack")
830                sub_attrs = self.attr_sets[msg_format.attr_set]
831                pathname += f"({value})"
832            else:
833                raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
834            offset += 4
835            subpath = self._decode_extack_path(NlAttrs(attr.raw), sub_attrs,
836                                               offset, target, search_attrs)
837            if subpath is None:
838                return None
839            return '.' + pathname + subpath
840
841        return None
842
843    def _decode_extack(self, request, op, extack, vals):
844        if 'bad-attr-offs' not in extack:
845            return
846
847        msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op)
848        offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header)
849        search_attrs = SpaceAttrs(op.attr_set, vals)
850        path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
851                                        extack['bad-attr-offs'], search_attrs)
852        if path:
853            del extack['bad-attr-offs']
854            extack['bad-attr'] = path
855
856    def _struct_size(self, name):
857        if name:
858            members = self.consts[name].members
859            size = 0
860            for m in members:
861                if m.type in ['pad', 'binary']:
862                    if m.struct:
863                        size += self._struct_size(m.struct)
864                    else:
865                        size += m.len
866                else:
867                    format = NlAttr.get_format(m.type, m.byte_order)
868                    size += format.size
869            return size
870        else:
871            return 0
872
873    def _decode_struct(self, data, name):
874        members = self.consts[name].members
875        attrs = dict()
876        offset = 0
877        for m in members:
878            value = None
879            if m.type == 'pad':
880                offset += m.len
881            elif m.type == 'binary':
882                if m.struct:
883                    len = self._struct_size(m.struct)
884                    value = self._decode_struct(data[offset : offset + len],
885                                                m.struct)
886                    offset += len
887                else:
888                    value = data[offset : offset + m.len]
889                    offset += m.len
890            else:
891                format = NlAttr.get_format(m.type, m.byte_order)
892                [ value ] = format.unpack_from(data, offset)
893                offset += format.size
894            if value is not None:
895                if m.enum:
896                    value = self._decode_enum(value, m)
897                elif m.display_hint:
898                    value = self._formatted_string(value, m.display_hint)
899                attrs[m.name] = value
900        return attrs
901
902    def _encode_struct(self, name, vals):
903        members = self.consts[name].members
904        attr_payload = b''
905        for m in members:
906            value = vals.pop(m.name) if m.name in vals else None
907            if m.type == 'pad':
908                attr_payload += bytearray(m.len)
909            elif m.type == 'binary':
910                if m.struct:
911                    if value is None:
912                        value = dict()
913                    attr_payload += self._encode_struct(m.struct, value)
914                else:
915                    if value is None:
916                        attr_payload += bytearray(m.len)
917                    else:
918                        attr_payload += bytes.fromhex(value)
919            else:
920                if value is None:
921                    value = 0
922                format = NlAttr.get_format(m.type, m.byte_order)
923                attr_payload += format.pack(value)
924        return attr_payload
925
926    def _formatted_string(self, raw, display_hint):
927        if display_hint == 'mac':
928            formatted = ':'.join('%02x' % b for b in raw)
929        elif display_hint == 'hex':
930            if isinstance(raw, int):
931                formatted = hex(raw)
932            else:
933                formatted = bytes.hex(raw, ' ')
934        elif display_hint in [ 'ipv4', 'ipv6' ]:
935            formatted = format(ipaddress.ip_address(raw))
936        elif display_hint == 'uuid':
937            formatted = str(uuid.UUID(bytes=raw))
938        else:
939            formatted = raw
940        return formatted
941
942    def _from_string(self, string, attr_spec):
943        if attr_spec.display_hint in ['ipv4', 'ipv6']:
944            ip = ipaddress.ip_address(string)
945            if attr_spec['type'] == 'binary':
946                raw = ip.packed
947            else:
948                raw = int(ip)
949        else:
950            raise Exception(f"Display hint '{attr_spec.display_hint}' not implemented"
951                            f" when parsing '{attr_spec['name']}'")
952        return raw
953
954    def handle_ntf(self, decoded):
955        msg = dict()
956        if self.include_raw:
957            msg['raw'] = decoded
958        op = self.rsp_by_value[decoded.cmd()]
959        attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
960        if op.fixed_header:
961            attrs.update(self._decode_struct(decoded.raw, op.fixed_header))
962
963        msg['name'] = op['name']
964        msg['msg'] = attrs
965        self.async_msg_queue.put(msg)
966
967    def check_ntf(self):
968        while True:
969            try:
970                reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT)
971            except BlockingIOError:
972                return
973
974            nms = NlMsgs(reply)
975            self._recv_dbg_print(reply, nms)
976            for nl_msg in nms:
977                if nl_msg.error:
978                    print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
979                    print(nl_msg)
980                    continue
981                if nl_msg.done:
982                    print("Netlink done while checking for ntf!?")
983                    continue
984
985                decoded = self.nlproto.decode(self, nl_msg, None)
986                if decoded.cmd() not in self.async_msg_ids:
987                    print("Unexpected msg id while checking for ntf", decoded)
988                    continue
989
990                self.handle_ntf(decoded)
991
992    def poll_ntf(self, duration=None):
993        start_time = time.time()
994        selector = selectors.DefaultSelector()
995        selector.register(self.sock, selectors.EVENT_READ)
996
997        while True:
998            try:
999                yield self.async_msg_queue.get_nowait()
1000            except queue.Empty:
1001                if duration is not None:
1002                    timeout = start_time + duration - time.time()
1003                    if timeout <= 0:
1004                        return
1005                else:
1006                    timeout = None
1007                events = selector.select(timeout)
1008                if events:
1009                    self.check_ntf()
1010
1011    def operation_do_attributes(self, name):
1012      """
1013      For a given operation name, find and return a supported
1014      set of attributes (as a dict).
1015      """
1016      op = self.find_operation(name)
1017      if not op:
1018        return None
1019
1020      return op['do']['request']['attributes'].copy()
1021
1022    def _encode_message(self, op, vals, flags, req_seq):
1023        nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
1024        for flag in flags or []:
1025            nl_flags |= flag
1026
1027        msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
1028        if op.fixed_header:
1029            msg += self._encode_struct(op.fixed_header, vals)
1030        search_attrs = SpaceAttrs(op.attr_set, vals)
1031        for name, value in vals.items():
1032            msg += self._add_attr(op.attr_set.name, name, value, search_attrs)
1033        msg = _genl_msg_finalize(msg)
1034        return msg
1035
1036    def _ops(self, ops):
1037        reqs_by_seq = {}
1038        req_seq = random.randint(1024, 65535)
1039        payload = b''
1040        for (method, vals, flags) in ops:
1041            op = self.ops[method]
1042            msg = self._encode_message(op, vals, flags, req_seq)
1043            reqs_by_seq[req_seq] = (op, vals, msg, flags)
1044            payload += msg
1045            req_seq += 1
1046
1047        self.sock.send(payload, 0)
1048
1049        done = False
1050        rsp = []
1051        op_rsp = []
1052        while not done:
1053            reply = self.sock.recv(self._recv_size)
1054            nms = NlMsgs(reply)
1055            self._recv_dbg_print(reply, nms)
1056            for nl_msg in nms:
1057                if nl_msg.nl_seq in reqs_by_seq:
1058                    (op, vals, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq]
1059                    if nl_msg.extack:
1060                        nl_msg.annotate_extack(op.attr_set)
1061                        self._decode_extack(req_msg, op, nl_msg.extack, vals)
1062                else:
1063                    op = None
1064                    req_flags = []
1065
1066                if nl_msg.error:
1067                    raise NlError(nl_msg)
1068                if nl_msg.done:
1069                    if nl_msg.extack:
1070                        print("Netlink warning:")
1071                        print(nl_msg)
1072
1073                    if Netlink.NLM_F_DUMP in req_flags:
1074                        rsp.append(op_rsp)
1075                    elif not op_rsp:
1076                        rsp.append(None)
1077                    elif len(op_rsp) == 1:
1078                        rsp.append(op_rsp[0])
1079                    else:
1080                        rsp.append(op_rsp)
1081                    op_rsp = []
1082
1083                    del reqs_by_seq[nl_msg.nl_seq]
1084                    done = len(reqs_by_seq) == 0
1085                    break
1086
1087                decoded = self.nlproto.decode(self, nl_msg, op)
1088
1089                # Check if this is a reply to our request
1090                if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value:
1091                    if decoded.cmd() in self.async_msg_ids:
1092                        self.handle_ntf(decoded)
1093                        continue
1094                    else:
1095                        print('Unexpected message: ' + repr(decoded))
1096                        continue
1097
1098                rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
1099                if op.fixed_header:
1100                    rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header))
1101                op_rsp.append(rsp_msg)
1102
1103        return rsp
1104
1105    def _op(self, method, vals, flags=None, dump=False):
1106        req_flags = flags or []
1107        if dump:
1108            req_flags.append(Netlink.NLM_F_DUMP)
1109
1110        ops = [(method, vals, req_flags)]
1111        return self._ops(ops)[0]
1112
1113    def do(self, method, vals, flags=None):
1114        return self._op(method, vals, flags)
1115
1116    def dump(self, method, vals):
1117        return self._op(method, vals, dump=True)
1118
1119    def do_multi(self, ops):
1120        return self._ops(ops)
1121