xref: /linux/tools/net/ynl/pyynl/lib/ynl.py (revision 8be4d31cb8aaeea27bde4b7ddb26e28a89062ebf)
1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2
3from collections import namedtuple
4from enum import Enum
5import functools
6import os
7import random
8import socket
9import struct
10from struct import Struct
11import sys
12import yaml
13import ipaddress
14import uuid
15import queue
16import selectors
17import time
18
19from .nlspec import SpecFamily
20
21#
22# Generic Netlink code which should really be in some library, but I can't quickly find one.
23#
24
25
26class Netlink:
27    # Netlink socket
28    SOL_NETLINK = 270
29
30    NETLINK_ADD_MEMBERSHIP = 1
31    NETLINK_CAP_ACK = 10
32    NETLINK_EXT_ACK = 11
33    NETLINK_GET_STRICT_CHK = 12
34
35    # Netlink message
36    NLMSG_ERROR = 2
37    NLMSG_DONE = 3
38
39    NLM_F_REQUEST = 1
40    NLM_F_ACK = 4
41    NLM_F_ROOT = 0x100
42    NLM_F_MATCH = 0x200
43
44    NLM_F_REPLACE = 0x100
45    NLM_F_EXCL = 0x200
46    NLM_F_CREATE = 0x400
47    NLM_F_APPEND = 0x800
48
49    NLM_F_CAPPED = 0x100
50    NLM_F_ACK_TLVS = 0x200
51
52    NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
53
54    NLA_F_NESTED = 0x8000
55    NLA_F_NET_BYTEORDER = 0x4000
56
57    NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
58
59    # Genetlink defines
60    NETLINK_GENERIC = 16
61
62    GENL_ID_CTRL = 0x10
63
64    # nlctrl
65    CTRL_CMD_GETFAMILY = 3
66
67    CTRL_ATTR_FAMILY_ID = 1
68    CTRL_ATTR_FAMILY_NAME = 2
69    CTRL_ATTR_MAXATTR = 5
70    CTRL_ATTR_MCAST_GROUPS = 7
71
72    CTRL_ATTR_MCAST_GRP_NAME = 1
73    CTRL_ATTR_MCAST_GRP_ID = 2
74
75    # Extack types
76    NLMSGERR_ATTR_MSG = 1
77    NLMSGERR_ATTR_OFFS = 2
78    NLMSGERR_ATTR_COOKIE = 3
79    NLMSGERR_ATTR_POLICY = 4
80    NLMSGERR_ATTR_MISS_TYPE = 5
81    NLMSGERR_ATTR_MISS_NEST = 6
82
83    # Policy types
84    NL_POLICY_TYPE_ATTR_TYPE = 1
85    NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2
86    NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3
87    NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4
88    NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5
89    NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6
90    NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7
91    NL_POLICY_TYPE_ATTR_POLICY_IDX = 8
92    NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9
93    NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10
94    NL_POLICY_TYPE_ATTR_PAD = 11
95    NL_POLICY_TYPE_ATTR_MASK = 12
96
97    AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64',
98                                  's8', 's16', 's32', 's64',
99                                  'binary', 'string', 'nul-string',
100                                  'nested', 'nested-array',
101                                  'bitfield32', 'sint', 'uint'])
102
103class NlError(Exception):
104  def __init__(self, nl_msg):
105    self.nl_msg = nl_msg
106    self.error = -nl_msg.error
107
108  def __str__(self):
109    return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}"
110
111
112class ConfigError(Exception):
113    pass
114
115
116class NlAttr:
117    ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
118    type_formats = {
119        'u8' : ScalarFormat(Struct('B'), Struct("B"),  Struct("B")),
120        's8' : ScalarFormat(Struct('b'), Struct("b"),  Struct("b")),
121        'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
122        's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
123        'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
124        's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
125        'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
126        's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
127    }
128
129    def __init__(self, raw, offset):
130        self._len, self._type = struct.unpack("HH", raw[offset : offset + 4])
131        self.type = self._type & ~Netlink.NLA_TYPE_MASK
132        self.is_nest = self._type & Netlink.NLA_F_NESTED
133        self.payload_len = self._len
134        self.full_len = (self.payload_len + 3) & ~3
135        self.raw = raw[offset + 4 : offset + self.payload_len]
136
137    @classmethod
138    def get_format(cls, attr_type, byte_order=None):
139        format = cls.type_formats[attr_type]
140        if byte_order:
141            return format.big if byte_order == "big-endian" \
142                else format.little
143        return format.native
144
145    def as_scalar(self, attr_type, byte_order=None):
146        format = self.get_format(attr_type, byte_order)
147        return format.unpack(self.raw)[0]
148
149    def as_auto_scalar(self, attr_type, byte_order=None):
150        if len(self.raw) != 4 and len(self.raw) != 8:
151            raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}")
152        real_type = attr_type[0] + str(len(self.raw) * 8)
153        format = self.get_format(real_type, byte_order)
154        return format.unpack(self.raw)[0]
155
156    def as_strz(self):
157        return self.raw.decode('ascii')[:-1]
158
159    def as_bin(self):
160        return self.raw
161
162    def as_c_array(self, type):
163        format = self.get_format(type)
164        return [ x[0] for x in format.iter_unpack(self.raw) ]
165
166    def __repr__(self):
167        return f"[type:{self.type} len:{self._len}] {self.raw}"
168
169
170class NlAttrs:
171    def __init__(self, msg, offset=0):
172        self.attrs = []
173
174        while offset < len(msg):
175            attr = NlAttr(msg, offset)
176            offset += attr.full_len
177            self.attrs.append(attr)
178
179    def __iter__(self):
180        yield from self.attrs
181
182    def __repr__(self):
183        msg = ''
184        for a in self.attrs:
185            if msg:
186                msg += '\n'
187            msg += repr(a)
188        return msg
189
190
191class NlMsg:
192    def __init__(self, msg, offset, attr_space=None):
193        self.hdr = msg[offset : offset + 16]
194
195        self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
196            struct.unpack("IHHII", self.hdr)
197
198        self.raw = msg[offset + 16 : offset + self.nl_len]
199
200        self.error = 0
201        self.done = 0
202
203        extack_off = None
204        if self.nl_type == Netlink.NLMSG_ERROR:
205            self.error = struct.unpack("i", self.raw[0:4])[0]
206            self.done = 1
207            extack_off = 20
208        elif self.nl_type == Netlink.NLMSG_DONE:
209            self.error = struct.unpack("i", self.raw[0:4])[0]
210            self.done = 1
211            extack_off = 4
212
213        self.extack = None
214        if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
215            self.extack = dict()
216            extack_attrs = NlAttrs(self.raw[extack_off:])
217            for extack in extack_attrs:
218                if extack.type == Netlink.NLMSGERR_ATTR_MSG:
219                    self.extack['msg'] = extack.as_strz()
220                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
221                    self.extack['miss-type'] = extack.as_scalar('u32')
222                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
223                    self.extack['miss-nest'] = extack.as_scalar('u32')
224                elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
225                    self.extack['bad-attr-offs'] = extack.as_scalar('u32')
226                elif extack.type == Netlink.NLMSGERR_ATTR_POLICY:
227                    self.extack['policy'] = self._decode_policy(extack.raw)
228                else:
229                    if 'unknown' not in self.extack:
230                        self.extack['unknown'] = []
231                    self.extack['unknown'].append(extack)
232
233            if attr_space:
234                self.annotate_extack(attr_space)
235
236    def _decode_policy(self, raw):
237        policy = {}
238        for attr in NlAttrs(raw):
239            if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE:
240                type = attr.as_scalar('u32')
241                policy['type'] = Netlink.AttrType(type).name
242            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S:
243                policy['min-value'] = attr.as_scalar('s64')
244            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S:
245                policy['max-value'] = attr.as_scalar('s64')
246            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U:
247                policy['min-value'] = attr.as_scalar('u64')
248            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U:
249                policy['max-value'] = attr.as_scalar('u64')
250            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH:
251                policy['min-length'] = attr.as_scalar('u32')
252            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH:
253                policy['max-length'] = attr.as_scalar('u32')
254            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK:
255                policy['bitfield32-mask'] = attr.as_scalar('u32')
256            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK:
257                policy['mask'] = attr.as_scalar('u64')
258        return policy
259
260    def annotate_extack(self, attr_space):
261        """ Make extack more human friendly with attribute information """
262
263        # We don't have the ability to parse nests yet, so only do global
264        if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
265            miss_type = self.extack['miss-type']
266            if miss_type in attr_space.attrs_by_val:
267                spec = attr_space.attrs_by_val[miss_type]
268                self.extack['miss-type'] = spec['name']
269                if 'doc' in spec:
270                    self.extack['miss-type-doc'] = spec['doc']
271
272    def cmd(self):
273        return self.nl_type
274
275    def __repr__(self):
276        msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}"
277        if self.error:
278            msg += '\n\terror: ' + str(self.error)
279        if self.extack:
280            msg += '\n\textack: ' + repr(self.extack)
281        return msg
282
283
284class NlMsgs:
285    def __init__(self, data):
286        self.msgs = []
287
288        offset = 0
289        while offset < len(data):
290            msg = NlMsg(data, offset)
291            offset += msg.nl_len
292            self.msgs.append(msg)
293
294    def __iter__(self):
295        yield from self.msgs
296
297
298genl_family_name_to_id = None
299
300
301def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
302    # we prepend length in _genl_msg_finalize()
303    if seq is None:
304        seq = random.randint(1, 1024)
305    nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
306    genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
307    return nlmsg + genlmsg
308
309
310def _genl_msg_finalize(msg):
311    return struct.pack("I", len(msg) + 4) + msg
312
313
314def _genl_load_families():
315    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
316        sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
317
318        msg = _genl_msg(Netlink.GENL_ID_CTRL,
319                        Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
320                        Netlink.CTRL_CMD_GETFAMILY, 1)
321        msg = _genl_msg_finalize(msg)
322
323        sock.send(msg, 0)
324
325        global genl_family_name_to_id
326        genl_family_name_to_id = dict()
327
328        while True:
329            reply = sock.recv(128 * 1024)
330            nms = NlMsgs(reply)
331            for nl_msg in nms:
332                if nl_msg.error:
333                    print("Netlink error:", nl_msg.error)
334                    return
335                if nl_msg.done:
336                    return
337
338                gm = GenlMsg(nl_msg)
339                fam = dict()
340                for attr in NlAttrs(gm.raw):
341                    if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
342                        fam['id'] = attr.as_scalar('u16')
343                    elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
344                        fam['name'] = attr.as_strz()
345                    elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
346                        fam['maxattr'] = attr.as_scalar('u32')
347                    elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
348                        fam['mcast'] = dict()
349                        for entry in NlAttrs(attr.raw):
350                            mcast_name = None
351                            mcast_id = None
352                            for entry_attr in NlAttrs(entry.raw):
353                                if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
354                                    mcast_name = entry_attr.as_strz()
355                                elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
356                                    mcast_id = entry_attr.as_scalar('u32')
357                            if mcast_name and mcast_id is not None:
358                                fam['mcast'][mcast_name] = mcast_id
359                if 'name' in fam and 'id' in fam:
360                    genl_family_name_to_id[fam['name']] = fam
361
362
363class GenlMsg:
364    def __init__(self, nl_msg):
365        self.nl = nl_msg
366        self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
367        self.raw = nl_msg.raw[4:]
368
369    def cmd(self):
370        return self.genl_cmd
371
372    def __repr__(self):
373        msg = repr(self.nl)
374        msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
375        for a in self.raw_attrs:
376            msg += '\t\t' + repr(a) + '\n'
377        return msg
378
379
380class NetlinkProtocol:
381    def __init__(self, family_name, proto_num):
382        self.family_name = family_name
383        self.proto_num = proto_num
384
385    def _message(self, nl_type, nl_flags, seq=None):
386        if seq is None:
387            seq = random.randint(1, 1024)
388        nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
389        return nlmsg
390
391    def message(self, flags, command, version, seq=None):
392        return self._message(command, flags, seq)
393
394    def _decode(self, nl_msg):
395        return nl_msg
396
397    def decode(self, ynl, nl_msg, op):
398        msg = self._decode(nl_msg)
399        if op is None:
400            op = ynl.rsp_by_value[msg.cmd()]
401        fixed_header_size = ynl._struct_size(op.fixed_header)
402        msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size)
403        return msg
404
405    def get_mcast_id(self, mcast_name, mcast_groups):
406        if mcast_name not in mcast_groups:
407            raise Exception(f'Multicast group "{mcast_name}" not present in the spec')
408        return mcast_groups[mcast_name].value
409
410    def msghdr_size(self):
411        return 16
412
413
414class GenlProtocol(NetlinkProtocol):
415    def __init__(self, family_name):
416        super().__init__(family_name, Netlink.NETLINK_GENERIC)
417
418        global genl_family_name_to_id
419        if genl_family_name_to_id is None:
420            _genl_load_families()
421
422        self.genl_family = genl_family_name_to_id[family_name]
423        self.family_id = genl_family_name_to_id[family_name]['id']
424
425    def message(self, flags, command, version, seq=None):
426        nlmsg = self._message(self.family_id, flags, seq)
427        genlmsg = struct.pack("BBH", command, version, 0)
428        return nlmsg + genlmsg
429
430    def _decode(self, nl_msg):
431        return GenlMsg(nl_msg)
432
433    def get_mcast_id(self, mcast_name, mcast_groups):
434        if mcast_name not in self.genl_family['mcast']:
435            raise Exception(f'Multicast group "{mcast_name}" not present in the family')
436        return self.genl_family['mcast'][mcast_name]
437
438    def msghdr_size(self):
439        return super().msghdr_size() + 4
440
441
442class SpaceAttrs:
443    SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values'])
444
445    def __init__(self, attr_space, attrs, outer = None):
446        outer_scopes = outer.scopes if outer else []
447        inner_scope = self.SpecValuesPair(attr_space, attrs)
448        self.scopes = [inner_scope] + outer_scopes
449
450    def lookup(self, name):
451        for scope in self.scopes:
452            if name in scope.spec:
453                if name in scope.values:
454                    return scope.values[name]
455                spec_name = scope.spec.yaml['name']
456                raise Exception(
457                    f"No value for '{name}' in attribute space '{spec_name}'")
458        raise Exception(f"Attribute '{name}' not defined in any attribute-set")
459
460
461#
462# YNL implementation details.
463#
464
465
466class YnlFamily(SpecFamily):
467    def __init__(self, def_path, schema=None, process_unknown=False,
468                 recv_size=0):
469        super().__init__(def_path, schema)
470
471        self.include_raw = False
472        self.process_unknown = process_unknown
473
474        try:
475            if self.proto == "netlink-raw":
476                self.nlproto = NetlinkProtocol(self.yaml['name'],
477                                               self.yaml['protonum'])
478            else:
479                self.nlproto = GenlProtocol(self.yaml['name'])
480        except KeyError:
481            raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
482
483        self._recv_dbg = False
484        # Note that netlink will use conservative (min) message size for
485        # the first dump recv() on the socket, our setting will only matter
486        # from the second recv() on.
487        self._recv_size = recv_size if recv_size else 131072
488        # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo)
489        # for a message, so smaller receive sizes will lead to truncation.
490        # Note that the min size for other families may be larger than 4k!
491        if self._recv_size < 4000:
492            raise ConfigError()
493
494        self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
495        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
496        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
497        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)
498
499        self.async_msg_ids = set()
500        self.async_msg_queue = queue.Queue()
501
502        for msg in self.msgs.values():
503            if msg.is_async:
504                self.async_msg_ids.add(msg.rsp_value)
505
506        for op_name, op in self.ops.items():
507            bound_f = functools.partial(self._op, op_name)
508            setattr(self, op.ident_name, bound_f)
509
510
511    def ntf_subscribe(self, mcast_name):
512        mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
513        self.sock.bind((0, 0))
514        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
515                             mcast_id)
516
517    def set_recv_dbg(self, enabled):
518        self._recv_dbg = enabled
519
520    def _recv_dbg_print(self, reply, nl_msgs):
521        if not self._recv_dbg:
522            return
523        print("Recv: read", len(reply), "bytes,",
524              len(nl_msgs.msgs), "messages", file=sys.stderr)
525        for nl_msg in nl_msgs:
526            print("  ", nl_msg, file=sys.stderr)
527
528    def _encode_enum(self, attr_spec, value):
529        enum = self.consts[attr_spec['enum']]
530        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
531            scalar = 0
532            if isinstance(value, str):
533                value = [value]
534            for single_value in value:
535                scalar += enum.entries[single_value].user_value(as_flags = True)
536            return scalar
537        else:
538            return enum.entries[value].user_value()
539
540    def _get_scalar(self, attr_spec, value):
541        try:
542            return int(value)
543        except (ValueError, TypeError) as e:
544            if 'enum' in attr_spec:
545                return self._encode_enum(attr_spec, value)
546            if attr_spec.display_hint:
547                return self._from_string(value, attr_spec)
548            raise e
549
550    def _add_attr(self, space, name, value, search_attrs):
551        try:
552            attr = self.attr_sets[space][name]
553        except KeyError:
554            raise Exception(f"Space '{space}' has no attribute '{name}'")
555        nl_type = attr.value
556
557        if attr.is_multi and isinstance(value, list):
558            attr_payload = b''
559            for subvalue in value:
560                attr_payload += self._add_attr(space, name, subvalue, search_attrs)
561            return attr_payload
562
563        if attr["type"] == 'nest':
564            nl_type |= Netlink.NLA_F_NESTED
565            attr_payload = b''
566            sub_space = attr['nested-attributes']
567            sub_attrs = SpaceAttrs(self.attr_sets[sub_space], value, search_attrs)
568            for subname, subvalue in value.items():
569                attr_payload += self._add_attr(sub_space, subname, subvalue, sub_attrs)
570        elif attr["type"] == 'flag':
571            if not value:
572                # If value is absent or false then skip attribute creation.
573                return b''
574            attr_payload = b''
575        elif attr["type"] == 'string':
576            attr_payload = str(value).encode('ascii') + b'\x00'
577        elif attr["type"] == 'binary':
578            if value is None:
579                attr_payload = b''
580            elif isinstance(value, bytes):
581                attr_payload = value
582            elif isinstance(value, str):
583                if attr.display_hint:
584                    attr_payload = self._from_string(value, attr)
585                else:
586                    attr_payload = bytes.fromhex(value)
587            elif isinstance(value, dict) and attr.struct_name:
588                attr_payload = self._encode_struct(attr.struct_name, value)
589            elif isinstance(value, list) and attr.sub_type in NlAttr.type_formats:
590                format = NlAttr.get_format(attr.sub_type)
591                attr_payload = b''.join([format.pack(x) for x in value])
592            else:
593                raise Exception(f'Unknown type for binary attribute, value: {value}')
594        elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar:
595            scalar = self._get_scalar(attr, value)
596            if attr.is_auto_scalar:
597                attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64')
598            else:
599                attr_type = attr["type"]
600            format = NlAttr.get_format(attr_type, attr.byte_order)
601            attr_payload = format.pack(scalar)
602        elif attr['type'] in "bitfield32":
603            scalar_value = self._get_scalar(attr, value["value"])
604            scalar_selector = self._get_scalar(attr, value["selector"])
605            attr_payload = struct.pack("II", scalar_value, scalar_selector)
606        elif attr['type'] == 'sub-message':
607            msg_format, _ = self._resolve_selector(attr, search_attrs)
608            attr_payload = b''
609            if msg_format.fixed_header:
610                attr_payload += self._encode_struct(msg_format.fixed_header, value)
611            if msg_format.attr_set:
612                if msg_format.attr_set in self.attr_sets:
613                    nl_type |= Netlink.NLA_F_NESTED
614                    sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs)
615                    for subname, subvalue in value.items():
616                        attr_payload += self._add_attr(msg_format.attr_set,
617                                                       subname, subvalue, sub_attrs)
618                else:
619                    raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'")
620        else:
621            raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
622
623        pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
624        return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
625
626    def _get_enum_or_unknown(self, enum, raw):
627        try:
628            name = enum.entries_by_val[raw].name
629        except KeyError as error:
630            if self.process_unknown:
631                name = f"Unknown({raw})"
632            else:
633                raise error
634        return name
635
636    def _decode_enum(self, raw, attr_spec):
637        enum = self.consts[attr_spec['enum']]
638        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
639            i = 0
640            value = set()
641            while raw:
642                if raw & 1:
643                    value.add(self._get_enum_or_unknown(enum, i))
644                raw >>= 1
645                i += 1
646        else:
647            value = self._get_enum_or_unknown(enum, raw)
648        return value
649
650    def _decode_binary(self, attr, attr_spec):
651        if attr_spec.struct_name:
652            decoded = self._decode_struct(attr.raw, attr_spec.struct_name)
653        elif attr_spec.sub_type:
654            decoded = attr.as_c_array(attr_spec.sub_type)
655            if 'enum' in attr_spec:
656                decoded = [ self._decode_enum(x, attr_spec) for x in decoded ]
657            elif attr_spec.display_hint:
658                decoded = [ self._formatted_string(x, attr_spec.display_hint)
659                            for x in decoded ]
660        else:
661            decoded = attr.as_bin()
662            if attr_spec.display_hint:
663                decoded = self._formatted_string(decoded, attr_spec.display_hint)
664        return decoded
665
666    def _decode_array_attr(self, attr, attr_spec):
667        decoded = []
668        offset = 0
669        while offset < len(attr.raw):
670            item = NlAttr(attr.raw, offset)
671            offset += item.full_len
672
673            if attr_spec["sub-type"] == 'nest':
674                subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes'])
675                decoded.append({ item.type: subattrs })
676            elif attr_spec["sub-type"] == 'binary':
677                subattr = item.as_bin()
678                if attr_spec.display_hint:
679                    subattr = self._formatted_string(subattr, attr_spec.display_hint)
680                decoded.append(subattr)
681            elif attr_spec["sub-type"] in NlAttr.type_formats:
682                subattr = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order)
683                if 'enum' in attr_spec:
684                    subattr = self._decode_enum(subattr, attr_spec)
685                elif attr_spec.display_hint:
686                    subattr = self._formatted_string(subattr, attr_spec.display_hint)
687                decoded.append(subattr)
688            else:
689                raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}')
690        return decoded
691
692    def _decode_nest_type_value(self, attr, attr_spec):
693        decoded = {}
694        value = attr
695        for name in attr_spec['type-value']:
696            value = NlAttr(value.raw, 0)
697            decoded[name] = value.type
698        subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes'])
699        decoded.update(subattrs)
700        return decoded
701
702    def _decode_unknown(self, attr):
703        if attr.is_nest:
704            return self._decode(NlAttrs(attr.raw), None)
705        else:
706            return attr.as_bin()
707
708    def _rsp_add(self, rsp, name, is_multi, decoded):
709        if is_multi == None:
710            if name in rsp and type(rsp[name]) is not list:
711                rsp[name] = [rsp[name]]
712                is_multi = True
713            else:
714                is_multi = False
715
716        if not is_multi:
717            rsp[name] = decoded
718        elif name in rsp:
719            rsp[name].append(decoded)
720        else:
721            rsp[name] = [decoded]
722
723    def _resolve_selector(self, attr_spec, search_attrs):
724        sub_msg = attr_spec.sub_message
725        if sub_msg not in self.sub_msgs:
726            raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}")
727        sub_msg_spec = self.sub_msgs[sub_msg]
728
729        selector = attr_spec.selector
730        value = search_attrs.lookup(selector)
731        if value not in sub_msg_spec.formats:
732            raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'")
733
734        spec = sub_msg_spec.formats[value]
735        return spec, value
736
737    def _decode_sub_msg(self, attr, attr_spec, search_attrs):
738        msg_format, _ = self._resolve_selector(attr_spec, search_attrs)
739        decoded = {}
740        offset = 0
741        if msg_format.fixed_header:
742            decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header));
743            offset = self._struct_size(msg_format.fixed_header)
744        if msg_format.attr_set:
745            if msg_format.attr_set in self.attr_sets:
746                subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set)
747                decoded.update(subdict)
748            else:
749                raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'")
750        return decoded
751
752    def _decode(self, attrs, space, outer_attrs = None):
753        rsp = dict()
754        if space:
755            attr_space = self.attr_sets[space]
756            search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs)
757
758        for attr in attrs:
759            try:
760                attr_spec = attr_space.attrs_by_val[attr.type]
761            except (KeyError, UnboundLocalError):
762                if not self.process_unknown:
763                    raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'")
764                attr_name = f"UnknownAttr({attr.type})"
765                self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr))
766                continue
767
768            try:
769                if attr_spec["type"] == 'nest':
770                    subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs)
771                    decoded = subdict
772                elif attr_spec["type"] == 'string':
773                    decoded = attr.as_strz()
774                elif attr_spec["type"] == 'binary':
775                    decoded = self._decode_binary(attr, attr_spec)
776                elif attr_spec["type"] == 'flag':
777                    decoded = True
778                elif attr_spec.is_auto_scalar:
779                    decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order)
780                    if 'enum' in attr_spec:
781                        decoded = self._decode_enum(decoded, attr_spec)
782                elif attr_spec["type"] in NlAttr.type_formats:
783                    decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
784                    if 'enum' in attr_spec:
785                        decoded = self._decode_enum(decoded, attr_spec)
786                    elif attr_spec.display_hint:
787                        decoded = self._formatted_string(decoded, attr_spec.display_hint)
788                elif attr_spec["type"] == 'indexed-array':
789                    decoded = self._decode_array_attr(attr, attr_spec)
790                elif attr_spec["type"] == 'bitfield32':
791                    value, selector = struct.unpack("II", attr.raw)
792                    if 'enum' in attr_spec:
793                        value = self._decode_enum(value, attr_spec)
794                        selector = self._decode_enum(selector, attr_spec)
795                    decoded = {"value": value, "selector": selector}
796                elif attr_spec["type"] == 'sub-message':
797                    decoded = self._decode_sub_msg(attr, attr_spec, search_attrs)
798                elif attr_spec["type"] == 'nest-type-value':
799                    decoded = self._decode_nest_type_value(attr, attr_spec)
800                else:
801                    if not self.process_unknown:
802                        raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
803                    decoded = self._decode_unknown(attr)
804
805                self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded)
806            except:
807                print(f"Error decoding '{attr_spec.name}' from '{space}'")
808                raise
809
810        return rsp
811
812    def _decode_extack_path(self, attrs, attr_set, offset, target, search_attrs):
813        for attr in attrs:
814            try:
815                attr_spec = attr_set.attrs_by_val[attr.type]
816            except KeyError:
817                raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'")
818            if offset > target:
819                break
820            if offset == target:
821                return '.' + attr_spec.name
822
823            if offset + attr.full_len <= target:
824                offset += attr.full_len
825                continue
826
827            pathname = attr_spec.name
828            if attr_spec['type'] == 'nest':
829                sub_attrs = self.attr_sets[attr_spec['nested-attributes']]
830                search_attrs = SpaceAttrs(sub_attrs, search_attrs.lookup(attr_spec['name']))
831            elif attr_spec['type'] == 'sub-message':
832                msg_format, value = self._resolve_selector(attr_spec, search_attrs)
833                if msg_format is None:
834                    raise Exception(f"Can't resolve sub-message of {attr_spec['name']} for extack")
835                sub_attrs = self.attr_sets[msg_format.attr_set]
836                pathname += f"({value})"
837            else:
838                raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
839            offset += 4
840            subpath = self._decode_extack_path(NlAttrs(attr.raw), sub_attrs,
841                                               offset, target, search_attrs)
842            if subpath is None:
843                return None
844            return '.' + pathname + subpath
845
846        return None
847
848    def _decode_extack(self, request, op, extack, vals):
849        if 'bad-attr-offs' not in extack:
850            return
851
852        msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op)
853        offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header)
854        search_attrs = SpaceAttrs(op.attr_set, vals)
855        path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
856                                        extack['bad-attr-offs'], search_attrs)
857        if path:
858            del extack['bad-attr-offs']
859            extack['bad-attr'] = path
860
861    def _struct_size(self, name):
862        if name:
863            members = self.consts[name].members
864            size = 0
865            for m in members:
866                if m.type in ['pad', 'binary']:
867                    if m.struct:
868                        size += self._struct_size(m.struct)
869                    else:
870                        size += m.len
871                else:
872                    format = NlAttr.get_format(m.type, m.byte_order)
873                    size += format.size
874            return size
875        else:
876            return 0
877
878    def _decode_struct(self, data, name):
879        members = self.consts[name].members
880        attrs = dict()
881        offset = 0
882        for m in members:
883            value = None
884            if m.type == 'pad':
885                offset += m.len
886            elif m.type == 'binary':
887                if m.struct:
888                    len = self._struct_size(m.struct)
889                    value = self._decode_struct(data[offset : offset + len],
890                                                m.struct)
891                    offset += len
892                else:
893                    value = data[offset : offset + m.len]
894                    offset += m.len
895            else:
896                format = NlAttr.get_format(m.type, m.byte_order)
897                [ value ] = format.unpack_from(data, offset)
898                offset += format.size
899            if value is not None:
900                if m.enum:
901                    value = self._decode_enum(value, m)
902                elif m.display_hint:
903                    value = self._formatted_string(value, m.display_hint)
904                attrs[m.name] = value
905        return attrs
906
907    def _encode_struct(self, name, vals):
908        members = self.consts[name].members
909        attr_payload = b''
910        for m in members:
911            value = vals.pop(m.name) if m.name in vals else None
912            if m.type == 'pad':
913                attr_payload += bytearray(m.len)
914            elif m.type == 'binary':
915                if m.struct:
916                    if value is None:
917                        value = dict()
918                    attr_payload += self._encode_struct(m.struct, value)
919                else:
920                    if value is None:
921                        attr_payload += bytearray(m.len)
922                    else:
923                        attr_payload += bytes.fromhex(value)
924            else:
925                if value is None:
926                    value = 0
927                format = NlAttr.get_format(m.type, m.byte_order)
928                attr_payload += format.pack(value)
929        return attr_payload
930
931    def _formatted_string(self, raw, display_hint):
932        if display_hint == 'mac':
933            formatted = ':'.join('%02x' % b for b in raw)
934        elif display_hint == 'hex':
935            if isinstance(raw, int):
936                formatted = hex(raw)
937            else:
938                formatted = bytes.hex(raw, ' ')
939        elif display_hint in [ 'ipv4', 'ipv6' ]:
940            formatted = format(ipaddress.ip_address(raw))
941        elif display_hint == 'uuid':
942            formatted = str(uuid.UUID(bytes=raw))
943        else:
944            formatted = raw
945        return formatted
946
947    def _from_string(self, string, attr_spec):
948        if attr_spec.display_hint in ['ipv4', 'ipv6']:
949            ip = ipaddress.ip_address(string)
950            if attr_spec['type'] == 'binary':
951                raw = ip.packed
952            else:
953                raw = int(ip)
954        else:
955            raise Exception(f"Display hint '{attr_spec.display_hint}' not implemented"
956                            f" when parsing '{attr_spec['name']}'")
957        return raw
958
959    def handle_ntf(self, decoded):
960        msg = dict()
961        if self.include_raw:
962            msg['raw'] = decoded
963        op = self.rsp_by_value[decoded.cmd()]
964        attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
965        if op.fixed_header:
966            attrs.update(self._decode_struct(decoded.raw, op.fixed_header))
967
968        msg['name'] = op['name']
969        msg['msg'] = attrs
970        self.async_msg_queue.put(msg)
971
972    def check_ntf(self):
973        while True:
974            try:
975                reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT)
976            except BlockingIOError:
977                return
978
979            nms = NlMsgs(reply)
980            self._recv_dbg_print(reply, nms)
981            for nl_msg in nms:
982                if nl_msg.error:
983                    print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
984                    print(nl_msg)
985                    continue
986                if nl_msg.done:
987                    print("Netlink done while checking for ntf!?")
988                    continue
989
990                decoded = self.nlproto.decode(self, nl_msg, None)
991                if decoded.cmd() not in self.async_msg_ids:
992                    print("Unexpected msg id while checking for ntf", decoded)
993                    continue
994
995                self.handle_ntf(decoded)
996
997    def poll_ntf(self, duration=None):
998        start_time = time.time()
999        selector = selectors.DefaultSelector()
1000        selector.register(self.sock, selectors.EVENT_READ)
1001
1002        while True:
1003            try:
1004                yield self.async_msg_queue.get_nowait()
1005            except queue.Empty:
1006                if duration is not None:
1007                    timeout = start_time + duration - time.time()
1008                    if timeout <= 0:
1009                        return
1010                else:
1011                    timeout = None
1012                events = selector.select(timeout)
1013                if events:
1014                    self.check_ntf()
1015
1016    def operation_do_attributes(self, name):
1017      """
1018      For a given operation name, find and return a supported
1019      set of attributes (as a dict).
1020      """
1021      op = self.find_operation(name)
1022      if not op:
1023        return None
1024
1025      return op['do']['request']['attributes'].copy()
1026
1027    def _encode_message(self, op, vals, flags, req_seq):
1028        nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
1029        for flag in flags or []:
1030            nl_flags |= flag
1031
1032        msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
1033        if op.fixed_header:
1034            msg += self._encode_struct(op.fixed_header, vals)
1035        search_attrs = SpaceAttrs(op.attr_set, vals)
1036        for name, value in vals.items():
1037            msg += self._add_attr(op.attr_set.name, name, value, search_attrs)
1038        msg = _genl_msg_finalize(msg)
1039        return msg
1040
1041    def _ops(self, ops):
1042        reqs_by_seq = {}
1043        req_seq = random.randint(1024, 65535)
1044        payload = b''
1045        for (method, vals, flags) in ops:
1046            op = self.ops[method]
1047            msg = self._encode_message(op, vals, flags, req_seq)
1048            reqs_by_seq[req_seq] = (op, vals, msg, flags)
1049            payload += msg
1050            req_seq += 1
1051
1052        self.sock.send(payload, 0)
1053
1054        done = False
1055        rsp = []
1056        op_rsp = []
1057        while not done:
1058            reply = self.sock.recv(self._recv_size)
1059            nms = NlMsgs(reply)
1060            self._recv_dbg_print(reply, nms)
1061            for nl_msg in nms:
1062                if nl_msg.nl_seq in reqs_by_seq:
1063                    (op, vals, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq]
1064                    if nl_msg.extack:
1065                        nl_msg.annotate_extack(op.attr_set)
1066                        self._decode_extack(req_msg, op, nl_msg.extack, vals)
1067                else:
1068                    op = None
1069                    req_flags = []
1070
1071                if nl_msg.error:
1072                    raise NlError(nl_msg)
1073                if nl_msg.done:
1074                    if nl_msg.extack:
1075                        print("Netlink warning:")
1076                        print(nl_msg)
1077
1078                    if Netlink.NLM_F_DUMP in req_flags:
1079                        rsp.append(op_rsp)
1080                    elif not op_rsp:
1081                        rsp.append(None)
1082                    elif len(op_rsp) == 1:
1083                        rsp.append(op_rsp[0])
1084                    else:
1085                        rsp.append(op_rsp)
1086                    op_rsp = []
1087
1088                    del reqs_by_seq[nl_msg.nl_seq]
1089                    done = len(reqs_by_seq) == 0
1090                    break
1091
1092                decoded = self.nlproto.decode(self, nl_msg, op)
1093
1094                # Check if this is a reply to our request
1095                if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value:
1096                    if decoded.cmd() in self.async_msg_ids:
1097                        self.handle_ntf(decoded)
1098                        continue
1099                    else:
1100                        print('Unexpected message: ' + repr(decoded))
1101                        continue
1102
1103                rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
1104                if op.fixed_header:
1105                    rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header))
1106                op_rsp.append(rsp_msg)
1107
1108        return rsp
1109
1110    def _op(self, method, vals, flags=None, dump=False):
1111        req_flags = flags or []
1112        if dump:
1113            req_flags.append(Netlink.NLM_F_DUMP)
1114
1115        ops = [(method, vals, req_flags)]
1116        return self._ops(ops)[0]
1117
1118    def do(self, method, vals, flags=None):
1119        return self._op(method, vals, flags)
1120
1121    def dump(self, method, vals):
1122        return self._op(method, vals, dump=True)
1123
1124    def do_multi(self, ops):
1125        return self._ops(ops)
1126