xref: /linux/tools/net/ynl/pyynl/lib/ynl.py (revision 5c8013ae2e86ec36b07500ba4cacb14ab4d6f728)
1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2
3from collections import namedtuple
4from enum import Enum
5import functools
6import os
7import random
8import socket
9import struct
10from struct import Struct
11import sys
12import yaml
13import ipaddress
14import uuid
15import queue
16import selectors
17import time
18
19from .nlspec import SpecFamily
20
21#
22# Generic Netlink code which should really be in some library, but I can't quickly find one.
23#
24
25
26class Netlink:
27    # Netlink socket
28    SOL_NETLINK = 270
29
30    NETLINK_ADD_MEMBERSHIP = 1
31    NETLINK_CAP_ACK = 10
32    NETLINK_EXT_ACK = 11
33    NETLINK_GET_STRICT_CHK = 12
34
35    # Netlink message
36    NLMSG_ERROR = 2
37    NLMSG_DONE = 3
38
39    NLM_F_REQUEST = 1
40    NLM_F_ACK = 4
41    NLM_F_ROOT = 0x100
42    NLM_F_MATCH = 0x200
43
44    NLM_F_REPLACE = 0x100
45    NLM_F_EXCL = 0x200
46    NLM_F_CREATE = 0x400
47    NLM_F_APPEND = 0x800
48
49    NLM_F_CAPPED = 0x100
50    NLM_F_ACK_TLVS = 0x200
51
52    NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
53
54    NLA_F_NESTED = 0x8000
55    NLA_F_NET_BYTEORDER = 0x4000
56
57    NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
58
59    # Genetlink defines
60    NETLINK_GENERIC = 16
61
62    GENL_ID_CTRL = 0x10
63
64    # nlctrl
65    CTRL_CMD_GETFAMILY = 3
66
67    CTRL_ATTR_FAMILY_ID = 1
68    CTRL_ATTR_FAMILY_NAME = 2
69    CTRL_ATTR_MAXATTR = 5
70    CTRL_ATTR_MCAST_GROUPS = 7
71
72    CTRL_ATTR_MCAST_GRP_NAME = 1
73    CTRL_ATTR_MCAST_GRP_ID = 2
74
75    # Extack types
76    NLMSGERR_ATTR_MSG = 1
77    NLMSGERR_ATTR_OFFS = 2
78    NLMSGERR_ATTR_COOKIE = 3
79    NLMSGERR_ATTR_POLICY = 4
80    NLMSGERR_ATTR_MISS_TYPE = 5
81    NLMSGERR_ATTR_MISS_NEST = 6
82
83    # Policy types
84    NL_POLICY_TYPE_ATTR_TYPE = 1
85    NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2
86    NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3
87    NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4
88    NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5
89    NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6
90    NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7
91    NL_POLICY_TYPE_ATTR_POLICY_IDX = 8
92    NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9
93    NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10
94    NL_POLICY_TYPE_ATTR_PAD = 11
95    NL_POLICY_TYPE_ATTR_MASK = 12
96
97    AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64',
98                                  's8', 's16', 's32', 's64',
99                                  'binary', 'string', 'nul-string',
100                                  'nested', 'nested-array',
101                                  'bitfield32', 'sint', 'uint'])
102
103class NlError(Exception):
104  def __init__(self, nl_msg):
105    self.nl_msg = nl_msg
106    self.error = -nl_msg.error
107
108  def __str__(self):
109    return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}"
110
111
112class ConfigError(Exception):
113    pass
114
115
116class NlAttr:
117    ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
118    type_formats = {
119        'u8' : ScalarFormat(Struct('B'), Struct("B"),  Struct("B")),
120        's8' : ScalarFormat(Struct('b'), Struct("b"),  Struct("b")),
121        'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
122        's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
123        'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
124        's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
125        'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
126        's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
127    }
128
129    def __init__(self, raw, offset):
130        self._len, self._type = struct.unpack("HH", raw[offset : offset + 4])
131        self.type = self._type & ~Netlink.NLA_TYPE_MASK
132        self.is_nest = self._type & Netlink.NLA_F_NESTED
133        self.payload_len = self._len
134        self.full_len = (self.payload_len + 3) & ~3
135        self.raw = raw[offset + 4 : offset + self.payload_len]
136
137    @classmethod
138    def get_format(cls, attr_type, byte_order=None):
139        format = cls.type_formats[attr_type]
140        if byte_order:
141            return format.big if byte_order == "big-endian" \
142                else format.little
143        return format.native
144
145    def as_scalar(self, attr_type, byte_order=None):
146        format = self.get_format(attr_type, byte_order)
147        return format.unpack(self.raw)[0]
148
149    def as_auto_scalar(self, attr_type, byte_order=None):
150        if len(self.raw) != 4 and len(self.raw) != 8:
151            raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}")
152        real_type = attr_type[0] + str(len(self.raw) * 8)
153        format = self.get_format(real_type, byte_order)
154        return format.unpack(self.raw)[0]
155
156    def as_strz(self):
157        return self.raw.decode('ascii')[:-1]
158
159    def as_bin(self):
160        return self.raw
161
162    def as_c_array(self, type):
163        format = self.get_format(type)
164        return [ x[0] for x in format.iter_unpack(self.raw) ]
165
166    def __repr__(self):
167        return f"[type:{self.type} len:{self._len}] {self.raw}"
168
169
170class NlAttrs:
171    def __init__(self, msg, offset=0):
172        self.attrs = []
173
174        while offset < len(msg):
175            attr = NlAttr(msg, offset)
176            offset += attr.full_len
177            self.attrs.append(attr)
178
179    def __iter__(self):
180        yield from self.attrs
181
182    def __repr__(self):
183        msg = ''
184        for a in self.attrs:
185            if msg:
186                msg += '\n'
187            msg += repr(a)
188        return msg
189
190
191class NlMsg:
192    def __init__(self, msg, offset, attr_space=None):
193        self.hdr = msg[offset : offset + 16]
194
195        self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
196            struct.unpack("IHHII", self.hdr)
197
198        self.raw = msg[offset + 16 : offset + self.nl_len]
199
200        self.error = 0
201        self.done = 0
202
203        extack_off = None
204        if self.nl_type == Netlink.NLMSG_ERROR:
205            self.error = struct.unpack("i", self.raw[0:4])[0]
206            self.done = 1
207            extack_off = 20
208        elif self.nl_type == Netlink.NLMSG_DONE:
209            self.error = struct.unpack("i", self.raw[0:4])[0]
210            self.done = 1
211            extack_off = 4
212
213        self.extack = None
214        if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
215            self.extack = dict()
216            extack_attrs = NlAttrs(self.raw[extack_off:])
217            for extack in extack_attrs:
218                if extack.type == Netlink.NLMSGERR_ATTR_MSG:
219                    self.extack['msg'] = extack.as_strz()
220                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
221                    self.extack['miss-type'] = extack.as_scalar('u32')
222                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
223                    self.extack['miss-nest'] = extack.as_scalar('u32')
224                elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
225                    self.extack['bad-attr-offs'] = extack.as_scalar('u32')
226                elif extack.type == Netlink.NLMSGERR_ATTR_POLICY:
227                    self.extack['policy'] = self._decode_policy(extack.raw)
228                else:
229                    if 'unknown' not in self.extack:
230                        self.extack['unknown'] = []
231                    self.extack['unknown'].append(extack)
232
233            if attr_space:
234                self.annotate_extack(attr_space)
235
236    def _decode_policy(self, raw):
237        policy = {}
238        for attr in NlAttrs(raw):
239            if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE:
240                type = attr.as_scalar('u32')
241                policy['type'] = Netlink.AttrType(type).name
242            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S:
243                policy['min-value'] = attr.as_scalar('s64')
244            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S:
245                policy['max-value'] = attr.as_scalar('s64')
246            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U:
247                policy['min-value'] = attr.as_scalar('u64')
248            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U:
249                policy['max-value'] = attr.as_scalar('u64')
250            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH:
251                policy['min-length'] = attr.as_scalar('u32')
252            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH:
253                policy['max-length'] = attr.as_scalar('u32')
254            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK:
255                policy['bitfield32-mask'] = attr.as_scalar('u32')
256            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK:
257                policy['mask'] = attr.as_scalar('u64')
258        return policy
259
260    def annotate_extack(self, attr_space):
261        """ Make extack more human friendly with attribute information """
262
263        # We don't have the ability to parse nests yet, so only do global
264        if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
265            miss_type = self.extack['miss-type']
266            if miss_type in attr_space.attrs_by_val:
267                spec = attr_space.attrs_by_val[miss_type]
268                self.extack['miss-type'] = spec['name']
269                if 'doc' in spec:
270                    self.extack['miss-type-doc'] = spec['doc']
271
272    def cmd(self):
273        return self.nl_type
274
275    def __repr__(self):
276        msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}"
277        if self.error:
278            msg += '\n\terror: ' + str(self.error)
279        if self.extack:
280            msg += '\n\textack: ' + repr(self.extack)
281        return msg
282
283
284class NlMsgs:
285    def __init__(self, data):
286        self.msgs = []
287
288        offset = 0
289        while offset < len(data):
290            msg = NlMsg(data, offset)
291            offset += msg.nl_len
292            self.msgs.append(msg)
293
294    def __iter__(self):
295        yield from self.msgs
296
297
298genl_family_name_to_id = None
299
300
301def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
302    # we prepend length in _genl_msg_finalize()
303    if seq is None:
304        seq = random.randint(1, 1024)
305    nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
306    genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
307    return nlmsg + genlmsg
308
309
310def _genl_msg_finalize(msg):
311    return struct.pack("I", len(msg) + 4) + msg
312
313
314def _genl_load_families():
315    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
316        sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
317
318        msg = _genl_msg(Netlink.GENL_ID_CTRL,
319                        Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
320                        Netlink.CTRL_CMD_GETFAMILY, 1)
321        msg = _genl_msg_finalize(msg)
322
323        sock.send(msg, 0)
324
325        global genl_family_name_to_id
326        genl_family_name_to_id = dict()
327
328        while True:
329            reply = sock.recv(128 * 1024)
330            nms = NlMsgs(reply)
331            for nl_msg in nms:
332                if nl_msg.error:
333                    print("Netlink error:", nl_msg.error)
334                    return
335                if nl_msg.done:
336                    return
337
338                gm = GenlMsg(nl_msg)
339                fam = dict()
340                for attr in NlAttrs(gm.raw):
341                    if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
342                        fam['id'] = attr.as_scalar('u16')
343                    elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
344                        fam['name'] = attr.as_strz()
345                    elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
346                        fam['maxattr'] = attr.as_scalar('u32')
347                    elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
348                        fam['mcast'] = dict()
349                        for entry in NlAttrs(attr.raw):
350                            mcast_name = None
351                            mcast_id = None
352                            for entry_attr in NlAttrs(entry.raw):
353                                if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
354                                    mcast_name = entry_attr.as_strz()
355                                elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
356                                    mcast_id = entry_attr.as_scalar('u32')
357                            if mcast_name and mcast_id is not None:
358                                fam['mcast'][mcast_name] = mcast_id
359                if 'name' in fam and 'id' in fam:
360                    genl_family_name_to_id[fam['name']] = fam
361
362
363class GenlMsg:
364    def __init__(self, nl_msg):
365        self.nl = nl_msg
366        self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
367        self.raw = nl_msg.raw[4:]
368
369    def cmd(self):
370        return self.genl_cmd
371
372    def __repr__(self):
373        msg = repr(self.nl)
374        msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
375        for a in self.raw_attrs:
376            msg += '\t\t' + repr(a) + '\n'
377        return msg
378
379
380class NetlinkProtocol:
381    def __init__(self, family_name, proto_num):
382        self.family_name = family_name
383        self.proto_num = proto_num
384
385    def _message(self, nl_type, nl_flags, seq=None):
386        if seq is None:
387            seq = random.randint(1, 1024)
388        nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
389        return nlmsg
390
391    def message(self, flags, command, version, seq=None):
392        return self._message(command, flags, seq)
393
394    def _decode(self, nl_msg):
395        return nl_msg
396
397    def decode(self, ynl, nl_msg, op):
398        msg = self._decode(nl_msg)
399        if op is None:
400            op = ynl.rsp_by_value[msg.cmd()]
401        fixed_header_size = ynl._struct_size(op.fixed_header)
402        msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size)
403        return msg
404
405    def get_mcast_id(self, mcast_name, mcast_groups):
406        if mcast_name not in mcast_groups:
407            raise Exception(f'Multicast group "{mcast_name}" not present in the spec')
408        return mcast_groups[mcast_name].value
409
410    def msghdr_size(self):
411        return 16
412
413
414class GenlProtocol(NetlinkProtocol):
415    def __init__(self, family_name):
416        super().__init__(family_name, Netlink.NETLINK_GENERIC)
417
418        global genl_family_name_to_id
419        if genl_family_name_to_id is None:
420            _genl_load_families()
421
422        self.genl_family = genl_family_name_to_id[family_name]
423        self.family_id = genl_family_name_to_id[family_name]['id']
424
425    def message(self, flags, command, version, seq=None):
426        nlmsg = self._message(self.family_id, flags, seq)
427        genlmsg = struct.pack("BBH", command, version, 0)
428        return nlmsg + genlmsg
429
430    def _decode(self, nl_msg):
431        return GenlMsg(nl_msg)
432
433    def get_mcast_id(self, mcast_name, mcast_groups):
434        if mcast_name not in self.genl_family['mcast']:
435            raise Exception(f'Multicast group "{mcast_name}" not present in the family')
436        return self.genl_family['mcast'][mcast_name]
437
438    def msghdr_size(self):
439        return super().msghdr_size() + 4
440
441
442class SpaceAttrs:
443    SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values'])
444
445    def __init__(self, attr_space, attrs, outer = None):
446        outer_scopes = outer.scopes if outer else []
447        inner_scope = self.SpecValuesPair(attr_space, attrs)
448        self.scopes = [inner_scope] + outer_scopes
449
450    def lookup(self, name):
451        for scope in self.scopes:
452            if name in scope.spec:
453                if name in scope.values:
454                    return scope.values[name]
455                spec_name = scope.spec.yaml['name']
456                raise Exception(
457                    f"No value for '{name}' in attribute space '{spec_name}'")
458        raise Exception(f"Attribute '{name}' not defined in any attribute-set")
459
460
461#
462# YNL implementation details.
463#
464
465
466class YnlFamily(SpecFamily):
467    def __init__(self, def_path, schema=None, process_unknown=False,
468                 recv_size=0):
469        super().__init__(def_path, schema)
470
471        self.include_raw = False
472        self.process_unknown = process_unknown
473
474        try:
475            if self.proto == "netlink-raw":
476                self.nlproto = NetlinkProtocol(self.yaml['name'],
477                                               self.yaml['protonum'])
478            else:
479                self.nlproto = GenlProtocol(self.yaml['name'])
480        except KeyError:
481            raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
482
483        self._recv_dbg = False
484        # Note that netlink will use conservative (min) message size for
485        # the first dump recv() on the socket, our setting will only matter
486        # from the second recv() on.
487        self._recv_size = recv_size if recv_size else 131072
488        # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo)
489        # for a message, so smaller receive sizes will lead to truncation.
490        # Note that the min size for other families may be larger than 4k!
491        if self._recv_size < 4000:
492            raise ConfigError()
493
494        self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
495        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
496        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
497        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)
498
499        self.async_msg_ids = set()
500        self.async_msg_queue = queue.Queue()
501
502        for msg in self.msgs.values():
503            if msg.is_async:
504                self.async_msg_ids.add(msg.rsp_value)
505
506        for op_name, op in self.ops.items():
507            bound_f = functools.partial(self._op, op_name)
508            setattr(self, op.ident_name, bound_f)
509
510
511    def ntf_subscribe(self, mcast_name):
512        mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
513        self.sock.bind((0, 0))
514        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
515                             mcast_id)
516
517    def set_recv_dbg(self, enabled):
518        self._recv_dbg = enabled
519
520    def _recv_dbg_print(self, reply, nl_msgs):
521        if not self._recv_dbg:
522            return
523        print("Recv: read", len(reply), "bytes,",
524              len(nl_msgs.msgs), "messages", file=sys.stderr)
525        for nl_msg in nl_msgs:
526            print("  ", nl_msg, file=sys.stderr)
527
528    def _encode_enum(self, attr_spec, value):
529        enum = self.consts[attr_spec['enum']]
530        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
531            scalar = 0
532            if isinstance(value, str):
533                value = [value]
534            for single_value in value:
535                scalar += enum.entries[single_value].user_value(as_flags = True)
536            return scalar
537        else:
538            return enum.entries[value].user_value()
539
540    def _get_scalar(self, attr_spec, value):
541        try:
542            return int(value)
543        except (ValueError, TypeError) as e:
544            if 'enum' in attr_spec:
545                return self._encode_enum(attr_spec, value)
546            if attr_spec.display_hint:
547                return self._from_string(value, attr_spec)
548            raise e
549
550    def _add_attr(self, space, name, value, search_attrs):
551        try:
552            attr = self.attr_sets[space][name]
553        except KeyError:
554            raise Exception(f"Space '{space}' has no attribute '{name}'")
555        nl_type = attr.value
556
557        if attr.is_multi and isinstance(value, list):
558            attr_payload = b''
559            for subvalue in value:
560                attr_payload += self._add_attr(space, name, subvalue, search_attrs)
561            return attr_payload
562
563        if attr["type"] == 'nest':
564            nl_type |= Netlink.NLA_F_NESTED
565            attr_payload = b''
566            sub_space = attr['nested-attributes']
567            sub_attrs = SpaceAttrs(self.attr_sets[sub_space], value, search_attrs)
568            for subname, subvalue in value.items():
569                attr_payload += self._add_attr(sub_space, subname, subvalue, sub_attrs)
570        elif attr["type"] == 'flag':
571            if not value:
572                # If value is absent or false then skip attribute creation.
573                return b''
574            attr_payload = b''
575        elif attr["type"] == 'string':
576            attr_payload = str(value).encode('ascii') + b'\x00'
577        elif attr["type"] == 'binary':
578            if isinstance(value, bytes):
579                attr_payload = value
580            elif isinstance(value, str):
581                if attr.display_hint:
582                    attr_payload = self._from_string(value, attr)
583                else:
584                    attr_payload = bytes.fromhex(value)
585            elif isinstance(value, dict) and attr.struct_name:
586                attr_payload = self._encode_struct(attr.struct_name, value)
587            else:
588                raise Exception(f'Unknown type for binary attribute, value: {value}')
589        elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar:
590            scalar = self._get_scalar(attr, value)
591            if attr.is_auto_scalar:
592                attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64')
593            else:
594                attr_type = attr["type"]
595            format = NlAttr.get_format(attr_type, attr.byte_order)
596            attr_payload = format.pack(scalar)
597        elif attr['type'] in "bitfield32":
598            scalar_value = self._get_scalar(attr, value["value"])
599            scalar_selector = self._get_scalar(attr, value["selector"])
600            attr_payload = struct.pack("II", scalar_value, scalar_selector)
601        elif attr['type'] == 'sub-message':
602            msg_format, _ = self._resolve_selector(attr, search_attrs)
603            attr_payload = b''
604            if msg_format.fixed_header:
605                attr_payload += self._encode_struct(msg_format.fixed_header, value)
606            if msg_format.attr_set:
607                if msg_format.attr_set in self.attr_sets:
608                    nl_type |= Netlink.NLA_F_NESTED
609                    sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs)
610                    for subname, subvalue in value.items():
611                        attr_payload += self._add_attr(msg_format.attr_set,
612                                                       subname, subvalue, sub_attrs)
613                else:
614                    raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'")
615        else:
616            raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
617
618        pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
619        return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
620
621    def _decode_enum(self, raw, attr_spec):
622        enum = self.consts[attr_spec['enum']]
623        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
624            i = 0
625            value = set()
626            while raw:
627                if raw & 1:
628                    value.add(enum.entries_by_val[i].name)
629                raw >>= 1
630                i += 1
631        else:
632            value = enum.entries_by_val[raw].name
633        return value
634
635    def _decode_binary(self, attr, attr_spec):
636        if attr_spec.struct_name:
637            decoded = self._decode_struct(attr.raw, attr_spec.struct_name)
638        elif attr_spec.sub_type:
639            decoded = attr.as_c_array(attr_spec.sub_type)
640            if 'enum' in attr_spec:
641                decoded = [ self._decode_enum(x, attr_spec) for x in decoded ]
642            elif attr_spec.display_hint:
643                decoded = [ self._formatted_string(x, attr_spec.display_hint)
644                            for x in decoded ]
645        else:
646            decoded = attr.as_bin()
647            if attr_spec.display_hint:
648                decoded = self._formatted_string(decoded, attr_spec.display_hint)
649        return decoded
650
651    def _decode_array_attr(self, attr, attr_spec):
652        decoded = []
653        offset = 0
654        while offset < len(attr.raw):
655            item = NlAttr(attr.raw, offset)
656            offset += item.full_len
657
658            if attr_spec["sub-type"] == 'nest':
659                subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes'])
660                decoded.append({ item.type: subattrs })
661            elif attr_spec["sub-type"] == 'binary':
662                subattr = item.as_bin()
663                if attr_spec.display_hint:
664                    subattr = self._formatted_string(subattr, attr_spec.display_hint)
665                decoded.append(subattr)
666            elif attr_spec["sub-type"] in NlAttr.type_formats:
667                subattr = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order)
668                if 'enum' in attr_spec:
669                    subattr = self._decode_enum(subattr, attr_spec)
670                elif attr_spec.display_hint:
671                    subattr = self._formatted_string(subattr, attr_spec.display_hint)
672                decoded.append(subattr)
673            else:
674                raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}')
675        return decoded
676
677    def _decode_nest_type_value(self, attr, attr_spec):
678        decoded = {}
679        value = attr
680        for name in attr_spec['type-value']:
681            value = NlAttr(value.raw, 0)
682            decoded[name] = value.type
683        subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes'])
684        decoded.update(subattrs)
685        return decoded
686
687    def _decode_unknown(self, attr):
688        if attr.is_nest:
689            return self._decode(NlAttrs(attr.raw), None)
690        else:
691            return attr.as_bin()
692
693    def _rsp_add(self, rsp, name, is_multi, decoded):
694        if is_multi == None:
695            if name in rsp and type(rsp[name]) is not list:
696                rsp[name] = [rsp[name]]
697                is_multi = True
698            else:
699                is_multi = False
700
701        if not is_multi:
702            rsp[name] = decoded
703        elif name in rsp:
704            rsp[name].append(decoded)
705        else:
706            rsp[name] = [decoded]
707
708    def _resolve_selector(self, attr_spec, search_attrs):
709        sub_msg = attr_spec.sub_message
710        if sub_msg not in self.sub_msgs:
711            raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}")
712        sub_msg_spec = self.sub_msgs[sub_msg]
713
714        selector = attr_spec.selector
715        value = search_attrs.lookup(selector)
716        if value not in sub_msg_spec.formats:
717            raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'")
718
719        spec = sub_msg_spec.formats[value]
720        return spec, value
721
722    def _decode_sub_msg(self, attr, attr_spec, search_attrs):
723        msg_format, _ = self._resolve_selector(attr_spec, search_attrs)
724        decoded = {}
725        offset = 0
726        if msg_format.fixed_header:
727            decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header));
728            offset = self._struct_size(msg_format.fixed_header)
729        if msg_format.attr_set:
730            if msg_format.attr_set in self.attr_sets:
731                subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set)
732                decoded.update(subdict)
733            else:
734                raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'")
735        return decoded
736
737    def _decode(self, attrs, space, outer_attrs = None):
738        rsp = dict()
739        if space:
740            attr_space = self.attr_sets[space]
741            search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs)
742
743        for attr in attrs:
744            try:
745                attr_spec = attr_space.attrs_by_val[attr.type]
746            except (KeyError, UnboundLocalError):
747                if not self.process_unknown:
748                    raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'")
749                attr_name = f"UnknownAttr({attr.type})"
750                self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr))
751                continue
752
753            try:
754                if attr_spec["type"] == 'nest':
755                    subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs)
756                    decoded = subdict
757                elif attr_spec["type"] == 'string':
758                    decoded = attr.as_strz()
759                elif attr_spec["type"] == 'binary':
760                    decoded = self._decode_binary(attr, attr_spec)
761                elif attr_spec["type"] == 'flag':
762                    decoded = True
763                elif attr_spec.is_auto_scalar:
764                    decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order)
765                elif attr_spec["type"] in NlAttr.type_formats:
766                    decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
767                    if 'enum' in attr_spec:
768                        decoded = self._decode_enum(decoded, attr_spec)
769                    elif attr_spec.display_hint:
770                        decoded = self._formatted_string(decoded, attr_spec.display_hint)
771                elif attr_spec["type"] == 'indexed-array':
772                    decoded = self._decode_array_attr(attr, attr_spec)
773                elif attr_spec["type"] == 'bitfield32':
774                    value, selector = struct.unpack("II", attr.raw)
775                    if 'enum' in attr_spec:
776                        value = self._decode_enum(value, attr_spec)
777                        selector = self._decode_enum(selector, attr_spec)
778                    decoded = {"value": value, "selector": selector}
779                elif attr_spec["type"] == 'sub-message':
780                    decoded = self._decode_sub_msg(attr, attr_spec, search_attrs)
781                elif attr_spec["type"] == 'nest-type-value':
782                    decoded = self._decode_nest_type_value(attr, attr_spec)
783                else:
784                    if not self.process_unknown:
785                        raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
786                    decoded = self._decode_unknown(attr)
787
788                self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded)
789            except:
790                print(f"Error decoding '{attr_spec.name}' from '{space}'")
791                raise
792
793        return rsp
794
795    def _decode_extack_path(self, attrs, attr_set, offset, target, search_attrs):
796        for attr in attrs:
797            try:
798                attr_spec = attr_set.attrs_by_val[attr.type]
799            except KeyError:
800                raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'")
801            if offset > target:
802                break
803            if offset == target:
804                return '.' + attr_spec.name
805
806            if offset + attr.full_len <= target:
807                offset += attr.full_len
808                continue
809
810            pathname = attr_spec.name
811            if attr_spec['type'] == 'nest':
812                sub_attrs = self.attr_sets[attr_spec['nested-attributes']]
813                search_attrs = SpaceAttrs(sub_attrs, search_attrs.lookup(attr_spec['name']))
814            elif attr_spec['type'] == 'sub-message':
815                msg_format, value = self._resolve_selector(attr_spec, search_attrs)
816                if msg_format is None:
817                    raise Exception(f"Can't resolve sub-message of {attr_spec['name']} for extack")
818                sub_attrs = self.attr_sets[msg_format.attr_set]
819                pathname += f"({value})"
820            else:
821                raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
822            offset += 4
823            subpath = self._decode_extack_path(NlAttrs(attr.raw), sub_attrs,
824                                               offset, target, search_attrs)
825            if subpath is None:
826                return None
827            return '.' + pathname + subpath
828
829        return None
830
831    def _decode_extack(self, request, op, extack, vals):
832        if 'bad-attr-offs' not in extack:
833            return
834
835        msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op)
836        offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header)
837        search_attrs = SpaceAttrs(op.attr_set, vals)
838        path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
839                                        extack['bad-attr-offs'], search_attrs)
840        if path:
841            del extack['bad-attr-offs']
842            extack['bad-attr'] = path
843
844    def _struct_size(self, name):
845        if name:
846            members = self.consts[name].members
847            size = 0
848            for m in members:
849                if m.type in ['pad', 'binary']:
850                    if m.struct:
851                        size += self._struct_size(m.struct)
852                    else:
853                        size += m.len
854                else:
855                    format = NlAttr.get_format(m.type, m.byte_order)
856                    size += format.size
857            return size
858        else:
859            return 0
860
861    def _decode_struct(self, data, name):
862        members = self.consts[name].members
863        attrs = dict()
864        offset = 0
865        for m in members:
866            value = None
867            if m.type == 'pad':
868                offset += m.len
869            elif m.type == 'binary':
870                if m.struct:
871                    len = self._struct_size(m.struct)
872                    value = self._decode_struct(data[offset : offset + len],
873                                                m.struct)
874                    offset += len
875                else:
876                    value = data[offset : offset + m.len]
877                    offset += m.len
878            else:
879                format = NlAttr.get_format(m.type, m.byte_order)
880                [ value ] = format.unpack_from(data, offset)
881                offset += format.size
882            if value is not None:
883                if m.enum:
884                    value = self._decode_enum(value, m)
885                elif m.display_hint:
886                    value = self._formatted_string(value, m.display_hint)
887                attrs[m.name] = value
888        return attrs
889
890    def _encode_struct(self, name, vals):
891        members = self.consts[name].members
892        attr_payload = b''
893        for m in members:
894            value = vals.pop(m.name) if m.name in vals else None
895            if m.type == 'pad':
896                attr_payload += bytearray(m.len)
897            elif m.type == 'binary':
898                if m.struct:
899                    if value is None:
900                        value = dict()
901                    attr_payload += self._encode_struct(m.struct, value)
902                else:
903                    if value is None:
904                        attr_payload += bytearray(m.len)
905                    else:
906                        attr_payload += bytes.fromhex(value)
907            else:
908                if value is None:
909                    value = 0
910                format = NlAttr.get_format(m.type, m.byte_order)
911                attr_payload += format.pack(value)
912        return attr_payload
913
914    def _formatted_string(self, raw, display_hint):
915        if display_hint == 'mac':
916            formatted = ':'.join('%02x' % b for b in raw)
917        elif display_hint == 'hex':
918            if isinstance(raw, int):
919                formatted = hex(raw)
920            else:
921                formatted = bytes.hex(raw, ' ')
922        elif display_hint in [ 'ipv4', 'ipv6' ]:
923            formatted = format(ipaddress.ip_address(raw))
924        elif display_hint == 'uuid':
925            formatted = str(uuid.UUID(bytes=raw))
926        else:
927            formatted = raw
928        return formatted
929
930    def _from_string(self, string, attr_spec):
931        if attr_spec.display_hint in ['ipv4', 'ipv6']:
932            ip = ipaddress.ip_address(string)
933            if attr_spec['type'] == 'binary':
934                raw = ip.packed
935            else:
936                raw = int(ip)
937        else:
938            raise Exception(f"Display hint '{attr_spec.display_hint}' not implemented"
939                            f" when parsing '{attr_spec['name']}'")
940        return raw
941
942    def handle_ntf(self, decoded):
943        msg = dict()
944        if self.include_raw:
945            msg['raw'] = decoded
946        op = self.rsp_by_value[decoded.cmd()]
947        attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
948        if op.fixed_header:
949            attrs.update(self._decode_struct(decoded.raw, op.fixed_header))
950
951        msg['name'] = op['name']
952        msg['msg'] = attrs
953        self.async_msg_queue.put(msg)
954
955    def check_ntf(self):
956        while True:
957            try:
958                reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT)
959            except BlockingIOError:
960                return
961
962            nms = NlMsgs(reply)
963            self._recv_dbg_print(reply, nms)
964            for nl_msg in nms:
965                if nl_msg.error:
966                    print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
967                    print(nl_msg)
968                    continue
969                if nl_msg.done:
970                    print("Netlink done while checking for ntf!?")
971                    continue
972
973                decoded = self.nlproto.decode(self, nl_msg, None)
974                if decoded.cmd() not in self.async_msg_ids:
975                    print("Unexpected msg id while checking for ntf", decoded)
976                    continue
977
978                self.handle_ntf(decoded)
979
980    def poll_ntf(self, duration=None):
981        start_time = time.time()
982        selector = selectors.DefaultSelector()
983        selector.register(self.sock, selectors.EVENT_READ)
984
985        while True:
986            try:
987                yield self.async_msg_queue.get_nowait()
988            except queue.Empty:
989                if duration is not None:
990                    timeout = start_time + duration - time.time()
991                    if timeout <= 0:
992                        return
993                else:
994                    timeout = None
995                events = selector.select(timeout)
996                if events:
997                    self.check_ntf()
998
999    def operation_do_attributes(self, name):
1000      """
1001      For a given operation name, find and return a supported
1002      set of attributes (as a dict).
1003      """
1004      op = self.find_operation(name)
1005      if not op:
1006        return None
1007
1008      return op['do']['request']['attributes'].copy()
1009
1010    def _encode_message(self, op, vals, flags, req_seq):
1011        nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
1012        for flag in flags or []:
1013            nl_flags |= flag
1014
1015        msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
1016        if op.fixed_header:
1017            msg += self._encode_struct(op.fixed_header, vals)
1018        search_attrs = SpaceAttrs(op.attr_set, vals)
1019        for name, value in vals.items():
1020            msg += self._add_attr(op.attr_set.name, name, value, search_attrs)
1021        msg = _genl_msg_finalize(msg)
1022        return msg
1023
1024    def _ops(self, ops):
1025        reqs_by_seq = {}
1026        req_seq = random.randint(1024, 65535)
1027        payload = b''
1028        for (method, vals, flags) in ops:
1029            op = self.ops[method]
1030            msg = self._encode_message(op, vals, flags, req_seq)
1031            reqs_by_seq[req_seq] = (op, vals, msg, flags)
1032            payload += msg
1033            req_seq += 1
1034
1035        self.sock.send(payload, 0)
1036
1037        done = False
1038        rsp = []
1039        op_rsp = []
1040        while not done:
1041            reply = self.sock.recv(self._recv_size)
1042            nms = NlMsgs(reply)
1043            self._recv_dbg_print(reply, nms)
1044            for nl_msg in nms:
1045                if nl_msg.nl_seq in reqs_by_seq:
1046                    (op, vals, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq]
1047                    if nl_msg.extack:
1048                        nl_msg.annotate_extack(op.attr_set)
1049                        self._decode_extack(req_msg, op, nl_msg.extack, vals)
1050                else:
1051                    op = None
1052                    req_flags = []
1053
1054                if nl_msg.error:
1055                    raise NlError(nl_msg)
1056                if nl_msg.done:
1057                    if nl_msg.extack:
1058                        print("Netlink warning:")
1059                        print(nl_msg)
1060
1061                    if Netlink.NLM_F_DUMP in req_flags:
1062                        rsp.append(op_rsp)
1063                    elif not op_rsp:
1064                        rsp.append(None)
1065                    elif len(op_rsp) == 1:
1066                        rsp.append(op_rsp[0])
1067                    else:
1068                        rsp.append(op_rsp)
1069                    op_rsp = []
1070
1071                    del reqs_by_seq[nl_msg.nl_seq]
1072                    done = len(reqs_by_seq) == 0
1073                    break
1074
1075                decoded = self.nlproto.decode(self, nl_msg, op)
1076
1077                # Check if this is a reply to our request
1078                if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value:
1079                    if decoded.cmd() in self.async_msg_ids:
1080                        self.handle_ntf(decoded)
1081                        continue
1082                    else:
1083                        print('Unexpected message: ' + repr(decoded))
1084                        continue
1085
1086                rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
1087                if op.fixed_header:
1088                    rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header))
1089                op_rsp.append(rsp_msg)
1090
1091        return rsp
1092
1093    def _op(self, method, vals, flags=None, dump=False):
1094        req_flags = flags or []
1095        if dump:
1096            req_flags.append(Netlink.NLM_F_DUMP)
1097
1098        ops = [(method, vals, req_flags)]
1099        return self._ops(ops)[0]
1100
1101    def do(self, method, vals, flags=None):
1102        return self._op(method, vals, flags)
1103
1104    def dump(self, method, vals):
1105        return self._op(method, vals, dump=True)
1106
1107    def do_multi(self, ops):
1108        return self._ops(ops)
1109