xref: /linux/tools/net/ynl/pyynl/lib/ynl.py (revision d103f26a5c8599385acb2d2e01dfbaedb00fdc0a)
1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2
3from collections import namedtuple
4from enum import Enum
5import functools
6import os
7import random
8import socket
9import struct
10from struct import Struct
11import sys
12import ipaddress
13import uuid
14import queue
15import selectors
16import time
17
18from .nlspec import SpecFamily
19
20#
21# Generic Netlink code which should really be in some library, but I can't quickly find one.
22#
23
24
25class Netlink:
26    # Netlink socket
27    SOL_NETLINK = 270
28
29    NETLINK_ADD_MEMBERSHIP = 1
30    NETLINK_CAP_ACK = 10
31    NETLINK_EXT_ACK = 11
32    NETLINK_GET_STRICT_CHK = 12
33
34    # Netlink message
35    NLMSG_ERROR = 2
36    NLMSG_DONE = 3
37
38    NLM_F_REQUEST = 1
39    NLM_F_ACK = 4
40    NLM_F_ROOT = 0x100
41    NLM_F_MATCH = 0x200
42
43    NLM_F_REPLACE = 0x100
44    NLM_F_EXCL = 0x200
45    NLM_F_CREATE = 0x400
46    NLM_F_APPEND = 0x800
47
48    NLM_F_CAPPED = 0x100
49    NLM_F_ACK_TLVS = 0x200
50
51    NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
52
53    NLA_F_NESTED = 0x8000
54    NLA_F_NET_BYTEORDER = 0x4000
55
56    NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
57
58    # Genetlink defines
59    NETLINK_GENERIC = 16
60
61    GENL_ID_CTRL = 0x10
62
63    # nlctrl
64    CTRL_CMD_GETFAMILY = 3
65
66    CTRL_ATTR_FAMILY_ID = 1
67    CTRL_ATTR_FAMILY_NAME = 2
68    CTRL_ATTR_MAXATTR = 5
69    CTRL_ATTR_MCAST_GROUPS = 7
70
71    CTRL_ATTR_MCAST_GRP_NAME = 1
72    CTRL_ATTR_MCAST_GRP_ID = 2
73
74    # Extack types
75    NLMSGERR_ATTR_MSG = 1
76    NLMSGERR_ATTR_OFFS = 2
77    NLMSGERR_ATTR_COOKIE = 3
78    NLMSGERR_ATTR_POLICY = 4
79    NLMSGERR_ATTR_MISS_TYPE = 5
80    NLMSGERR_ATTR_MISS_NEST = 6
81
82    # Policy types
83    NL_POLICY_TYPE_ATTR_TYPE = 1
84    NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2
85    NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3
86    NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4
87    NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5
88    NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6
89    NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7
90    NL_POLICY_TYPE_ATTR_POLICY_IDX = 8
91    NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9
92    NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10
93    NL_POLICY_TYPE_ATTR_PAD = 11
94    NL_POLICY_TYPE_ATTR_MASK = 12
95
96    AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64',
97                                  's8', 's16', 's32', 's64',
98                                  'binary', 'string', 'nul-string',
99                                  'nested', 'nested-array',
100                                  'bitfield32', 'sint', 'uint'])
101
102class NlError(Exception):
103  def __init__(self, nl_msg):
104    self.nl_msg = nl_msg
105    self.error = -nl_msg.error
106
107  def __str__(self):
108    return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}"
109
110
111class ConfigError(Exception):
112    pass
113
114
115class NlAttr:
116    ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
117    type_formats = {
118        'u8' : ScalarFormat(Struct('B'), Struct("B"),  Struct("B")),
119        's8' : ScalarFormat(Struct('b'), Struct("b"),  Struct("b")),
120        'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
121        's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
122        'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
123        's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
124        'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
125        's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
126    }
127
128    def __init__(self, raw, offset):
129        self._len, self._type = struct.unpack("HH", raw[offset : offset + 4])
130        self.type = self._type & ~Netlink.NLA_TYPE_MASK
131        self.is_nest = self._type & Netlink.NLA_F_NESTED
132        self.payload_len = self._len
133        self.full_len = (self.payload_len + 3) & ~3
134        self.raw = raw[offset + 4 : offset + self.payload_len]
135
136    @classmethod
137    def get_format(cls, attr_type, byte_order=None):
138        format = cls.type_formats[attr_type]
139        if byte_order:
140            return format.big if byte_order == "big-endian" \
141                else format.little
142        return format.native
143
144    def as_scalar(self, attr_type, byte_order=None):
145        format = self.get_format(attr_type, byte_order)
146        return format.unpack(self.raw)[0]
147
148    def as_auto_scalar(self, attr_type, byte_order=None):
149        if len(self.raw) != 4 and len(self.raw) != 8:
150            raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}")
151        real_type = attr_type[0] + str(len(self.raw) * 8)
152        format = self.get_format(real_type, byte_order)
153        return format.unpack(self.raw)[0]
154
155    def as_strz(self):
156        return self.raw.decode('ascii')[:-1]
157
158    def as_bin(self):
159        return self.raw
160
161    def as_c_array(self, type):
162        format = self.get_format(type)
163        return [ x[0] for x in format.iter_unpack(self.raw) ]
164
165    def __repr__(self):
166        return f"[type:{self.type} len:{self._len}] {self.raw}"
167
168
169class NlAttrs:
170    def __init__(self, msg, offset=0):
171        self.attrs = []
172
173        while offset < len(msg):
174            attr = NlAttr(msg, offset)
175            offset += attr.full_len
176            self.attrs.append(attr)
177
178    def __iter__(self):
179        yield from self.attrs
180
181    def __repr__(self):
182        msg = ''
183        for a in self.attrs:
184            if msg:
185                msg += '\n'
186            msg += repr(a)
187        return msg
188
189
190class NlMsg:
191    def __init__(self, msg, offset, attr_space=None):
192        self.hdr = msg[offset : offset + 16]
193
194        self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
195            struct.unpack("IHHII", self.hdr)
196
197        self.raw = msg[offset + 16 : offset + self.nl_len]
198
199        self.error = 0
200        self.done = 0
201
202        extack_off = None
203        if self.nl_type == Netlink.NLMSG_ERROR:
204            self.error = struct.unpack("i", self.raw[0:4])[0]
205            self.done = 1
206            extack_off = 20
207        elif self.nl_type == Netlink.NLMSG_DONE:
208            self.error = struct.unpack("i", self.raw[0:4])[0]
209            self.done = 1
210            extack_off = 4
211
212        self.extack = None
213        if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
214            self.extack = dict()
215            extack_attrs = NlAttrs(self.raw[extack_off:])
216            for extack in extack_attrs:
217                if extack.type == Netlink.NLMSGERR_ATTR_MSG:
218                    self.extack['msg'] = extack.as_strz()
219                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
220                    self.extack['miss-type'] = extack.as_scalar('u32')
221                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
222                    self.extack['miss-nest'] = extack.as_scalar('u32')
223                elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
224                    self.extack['bad-attr-offs'] = extack.as_scalar('u32')
225                elif extack.type == Netlink.NLMSGERR_ATTR_POLICY:
226                    self.extack['policy'] = self._decode_policy(extack.raw)
227                else:
228                    if 'unknown' not in self.extack:
229                        self.extack['unknown'] = []
230                    self.extack['unknown'].append(extack)
231
232            if attr_space:
233                self.annotate_extack(attr_space)
234
235    def _decode_policy(self, raw):
236        policy = {}
237        for attr in NlAttrs(raw):
238            if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE:
239                type = attr.as_scalar('u32')
240                policy['type'] = Netlink.AttrType(type).name
241            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S:
242                policy['min-value'] = attr.as_scalar('s64')
243            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S:
244                policy['max-value'] = attr.as_scalar('s64')
245            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U:
246                policy['min-value'] = attr.as_scalar('u64')
247            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U:
248                policy['max-value'] = attr.as_scalar('u64')
249            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH:
250                policy['min-length'] = attr.as_scalar('u32')
251            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH:
252                policy['max-length'] = attr.as_scalar('u32')
253            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK:
254                policy['bitfield32-mask'] = attr.as_scalar('u32')
255            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK:
256                policy['mask'] = attr.as_scalar('u64')
257        return policy
258
259    def annotate_extack(self, attr_space):
260        """ Make extack more human friendly with attribute information """
261
262        # We don't have the ability to parse nests yet, so only do global
263        if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
264            miss_type = self.extack['miss-type']
265            if miss_type in attr_space.attrs_by_val:
266                spec = attr_space.attrs_by_val[miss_type]
267                self.extack['miss-type'] = spec['name']
268                if 'doc' in spec:
269                    self.extack['miss-type-doc'] = spec['doc']
270
271    def cmd(self):
272        return self.nl_type
273
274    def __repr__(self):
275        msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}"
276        if self.error:
277            msg += '\n\terror: ' + str(self.error)
278        if self.extack:
279            msg += '\n\textack: ' + repr(self.extack)
280        return msg
281
282
283class NlMsgs:
284    def __init__(self, data):
285        self.msgs = []
286
287        offset = 0
288        while offset < len(data):
289            msg = NlMsg(data, offset)
290            offset += msg.nl_len
291            self.msgs.append(msg)
292
293    def __iter__(self):
294        yield from self.msgs
295
296
297genl_family_name_to_id = None
298
299
300def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
301    # we prepend length in _genl_msg_finalize()
302    if seq is None:
303        seq = random.randint(1, 1024)
304    nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
305    genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
306    return nlmsg + genlmsg
307
308
309def _genl_msg_finalize(msg):
310    return struct.pack("I", len(msg) + 4) + msg
311
312
313def _genl_load_families():
314    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
315        sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
316
317        msg = _genl_msg(Netlink.GENL_ID_CTRL,
318                        Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
319                        Netlink.CTRL_CMD_GETFAMILY, 1)
320        msg = _genl_msg_finalize(msg)
321
322        sock.send(msg, 0)
323
324        global genl_family_name_to_id
325        genl_family_name_to_id = dict()
326
327        while True:
328            reply = sock.recv(128 * 1024)
329            nms = NlMsgs(reply)
330            for nl_msg in nms:
331                if nl_msg.error:
332                    print("Netlink error:", nl_msg.error)
333                    return
334                if nl_msg.done:
335                    return
336
337                gm = GenlMsg(nl_msg)
338                fam = dict()
339                for attr in NlAttrs(gm.raw):
340                    if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
341                        fam['id'] = attr.as_scalar('u16')
342                    elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
343                        fam['name'] = attr.as_strz()
344                    elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
345                        fam['maxattr'] = attr.as_scalar('u32')
346                    elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
347                        fam['mcast'] = dict()
348                        for entry in NlAttrs(attr.raw):
349                            mcast_name = None
350                            mcast_id = None
351                            for entry_attr in NlAttrs(entry.raw):
352                                if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
353                                    mcast_name = entry_attr.as_strz()
354                                elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
355                                    mcast_id = entry_attr.as_scalar('u32')
356                            if mcast_name and mcast_id is not None:
357                                fam['mcast'][mcast_name] = mcast_id
358                if 'name' in fam and 'id' in fam:
359                    genl_family_name_to_id[fam['name']] = fam
360
361
362class GenlMsg:
363    def __init__(self, nl_msg):
364        self.nl = nl_msg
365        self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
366        self.raw = nl_msg.raw[4:]
367
368    def cmd(self):
369        return self.genl_cmd
370
371    def __repr__(self):
372        msg = repr(self.nl)
373        msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
374        for a in self.raw_attrs:
375            msg += '\t\t' + repr(a) + '\n'
376        return msg
377
378
379class NetlinkProtocol:
380    def __init__(self, family_name, proto_num):
381        self.family_name = family_name
382        self.proto_num = proto_num
383
384    def _message(self, nl_type, nl_flags, seq=None):
385        if seq is None:
386            seq = random.randint(1, 1024)
387        nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
388        return nlmsg
389
390    def message(self, flags, command, version, seq=None):
391        return self._message(command, flags, seq)
392
393    def _decode(self, nl_msg):
394        return nl_msg
395
396    def decode(self, ynl, nl_msg, op):
397        msg = self._decode(nl_msg)
398        if op is None:
399            op = ynl.rsp_by_value[msg.cmd()]
400        fixed_header_size = ynl._struct_size(op.fixed_header)
401        msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size)
402        return msg
403
404    def get_mcast_id(self, mcast_name, mcast_groups):
405        if mcast_name not in mcast_groups:
406            raise Exception(f'Multicast group "{mcast_name}" not present in the spec')
407        return mcast_groups[mcast_name].value
408
409    def msghdr_size(self):
410        return 16
411
412
413class GenlProtocol(NetlinkProtocol):
414    def __init__(self, family_name):
415        super().__init__(family_name, Netlink.NETLINK_GENERIC)
416
417        global genl_family_name_to_id
418        if genl_family_name_to_id is None:
419            _genl_load_families()
420
421        self.genl_family = genl_family_name_to_id[family_name]
422        self.family_id = genl_family_name_to_id[family_name]['id']
423
424    def message(self, flags, command, version, seq=None):
425        nlmsg = self._message(self.family_id, flags, seq)
426        genlmsg = struct.pack("BBH", command, version, 0)
427        return nlmsg + genlmsg
428
429    def _decode(self, nl_msg):
430        return GenlMsg(nl_msg)
431
432    def get_mcast_id(self, mcast_name, mcast_groups):
433        if mcast_name not in self.genl_family['mcast']:
434            raise Exception(f'Multicast group "{mcast_name}" not present in the family')
435        return self.genl_family['mcast'][mcast_name]
436
437    def msghdr_size(self):
438        return super().msghdr_size() + 4
439
440
441class SpaceAttrs:
442    SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values'])
443
444    def __init__(self, attr_space, attrs, outer = None):
445        outer_scopes = outer.scopes if outer else []
446        inner_scope = self.SpecValuesPair(attr_space, attrs)
447        self.scopes = [inner_scope] + outer_scopes
448
449    def lookup(self, name):
450        for scope in self.scopes:
451            if name in scope.spec:
452                if name in scope.values:
453                    return scope.values[name]
454                spec_name = scope.spec.yaml['name']
455                raise Exception(
456                    f"No value for '{name}' in attribute space '{spec_name}'")
457        raise Exception(f"Attribute '{name}' not defined in any attribute-set")
458
459
460#
461# YNL implementation details.
462#
463
464
465class YnlFamily(SpecFamily):
466    def __init__(self, def_path, schema=None, process_unknown=False,
467                 recv_size=0):
468        super().__init__(def_path, schema)
469
470        self.include_raw = False
471        self.process_unknown = process_unknown
472
473        try:
474            if self.proto == "netlink-raw":
475                self.nlproto = NetlinkProtocol(self.yaml['name'],
476                                               self.yaml['protonum'])
477            else:
478                self.nlproto = GenlProtocol(self.yaml['name'])
479        except KeyError:
480            raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
481
482        self._recv_dbg = False
483        # Note that netlink will use conservative (min) message size for
484        # the first dump recv() on the socket, our setting will only matter
485        # from the second recv() on.
486        self._recv_size = recv_size if recv_size else 131072
487        # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo)
488        # for a message, so smaller receive sizes will lead to truncation.
489        # Note that the min size for other families may be larger than 4k!
490        if self._recv_size < 4000:
491            raise ConfigError()
492
493        self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
494        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
495        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
496        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)
497
498        self.async_msg_ids = set()
499        self.async_msg_queue = queue.Queue()
500
501        for msg in self.msgs.values():
502            if msg.is_async:
503                self.async_msg_ids.add(msg.rsp_value)
504
505        for op_name, op in self.ops.items():
506            bound_f = functools.partial(self._op, op_name)
507            setattr(self, op.ident_name, bound_f)
508
509
510    def ntf_subscribe(self, mcast_name):
511        mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
512        self.sock.bind((0, 0))
513        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
514                             mcast_id)
515
516    def set_recv_dbg(self, enabled):
517        self._recv_dbg = enabled
518
519    def _recv_dbg_print(self, reply, nl_msgs):
520        if not self._recv_dbg:
521            return
522        print("Recv: read", len(reply), "bytes,",
523              len(nl_msgs.msgs), "messages", file=sys.stderr)
524        for nl_msg in nl_msgs:
525            print("  ", nl_msg, file=sys.stderr)
526
527    def _encode_enum(self, attr_spec, value):
528        enum = self.consts[attr_spec['enum']]
529        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
530            scalar = 0
531            if isinstance(value, str):
532                value = [value]
533            for single_value in value:
534                scalar += enum.entries[single_value].user_value(as_flags = True)
535            return scalar
536        else:
537            return enum.entries[value].user_value()
538
539    def _get_scalar(self, attr_spec, value):
540        try:
541            return int(value)
542        except (ValueError, TypeError) as e:
543            if 'enum' in attr_spec:
544                return self._encode_enum(attr_spec, value)
545            if attr_spec.display_hint:
546                return self._from_string(value, attr_spec)
547            raise e
548
549    def _add_attr(self, space, name, value, search_attrs):
550        try:
551            attr = self.attr_sets[space][name]
552        except KeyError:
553            raise Exception(f"Space '{space}' has no attribute '{name}'")
554        nl_type = attr.value
555
556        if attr.is_multi and isinstance(value, list):
557            attr_payload = b''
558            for subvalue in value:
559                attr_payload += self._add_attr(space, name, subvalue, search_attrs)
560            return attr_payload
561
562        if attr["type"] == 'nest':
563            nl_type |= Netlink.NLA_F_NESTED
564            attr_payload = b''
565            sub_space = attr['nested-attributes']
566            sub_attrs = SpaceAttrs(self.attr_sets[sub_space], value, search_attrs)
567            for subname, subvalue in value.items():
568                attr_payload += self._add_attr(sub_space, subname, subvalue, sub_attrs)
569        elif attr["type"] == 'flag':
570            if not value:
571                # If value is absent or false then skip attribute creation.
572                return b''
573            attr_payload = b''
574        elif attr["type"] == 'string':
575            attr_payload = str(value).encode('ascii') + b'\x00'
576        elif attr["type"] == 'binary':
577            if value is None:
578                attr_payload = b''
579            elif isinstance(value, bytes):
580                attr_payload = value
581            elif isinstance(value, str):
582                if attr.display_hint:
583                    attr_payload = self._from_string(value, attr)
584                else:
585                    attr_payload = bytes.fromhex(value)
586            elif isinstance(value, dict) and attr.struct_name:
587                attr_payload = self._encode_struct(attr.struct_name, value)
588            elif isinstance(value, list) and attr.sub_type in NlAttr.type_formats:
589                format = NlAttr.get_format(attr.sub_type)
590                attr_payload = b''.join([format.pack(x) for x in value])
591            else:
592                raise Exception(f'Unknown type for binary attribute, value: {value}')
593        elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar:
594            scalar = self._get_scalar(attr, value)
595            if attr.is_auto_scalar:
596                attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64')
597            else:
598                attr_type = attr["type"]
599            format = NlAttr.get_format(attr_type, attr.byte_order)
600            attr_payload = format.pack(scalar)
601        elif attr['type'] in "bitfield32":
602            scalar_value = self._get_scalar(attr, value["value"])
603            scalar_selector = self._get_scalar(attr, value["selector"])
604            attr_payload = struct.pack("II", scalar_value, scalar_selector)
605        elif attr['type'] == 'sub-message':
606            msg_format, _ = self._resolve_selector(attr, search_attrs)
607            attr_payload = b''
608            if msg_format.fixed_header:
609                attr_payload += self._encode_struct(msg_format.fixed_header, value)
610            if msg_format.attr_set:
611                if msg_format.attr_set in self.attr_sets:
612                    nl_type |= Netlink.NLA_F_NESTED
613                    sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs)
614                    for subname, subvalue in value.items():
615                        attr_payload += self._add_attr(msg_format.attr_set,
616                                                       subname, subvalue, sub_attrs)
617                else:
618                    raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'")
619        else:
620            raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
621
622        pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
623        return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
624
625    def _get_enum_or_unknown(self, enum, raw):
626        try:
627            name = enum.entries_by_val[raw].name
628        except KeyError as error:
629            if self.process_unknown:
630                name = f"Unknown({raw})"
631            else:
632                raise error
633        return name
634
635    def _decode_enum(self, raw, attr_spec):
636        enum = self.consts[attr_spec['enum']]
637        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
638            i = 0
639            value = set()
640            while raw:
641                if raw & 1:
642                    value.add(self._get_enum_or_unknown(enum, i))
643                raw >>= 1
644                i += 1
645        else:
646            value = self._get_enum_or_unknown(enum, raw)
647        return value
648
649    def _decode_binary(self, attr, attr_spec):
650        if attr_spec.struct_name:
651            decoded = self._decode_struct(attr.raw, attr_spec.struct_name)
652        elif attr_spec.sub_type:
653            decoded = attr.as_c_array(attr_spec.sub_type)
654            if 'enum' in attr_spec:
655                decoded = [ self._decode_enum(x, attr_spec) for x in decoded ]
656            elif attr_spec.display_hint:
657                decoded = [ self._formatted_string(x, attr_spec.display_hint)
658                            for x in decoded ]
659        else:
660            decoded = attr.as_bin()
661            if attr_spec.display_hint:
662                decoded = self._formatted_string(decoded, attr_spec.display_hint)
663        return decoded
664
665    def _decode_array_attr(self, attr, attr_spec):
666        decoded = []
667        offset = 0
668        while offset < len(attr.raw):
669            item = NlAttr(attr.raw, offset)
670            offset += item.full_len
671
672            if attr_spec["sub-type"] == 'nest':
673                subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes'])
674                decoded.append({ item.type: subattrs })
675            elif attr_spec["sub-type"] == 'binary':
676                subattr = item.as_bin()
677                if attr_spec.display_hint:
678                    subattr = self._formatted_string(subattr, attr_spec.display_hint)
679                decoded.append(subattr)
680            elif attr_spec["sub-type"] in NlAttr.type_formats:
681                subattr = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order)
682                if 'enum' in attr_spec:
683                    subattr = self._decode_enum(subattr, attr_spec)
684                elif attr_spec.display_hint:
685                    subattr = self._formatted_string(subattr, attr_spec.display_hint)
686                decoded.append(subattr)
687            else:
688                raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}')
689        return decoded
690
691    def _decode_nest_type_value(self, attr, attr_spec):
692        decoded = {}
693        value = attr
694        for name in attr_spec['type-value']:
695            value = NlAttr(value.raw, 0)
696            decoded[name] = value.type
697        subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes'])
698        decoded.update(subattrs)
699        return decoded
700
701    def _decode_unknown(self, attr):
702        if attr.is_nest:
703            return self._decode(NlAttrs(attr.raw), None)
704        else:
705            return attr.as_bin()
706
707    def _rsp_add(self, rsp, name, is_multi, decoded):
708        if is_multi is None:
709            if name in rsp and type(rsp[name]) is not list:
710                rsp[name] = [rsp[name]]
711                is_multi = True
712            else:
713                is_multi = False
714
715        if not is_multi:
716            rsp[name] = decoded
717        elif name in rsp:
718            rsp[name].append(decoded)
719        else:
720            rsp[name] = [decoded]
721
722    def _resolve_selector(self, attr_spec, search_attrs):
723        sub_msg = attr_spec.sub_message
724        if sub_msg not in self.sub_msgs:
725            raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}")
726        sub_msg_spec = self.sub_msgs[sub_msg]
727
728        selector = attr_spec.selector
729        value = search_attrs.lookup(selector)
730        if value not in sub_msg_spec.formats:
731            raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'")
732
733        spec = sub_msg_spec.formats[value]
734        return spec, value
735
736    def _decode_sub_msg(self, attr, attr_spec, search_attrs):
737        msg_format, _ = self._resolve_selector(attr_spec, search_attrs)
738        decoded = {}
739        offset = 0
740        if msg_format.fixed_header:
741            decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header))
742            offset = self._struct_size(msg_format.fixed_header)
743        if msg_format.attr_set:
744            if msg_format.attr_set in self.attr_sets:
745                subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set)
746                decoded.update(subdict)
747            else:
748                raise Exception(f"Unknown attribute-set '{msg_format.attr_set}' when decoding '{attr_spec.name}'")
749        return decoded
750
751    def _decode(self, attrs, space, outer_attrs = None):
752        rsp = dict()
753        if space:
754            attr_space = self.attr_sets[space]
755            search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs)
756
757        for attr in attrs:
758            try:
759                attr_spec = attr_space.attrs_by_val[attr.type]
760            except (KeyError, UnboundLocalError):
761                if not self.process_unknown:
762                    raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'")
763                attr_name = f"UnknownAttr({attr.type})"
764                self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr))
765                continue
766
767            try:
768                if attr_spec["type"] == 'nest':
769                    subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs)
770                    decoded = subdict
771                elif attr_spec["type"] == 'string':
772                    decoded = attr.as_strz()
773                elif attr_spec["type"] == 'binary':
774                    decoded = self._decode_binary(attr, attr_spec)
775                elif attr_spec["type"] == 'flag':
776                    decoded = True
777                elif attr_spec.is_auto_scalar:
778                    decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order)
779                    if 'enum' in attr_spec:
780                        decoded = self._decode_enum(decoded, attr_spec)
781                elif attr_spec["type"] in NlAttr.type_formats:
782                    decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
783                    if 'enum' in attr_spec:
784                        decoded = self._decode_enum(decoded, attr_spec)
785                    elif attr_spec.display_hint:
786                        decoded = self._formatted_string(decoded, attr_spec.display_hint)
787                elif attr_spec["type"] == 'indexed-array':
788                    decoded = self._decode_array_attr(attr, attr_spec)
789                elif attr_spec["type"] == 'bitfield32':
790                    value, selector = struct.unpack("II", attr.raw)
791                    if 'enum' in attr_spec:
792                        value = self._decode_enum(value, attr_spec)
793                        selector = self._decode_enum(selector, attr_spec)
794                    decoded = {"value": value, "selector": selector}
795                elif attr_spec["type"] == 'sub-message':
796                    decoded = self._decode_sub_msg(attr, attr_spec, search_attrs)
797                elif attr_spec["type"] == 'nest-type-value':
798                    decoded = self._decode_nest_type_value(attr, attr_spec)
799                else:
800                    if not self.process_unknown:
801                        raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
802                    decoded = self._decode_unknown(attr)
803
804                self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded)
805            except:
806                print(f"Error decoding '{attr_spec.name}' from '{space}'")
807                raise
808
809        return rsp
810
811    def _decode_extack_path(self, attrs, attr_set, offset, target, search_attrs):
812        for attr in attrs:
813            try:
814                attr_spec = attr_set.attrs_by_val[attr.type]
815            except KeyError:
816                raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'")
817            if offset > target:
818                break
819            if offset == target:
820                return '.' + attr_spec.name
821
822            if offset + attr.full_len <= target:
823                offset += attr.full_len
824                continue
825
826            pathname = attr_spec.name
827            if attr_spec['type'] == 'nest':
828                sub_attrs = self.attr_sets[attr_spec['nested-attributes']]
829                search_attrs = SpaceAttrs(sub_attrs, search_attrs.lookup(attr_spec['name']))
830            elif attr_spec['type'] == 'sub-message':
831                msg_format, value = self._resolve_selector(attr_spec, search_attrs)
832                if msg_format is None:
833                    raise Exception(f"Can't resolve sub-message of {attr_spec['name']} for extack")
834                sub_attrs = self.attr_sets[msg_format.attr_set]
835                pathname += f"({value})"
836            else:
837                raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
838            offset += 4
839            subpath = self._decode_extack_path(NlAttrs(attr.raw), sub_attrs,
840                                               offset, target, search_attrs)
841            if subpath is None:
842                return None
843            return '.' + pathname + subpath
844
845        return None
846
847    def _decode_extack(self, request, op, extack, vals):
848        if 'bad-attr-offs' not in extack:
849            return
850
851        msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op)
852        offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header)
853        search_attrs = SpaceAttrs(op.attr_set, vals)
854        path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
855                                        extack['bad-attr-offs'], search_attrs)
856        if path:
857            del extack['bad-attr-offs']
858            extack['bad-attr'] = path
859
860    def _struct_size(self, name):
861        if name:
862            members = self.consts[name].members
863            size = 0
864            for m in members:
865                if m.type in ['pad', 'binary']:
866                    if m.struct:
867                        size += self._struct_size(m.struct)
868                    else:
869                        size += m.len
870                else:
871                    format = NlAttr.get_format(m.type, m.byte_order)
872                    size += format.size
873            return size
874        else:
875            return 0
876
877    def _decode_struct(self, data, name):
878        members = self.consts[name].members
879        attrs = dict()
880        offset = 0
881        for m in members:
882            value = None
883            if m.type == 'pad':
884                offset += m.len
885            elif m.type == 'binary':
886                if m.struct:
887                    len = self._struct_size(m.struct)
888                    value = self._decode_struct(data[offset : offset + len],
889                                                m.struct)
890                    offset += len
891                else:
892                    value = data[offset : offset + m.len]
893                    offset += m.len
894            else:
895                format = NlAttr.get_format(m.type, m.byte_order)
896                [ value ] = format.unpack_from(data, offset)
897                offset += format.size
898            if value is not None:
899                if m.enum:
900                    value = self._decode_enum(value, m)
901                elif m.display_hint:
902                    value = self._formatted_string(value, m.display_hint)
903                attrs[m.name] = value
904        return attrs
905
906    def _encode_struct(self, name, vals):
907        members = self.consts[name].members
908        attr_payload = b''
909        for m in members:
910            value = vals.pop(m.name) if m.name in vals else None
911            if m.type == 'pad':
912                attr_payload += bytearray(m.len)
913            elif m.type == 'binary':
914                if m.struct:
915                    if value is None:
916                        value = dict()
917                    attr_payload += self._encode_struct(m.struct, value)
918                else:
919                    if value is None:
920                        attr_payload += bytearray(m.len)
921                    else:
922                        attr_payload += bytes.fromhex(value)
923            else:
924                if value is None:
925                    value = 0
926                format = NlAttr.get_format(m.type, m.byte_order)
927                attr_payload += format.pack(value)
928        return attr_payload
929
930    def _formatted_string(self, raw, display_hint):
931        if display_hint == 'mac':
932            formatted = ':'.join('%02x' % b for b in raw)
933        elif display_hint == 'hex':
934            if isinstance(raw, int):
935                formatted = hex(raw)
936            else:
937                formatted = bytes.hex(raw, ' ')
938        elif display_hint in [ 'ipv4', 'ipv6' ]:
939            formatted = format(ipaddress.ip_address(raw))
940        elif display_hint == 'uuid':
941            formatted = str(uuid.UUID(bytes=raw))
942        else:
943            formatted = raw
944        return formatted
945
946    def _from_string(self, string, attr_spec):
947        if attr_spec.display_hint in ['ipv4', 'ipv6']:
948            ip = ipaddress.ip_address(string)
949            if attr_spec['type'] == 'binary':
950                raw = ip.packed
951            else:
952                raw = int(ip)
953        else:
954            raise Exception(f"Display hint '{attr_spec.display_hint}' not implemented"
955                            f" when parsing '{attr_spec['name']}'")
956        return raw
957
958    def handle_ntf(self, decoded):
959        msg = dict()
960        if self.include_raw:
961            msg['raw'] = decoded
962        op = self.rsp_by_value[decoded.cmd()]
963        attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
964        if op.fixed_header:
965            attrs.update(self._decode_struct(decoded.raw, op.fixed_header))
966
967        msg['name'] = op['name']
968        msg['msg'] = attrs
969        self.async_msg_queue.put(msg)
970
971    def check_ntf(self):
972        while True:
973            try:
974                reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT)
975            except BlockingIOError:
976                return
977
978            nms = NlMsgs(reply)
979            self._recv_dbg_print(reply, nms)
980            for nl_msg in nms:
981                if nl_msg.error:
982                    print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
983                    print(nl_msg)
984                    continue
985                if nl_msg.done:
986                    print("Netlink done while checking for ntf!?")
987                    continue
988
989                decoded = self.nlproto.decode(self, nl_msg, None)
990                if decoded.cmd() not in self.async_msg_ids:
991                    print("Unexpected msg id while checking for ntf", decoded)
992                    continue
993
994                self.handle_ntf(decoded)
995
996    def poll_ntf(self, duration=None):
997        start_time = time.time()
998        selector = selectors.DefaultSelector()
999        selector.register(self.sock, selectors.EVENT_READ)
1000
1001        while True:
1002            try:
1003                yield self.async_msg_queue.get_nowait()
1004            except queue.Empty:
1005                if duration is not None:
1006                    timeout = start_time + duration - time.time()
1007                    if timeout <= 0:
1008                        return
1009                else:
1010                    timeout = None
1011                events = selector.select(timeout)
1012                if events:
1013                    self.check_ntf()
1014
1015    def operation_do_attributes(self, name):
1016      """
1017      For a given operation name, find and return a supported
1018      set of attributes (as a dict).
1019      """
1020      op = self.find_operation(name)
1021      if not op:
1022        return None
1023
1024      return op['do']['request']['attributes'].copy()
1025
1026    def _encode_message(self, op, vals, flags, req_seq):
1027        nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
1028        for flag in flags or []:
1029            nl_flags |= flag
1030
1031        msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
1032        if op.fixed_header:
1033            msg += self._encode_struct(op.fixed_header, vals)
1034        search_attrs = SpaceAttrs(op.attr_set, vals)
1035        for name, value in vals.items():
1036            msg += self._add_attr(op.attr_set.name, name, value, search_attrs)
1037        msg = _genl_msg_finalize(msg)
1038        return msg
1039
1040    def _ops(self, ops):
1041        reqs_by_seq = {}
1042        req_seq = random.randint(1024, 65535)
1043        payload = b''
1044        for (method, vals, flags) in ops:
1045            op = self.ops[method]
1046            msg = self._encode_message(op, vals, flags, req_seq)
1047            reqs_by_seq[req_seq] = (op, vals, msg, flags)
1048            payload += msg
1049            req_seq += 1
1050
1051        self.sock.send(payload, 0)
1052
1053        done = False
1054        rsp = []
1055        op_rsp = []
1056        while not done:
1057            reply = self.sock.recv(self._recv_size)
1058            nms = NlMsgs(reply)
1059            self._recv_dbg_print(reply, nms)
1060            for nl_msg in nms:
1061                if nl_msg.nl_seq in reqs_by_seq:
1062                    (op, vals, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq]
1063                    if nl_msg.extack:
1064                        nl_msg.annotate_extack(op.attr_set)
1065                        self._decode_extack(req_msg, op, nl_msg.extack, vals)
1066                else:
1067                    op = None
1068                    req_flags = []
1069
1070                if nl_msg.error:
1071                    raise NlError(nl_msg)
1072                if nl_msg.done:
1073                    if nl_msg.extack:
1074                        print("Netlink warning:")
1075                        print(nl_msg)
1076
1077                    if Netlink.NLM_F_DUMP in req_flags:
1078                        rsp.append(op_rsp)
1079                    elif not op_rsp:
1080                        rsp.append(None)
1081                    elif len(op_rsp) == 1:
1082                        rsp.append(op_rsp[0])
1083                    else:
1084                        rsp.append(op_rsp)
1085                    op_rsp = []
1086
1087                    del reqs_by_seq[nl_msg.nl_seq]
1088                    done = len(reqs_by_seq) == 0
1089                    break
1090
1091                decoded = self.nlproto.decode(self, nl_msg, op)
1092
1093                # Check if this is a reply to our request
1094                if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value:
1095                    if decoded.cmd() in self.async_msg_ids:
1096                        self.handle_ntf(decoded)
1097                        continue
1098                    else:
1099                        print('Unexpected message: ' + repr(decoded))
1100                        continue
1101
1102                rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
1103                if op.fixed_header:
1104                    rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header))
1105                op_rsp.append(rsp_msg)
1106
1107        return rsp
1108
1109    def _op(self, method, vals, flags=None, dump=False):
1110        req_flags = flags or []
1111        if dump:
1112            req_flags.append(Netlink.NLM_F_DUMP)
1113
1114        ops = [(method, vals, req_flags)]
1115        return self._ops(ops)[0]
1116
1117    def do(self, method, vals, flags=None):
1118        return self._op(method, vals, flags)
1119
1120    def dump(self, method, vals):
1121        return self._op(method, vals, dump=True)
1122
1123    def do_multi(self, ops):
1124        return self._ops(ops)
1125