xref: /linux/tools/net/ynl/pyynl/lib/ynl.py (revision 569a5d63fd1b2b106a2134625f4ee7a7f9056996)
1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2
3from collections import namedtuple
4from enum import Enum
5import functools
6import os
7import random
8import socket
9import struct
10from struct import Struct
11import sys
12import yaml
13import ipaddress
14import uuid
15import queue
16import selectors
17import time
18
19from .nlspec import SpecFamily
20
21#
22# Generic Netlink code which should really be in some library, but I can't quickly find one.
23#
24
25
26class Netlink:
27    # Netlink socket
28    SOL_NETLINK = 270
29
30    NETLINK_ADD_MEMBERSHIP = 1
31    NETLINK_CAP_ACK = 10
32    NETLINK_EXT_ACK = 11
33    NETLINK_GET_STRICT_CHK = 12
34
35    # Netlink message
36    NLMSG_ERROR = 2
37    NLMSG_DONE = 3
38
39    NLM_F_REQUEST = 1
40    NLM_F_ACK = 4
41    NLM_F_ROOT = 0x100
42    NLM_F_MATCH = 0x200
43
44    NLM_F_REPLACE = 0x100
45    NLM_F_EXCL = 0x200
46    NLM_F_CREATE = 0x400
47    NLM_F_APPEND = 0x800
48
49    NLM_F_CAPPED = 0x100
50    NLM_F_ACK_TLVS = 0x200
51
52    NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
53
54    NLA_F_NESTED = 0x8000
55    NLA_F_NET_BYTEORDER = 0x4000
56
57    NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
58
59    # Genetlink defines
60    NETLINK_GENERIC = 16
61
62    GENL_ID_CTRL = 0x10
63
64    # nlctrl
65    CTRL_CMD_GETFAMILY = 3
66
67    CTRL_ATTR_FAMILY_ID = 1
68    CTRL_ATTR_FAMILY_NAME = 2
69    CTRL_ATTR_MAXATTR = 5
70    CTRL_ATTR_MCAST_GROUPS = 7
71
72    CTRL_ATTR_MCAST_GRP_NAME = 1
73    CTRL_ATTR_MCAST_GRP_ID = 2
74
75    # Extack types
76    NLMSGERR_ATTR_MSG = 1
77    NLMSGERR_ATTR_OFFS = 2
78    NLMSGERR_ATTR_COOKIE = 3
79    NLMSGERR_ATTR_POLICY = 4
80    NLMSGERR_ATTR_MISS_TYPE = 5
81    NLMSGERR_ATTR_MISS_NEST = 6
82
83    # Policy types
84    NL_POLICY_TYPE_ATTR_TYPE = 1
85    NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2
86    NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3
87    NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4
88    NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5
89    NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6
90    NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7
91    NL_POLICY_TYPE_ATTR_POLICY_IDX = 8
92    NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9
93    NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10
94    NL_POLICY_TYPE_ATTR_PAD = 11
95    NL_POLICY_TYPE_ATTR_MASK = 12
96
97    AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64',
98                                  's8', 's16', 's32', 's64',
99                                  'binary', 'string', 'nul-string',
100                                  'nested', 'nested-array',
101                                  'bitfield32', 'sint', 'uint'])
102
103class NlError(Exception):
104  def __init__(self, nl_msg):
105    self.nl_msg = nl_msg
106    self.error = -nl_msg.error
107
108  def __str__(self):
109    return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}"
110
111
112class ConfigError(Exception):
113    pass
114
115
116class NlAttr:
117    ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
118    type_formats = {
119        'u8' : ScalarFormat(Struct('B'), Struct("B"),  Struct("B")),
120        's8' : ScalarFormat(Struct('b'), Struct("b"),  Struct("b")),
121        'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
122        's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
123        'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
124        's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
125        'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
126        's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
127    }
128
129    def __init__(self, raw, offset):
130        self._len, self._type = struct.unpack("HH", raw[offset : offset + 4])
131        self.type = self._type & ~Netlink.NLA_TYPE_MASK
132        self.is_nest = self._type & Netlink.NLA_F_NESTED
133        self.payload_len = self._len
134        self.full_len = (self.payload_len + 3) & ~3
135        self.raw = raw[offset + 4 : offset + self.payload_len]
136
137    @classmethod
138    def get_format(cls, attr_type, byte_order=None):
139        format = cls.type_formats[attr_type]
140        if byte_order:
141            return format.big if byte_order == "big-endian" \
142                else format.little
143        return format.native
144
145    def as_scalar(self, attr_type, byte_order=None):
146        format = self.get_format(attr_type, byte_order)
147        return format.unpack(self.raw)[0]
148
149    def as_auto_scalar(self, attr_type, byte_order=None):
150        if len(self.raw) != 4 and len(self.raw) != 8:
151            raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}")
152        real_type = attr_type[0] + str(len(self.raw) * 8)
153        format = self.get_format(real_type, byte_order)
154        return format.unpack(self.raw)[0]
155
156    def as_strz(self):
157        return self.raw.decode('ascii')[:-1]
158
159    def as_bin(self):
160        return self.raw
161
162    def as_c_array(self, type):
163        format = self.get_format(type)
164        return [ x[0] for x in format.iter_unpack(self.raw) ]
165
166    def __repr__(self):
167        return f"[type:{self.type} len:{self._len}] {self.raw}"
168
169
170class NlAttrs:
171    def __init__(self, msg, offset=0):
172        self.attrs = []
173
174        while offset < len(msg):
175            attr = NlAttr(msg, offset)
176            offset += attr.full_len
177            self.attrs.append(attr)
178
179    def __iter__(self):
180        yield from self.attrs
181
182    def __repr__(self):
183        msg = ''
184        for a in self.attrs:
185            if msg:
186                msg += '\n'
187            msg += repr(a)
188        return msg
189
190
191class NlMsg:
192    def __init__(self, msg, offset, attr_space=None):
193        self.hdr = msg[offset : offset + 16]
194
195        self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
196            struct.unpack("IHHII", self.hdr)
197
198        self.raw = msg[offset + 16 : offset + self.nl_len]
199
200        self.error = 0
201        self.done = 0
202
203        extack_off = None
204        if self.nl_type == Netlink.NLMSG_ERROR:
205            self.error = struct.unpack("i", self.raw[0:4])[0]
206            self.done = 1
207            extack_off = 20
208        elif self.nl_type == Netlink.NLMSG_DONE:
209            self.error = struct.unpack("i", self.raw[0:4])[0]
210            self.done = 1
211            extack_off = 4
212
213        self.extack = None
214        if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
215            self.extack = dict()
216            extack_attrs = NlAttrs(self.raw[extack_off:])
217            for extack in extack_attrs:
218                if extack.type == Netlink.NLMSGERR_ATTR_MSG:
219                    self.extack['msg'] = extack.as_strz()
220                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
221                    self.extack['miss-type'] = extack.as_scalar('u32')
222                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
223                    self.extack['miss-nest'] = extack.as_scalar('u32')
224                elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
225                    self.extack['bad-attr-offs'] = extack.as_scalar('u32')
226                elif extack.type == Netlink.NLMSGERR_ATTR_POLICY:
227                    self.extack['policy'] = self._decode_policy(extack.raw)
228                else:
229                    if 'unknown' not in self.extack:
230                        self.extack['unknown'] = []
231                    self.extack['unknown'].append(extack)
232
233            if attr_space:
234                # We don't have the ability to parse nests yet, so only do global
235                if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
236                    miss_type = self.extack['miss-type']
237                    if miss_type in attr_space.attrs_by_val:
238                        spec = attr_space.attrs_by_val[miss_type]
239                        self.extack['miss-type'] = spec['name']
240                        if 'doc' in spec:
241                            self.extack['miss-type-doc'] = spec['doc']
242
243    def _decode_policy(self, raw):
244        policy = {}
245        for attr in NlAttrs(raw):
246            if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE:
247                type = attr.as_scalar('u32')
248                policy['type'] = Netlink.AttrType(type).name
249            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S:
250                policy['min-value'] = attr.as_scalar('s64')
251            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S:
252                policy['max-value'] = attr.as_scalar('s64')
253            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U:
254                policy['min-value'] = attr.as_scalar('u64')
255            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U:
256                policy['max-value'] = attr.as_scalar('u64')
257            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH:
258                policy['min-length'] = attr.as_scalar('u32')
259            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH:
260                policy['max-length'] = attr.as_scalar('u32')
261            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK:
262                policy['bitfield32-mask'] = attr.as_scalar('u32')
263            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK:
264                policy['mask'] = attr.as_scalar('u64')
265        return policy
266
267    def cmd(self):
268        return self.nl_type
269
270    def __repr__(self):
271        msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}"
272        if self.error:
273            msg += '\n\terror: ' + str(self.error)
274        if self.extack:
275            msg += '\n\textack: ' + repr(self.extack)
276        return msg
277
278
279class NlMsgs:
280    def __init__(self, data, attr_space=None):
281        self.msgs = []
282
283        offset = 0
284        while offset < len(data):
285            msg = NlMsg(data, offset, attr_space=attr_space)
286            offset += msg.nl_len
287            self.msgs.append(msg)
288
289    def __iter__(self):
290        yield from self.msgs
291
292
293genl_family_name_to_id = None
294
295
296def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
297    # we prepend length in _genl_msg_finalize()
298    if seq is None:
299        seq = random.randint(1, 1024)
300    nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
301    genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
302    return nlmsg + genlmsg
303
304
305def _genl_msg_finalize(msg):
306    return struct.pack("I", len(msg) + 4) + msg
307
308
309def _genl_load_families():
310    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
311        sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
312
313        msg = _genl_msg(Netlink.GENL_ID_CTRL,
314                        Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
315                        Netlink.CTRL_CMD_GETFAMILY, 1)
316        msg = _genl_msg_finalize(msg)
317
318        sock.send(msg, 0)
319
320        global genl_family_name_to_id
321        genl_family_name_to_id = dict()
322
323        while True:
324            reply = sock.recv(128 * 1024)
325            nms = NlMsgs(reply)
326            for nl_msg in nms:
327                if nl_msg.error:
328                    print("Netlink error:", nl_msg.error)
329                    return
330                if nl_msg.done:
331                    return
332
333                gm = GenlMsg(nl_msg)
334                fam = dict()
335                for attr in NlAttrs(gm.raw):
336                    if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
337                        fam['id'] = attr.as_scalar('u16')
338                    elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
339                        fam['name'] = attr.as_strz()
340                    elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
341                        fam['maxattr'] = attr.as_scalar('u32')
342                    elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
343                        fam['mcast'] = dict()
344                        for entry in NlAttrs(attr.raw):
345                            mcast_name = None
346                            mcast_id = None
347                            for entry_attr in NlAttrs(entry.raw):
348                                if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
349                                    mcast_name = entry_attr.as_strz()
350                                elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
351                                    mcast_id = entry_attr.as_scalar('u32')
352                            if mcast_name and mcast_id is not None:
353                                fam['mcast'][mcast_name] = mcast_id
354                if 'name' in fam and 'id' in fam:
355                    genl_family_name_to_id[fam['name']] = fam
356
357
358class GenlMsg:
359    def __init__(self, nl_msg):
360        self.nl = nl_msg
361        self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
362        self.raw = nl_msg.raw[4:]
363
364    def cmd(self):
365        return self.genl_cmd
366
367    def __repr__(self):
368        msg = repr(self.nl)
369        msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
370        for a in self.raw_attrs:
371            msg += '\t\t' + repr(a) + '\n'
372        return msg
373
374
375class NetlinkProtocol:
376    def __init__(self, family_name, proto_num):
377        self.family_name = family_name
378        self.proto_num = proto_num
379
380    def _message(self, nl_type, nl_flags, seq=None):
381        if seq is None:
382            seq = random.randint(1, 1024)
383        nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
384        return nlmsg
385
386    def message(self, flags, command, version, seq=None):
387        return self._message(command, flags, seq)
388
389    def _decode(self, nl_msg):
390        return nl_msg
391
392    def decode(self, ynl, nl_msg, op):
393        msg = self._decode(nl_msg)
394        if op is None:
395            op = ynl.rsp_by_value[msg.cmd()]
396        fixed_header_size = ynl._struct_size(op.fixed_header)
397        msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size)
398        return msg
399
400    def get_mcast_id(self, mcast_name, mcast_groups):
401        if mcast_name not in mcast_groups:
402            raise Exception(f'Multicast group "{mcast_name}" not present in the spec')
403        return mcast_groups[mcast_name].value
404
405    def msghdr_size(self):
406        return 16
407
408
409class GenlProtocol(NetlinkProtocol):
410    def __init__(self, family_name):
411        super().__init__(family_name, Netlink.NETLINK_GENERIC)
412
413        global genl_family_name_to_id
414        if genl_family_name_to_id is None:
415            _genl_load_families()
416
417        self.genl_family = genl_family_name_to_id[family_name]
418        self.family_id = genl_family_name_to_id[family_name]['id']
419
420    def message(self, flags, command, version, seq=None):
421        nlmsg = self._message(self.family_id, flags, seq)
422        genlmsg = struct.pack("BBH", command, version, 0)
423        return nlmsg + genlmsg
424
425    def _decode(self, nl_msg):
426        return GenlMsg(nl_msg)
427
428    def get_mcast_id(self, mcast_name, mcast_groups):
429        if mcast_name not in self.genl_family['mcast']:
430            raise Exception(f'Multicast group "{mcast_name}" not present in the family')
431        return self.genl_family['mcast'][mcast_name]
432
433    def msghdr_size(self):
434        return super().msghdr_size() + 4
435
436
437class SpaceAttrs:
438    SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values'])
439
440    def __init__(self, attr_space, attrs, outer = None):
441        outer_scopes = outer.scopes if outer else []
442        inner_scope = self.SpecValuesPair(attr_space, attrs)
443        self.scopes = [inner_scope] + outer_scopes
444
445    def lookup(self, name):
446        for scope in self.scopes:
447            if name in scope.spec:
448                if name in scope.values:
449                    return scope.values[name]
450                spec_name = scope.spec.yaml['name']
451                raise Exception(
452                    f"No value for '{name}' in attribute space '{spec_name}'")
453        raise Exception(f"Attribute '{name}' not defined in any attribute-set")
454
455
456#
457# YNL implementation details.
458#
459
460
461class YnlFamily(SpecFamily):
462    def __init__(self, def_path, schema=None, process_unknown=False,
463                 recv_size=0):
464        super().__init__(def_path, schema)
465
466        self.include_raw = False
467        self.process_unknown = process_unknown
468
469        try:
470            if self.proto == "netlink-raw":
471                self.nlproto = NetlinkProtocol(self.yaml['name'],
472                                               self.yaml['protonum'])
473            else:
474                self.nlproto = GenlProtocol(self.yaml['name'])
475        except KeyError:
476            raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
477
478        self._recv_dbg = False
479        # Note that netlink will use conservative (min) message size for
480        # the first dump recv() on the socket, our setting will only matter
481        # from the second recv() on.
482        self._recv_size = recv_size if recv_size else 131072
483        # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo)
484        # for a message, so smaller receive sizes will lead to truncation.
485        # Note that the min size for other families may be larger than 4k!
486        if self._recv_size < 4000:
487            raise ConfigError()
488
489        self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
490        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
491        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
492        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)
493
494        self.async_msg_ids = set()
495        self.async_msg_queue = queue.Queue()
496
497        for msg in self.msgs.values():
498            if msg.is_async:
499                self.async_msg_ids.add(msg.rsp_value)
500
501        for op_name, op in self.ops.items():
502            bound_f = functools.partial(self._op, op_name)
503            setattr(self, op.ident_name, bound_f)
504
505
506    def ntf_subscribe(self, mcast_name):
507        mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
508        self.sock.bind((0, 0))
509        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
510                             mcast_id)
511
512    def set_recv_dbg(self, enabled):
513        self._recv_dbg = enabled
514
515    def _recv_dbg_print(self, reply, nl_msgs):
516        if not self._recv_dbg:
517            return
518        print("Recv: read", len(reply), "bytes,",
519              len(nl_msgs.msgs), "messages", file=sys.stderr)
520        for nl_msg in nl_msgs:
521            print("  ", nl_msg, file=sys.stderr)
522
523    def _encode_enum(self, attr_spec, value):
524        enum = self.consts[attr_spec['enum']]
525        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
526            scalar = 0
527            if isinstance(value, str):
528                value = [value]
529            for single_value in value:
530                scalar += enum.entries[single_value].user_value(as_flags = True)
531            return scalar
532        else:
533            return enum.entries[value].user_value()
534
535    def _get_scalar(self, attr_spec, value):
536        try:
537            return int(value)
538        except (ValueError, TypeError) as e:
539            if 'enum' not in attr_spec:
540                raise e
541        return self._encode_enum(attr_spec, value)
542
543    def _add_attr(self, space, name, value, search_attrs):
544        try:
545            attr = self.attr_sets[space][name]
546        except KeyError:
547            raise Exception(f"Space '{space}' has no attribute '{name}'")
548        nl_type = attr.value
549
550        if attr.is_multi and isinstance(value, list):
551            attr_payload = b''
552            for subvalue in value:
553                attr_payload += self._add_attr(space, name, subvalue, search_attrs)
554            return attr_payload
555
556        if attr["type"] == 'nest':
557            nl_type |= Netlink.NLA_F_NESTED
558            attr_payload = b''
559            sub_space = attr['nested-attributes']
560            sub_attrs = SpaceAttrs(self.attr_sets[sub_space], value, search_attrs)
561            for subname, subvalue in value.items():
562                attr_payload += self._add_attr(sub_space, subname, subvalue, sub_attrs)
563        elif attr["type"] == 'flag':
564            if not value:
565                # If value is absent or false then skip attribute creation.
566                return b''
567            attr_payload = b''
568        elif attr["type"] == 'string':
569            attr_payload = str(value).encode('ascii') + b'\x00'
570        elif attr["type"] == 'binary':
571            if isinstance(value, bytes):
572                attr_payload = value
573            elif isinstance(value, str):
574                attr_payload = bytes.fromhex(value)
575            elif isinstance(value, dict) and attr.struct_name:
576                attr_payload = self._encode_struct(attr.struct_name, value)
577            else:
578                raise Exception(f'Unknown type for binary attribute, value: {value}')
579        elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar:
580            scalar = self._get_scalar(attr, value)
581            if attr.is_auto_scalar:
582                attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64')
583            else:
584                attr_type = attr["type"]
585            format = NlAttr.get_format(attr_type, attr.byte_order)
586            attr_payload = format.pack(scalar)
587        elif attr['type'] in "bitfield32":
588            scalar_value = self._get_scalar(attr, value["value"])
589            scalar_selector = self._get_scalar(attr, value["selector"])
590            attr_payload = struct.pack("II", scalar_value, scalar_selector)
591        elif attr['type'] == 'sub-message':
592            msg_format = self._resolve_selector(attr, search_attrs)
593            attr_payload = b''
594            if msg_format.fixed_header:
595                attr_payload += self._encode_struct(msg_format.fixed_header, value)
596            if msg_format.attr_set:
597                if msg_format.attr_set in self.attr_sets:
598                    nl_type |= Netlink.NLA_F_NESTED
599                    sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs)
600                    for subname, subvalue in value.items():
601                        attr_payload += self._add_attr(msg_format.attr_set,
602                                                       subname, subvalue, sub_attrs)
603                else:
604                    raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'")
605        else:
606            raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
607
608        pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
609        return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
610
611    def _decode_enum(self, raw, attr_spec):
612        enum = self.consts[attr_spec['enum']]
613        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
614            i = 0
615            value = set()
616            while raw:
617                if raw & 1:
618                    value.add(enum.entries_by_val[i].name)
619                raw >>= 1
620                i += 1
621        else:
622            value = enum.entries_by_val[raw].name
623        return value
624
625    def _decode_binary(self, attr, attr_spec):
626        if attr_spec.struct_name:
627            decoded = self._decode_struct(attr.raw, attr_spec.struct_name)
628        elif attr_spec.sub_type:
629            decoded = attr.as_c_array(attr_spec.sub_type)
630            if 'enum' in attr_spec:
631                decoded = [ self._decode_enum(x, attr_spec) for x in decoded ]
632            elif attr_spec.display_hint:
633                decoded = [ self._formatted_string(x, attr_spec.display_hint)
634                            for x in decoded ]
635        else:
636            decoded = attr.as_bin()
637            if attr_spec.display_hint:
638                decoded = self._formatted_string(decoded, attr_spec.display_hint)
639        return decoded
640
641    def _decode_array_attr(self, attr, attr_spec):
642        decoded = []
643        offset = 0
644        while offset < len(attr.raw):
645            item = NlAttr(attr.raw, offset)
646            offset += item.full_len
647
648            if attr_spec["sub-type"] == 'nest':
649                subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes'])
650                decoded.append({ item.type: subattrs })
651            elif attr_spec["sub-type"] == 'binary':
652                subattr = item.as_bin()
653                if attr_spec.display_hint:
654                    subattr = self._formatted_string(subattr, attr_spec.display_hint)
655                decoded.append(subattr)
656            elif attr_spec["sub-type"] in NlAttr.type_formats:
657                subattr = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order)
658                if 'enum' in attr_spec:
659                    subattr = self._decode_enum(subattr, attr_spec)
660                elif attr_spec.display_hint:
661                    subattr = self._formatted_string(subattr, attr_spec.display_hint)
662                decoded.append(subattr)
663            else:
664                raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}')
665        return decoded
666
667    def _decode_nest_type_value(self, attr, attr_spec):
668        decoded = {}
669        value = attr
670        for name in attr_spec['type-value']:
671            value = NlAttr(value.raw, 0)
672            decoded[name] = value.type
673        subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes'])
674        decoded.update(subattrs)
675        return decoded
676
677    def _decode_unknown(self, attr):
678        if attr.is_nest:
679            return self._decode(NlAttrs(attr.raw), None)
680        else:
681            return attr.as_bin()
682
683    def _rsp_add(self, rsp, name, is_multi, decoded):
684        if is_multi == None:
685            if name in rsp and type(rsp[name]) is not list:
686                rsp[name] = [rsp[name]]
687                is_multi = True
688            else:
689                is_multi = False
690
691        if not is_multi:
692            rsp[name] = decoded
693        elif name in rsp:
694            rsp[name].append(decoded)
695        else:
696            rsp[name] = [decoded]
697
698    def _resolve_selector(self, attr_spec, search_attrs):
699        sub_msg = attr_spec.sub_message
700        if sub_msg not in self.sub_msgs:
701            raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}")
702        sub_msg_spec = self.sub_msgs[sub_msg]
703
704        selector = attr_spec.selector
705        value = search_attrs.lookup(selector)
706        if value not in sub_msg_spec.formats:
707            raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'")
708
709        spec = sub_msg_spec.formats[value]
710        return spec
711
712    def _decode_sub_msg(self, attr, attr_spec, search_attrs):
713        msg_format = self._resolve_selector(attr_spec, search_attrs)
714        decoded = {}
715        offset = 0
716        if msg_format.fixed_header:
717            decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header));
718            offset = self._struct_size(msg_format.fixed_header)
719        if msg_format.attr_set:
720            if msg_format.attr_set in self.attr_sets:
721                subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set)
722                decoded.update(subdict)
723            else:
724                raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'")
725        return decoded
726
727    def _decode(self, attrs, space, outer_attrs = None):
728        rsp = dict()
729        if space:
730            attr_space = self.attr_sets[space]
731            search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs)
732
733        for attr in attrs:
734            try:
735                attr_spec = attr_space.attrs_by_val[attr.type]
736            except (KeyError, UnboundLocalError):
737                if not self.process_unknown:
738                    raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'")
739                attr_name = f"UnknownAttr({attr.type})"
740                self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr))
741                continue
742
743            try:
744                if attr_spec["type"] == 'nest':
745                    subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs)
746                    decoded = subdict
747                elif attr_spec["type"] == 'string':
748                    decoded = attr.as_strz()
749                elif attr_spec["type"] == 'binary':
750                    decoded = self._decode_binary(attr, attr_spec)
751                elif attr_spec["type"] == 'flag':
752                    decoded = True
753                elif attr_spec.is_auto_scalar:
754                    decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order)
755                elif attr_spec["type"] in NlAttr.type_formats:
756                    decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
757                    if 'enum' in attr_spec:
758                        decoded = self._decode_enum(decoded, attr_spec)
759                    elif attr_spec.display_hint:
760                        decoded = self._formatted_string(decoded, attr_spec.display_hint)
761                elif attr_spec["type"] == 'indexed-array':
762                    decoded = self._decode_array_attr(attr, attr_spec)
763                elif attr_spec["type"] == 'bitfield32':
764                    value, selector = struct.unpack("II", attr.raw)
765                    if 'enum' in attr_spec:
766                        value = self._decode_enum(value, attr_spec)
767                        selector = self._decode_enum(selector, attr_spec)
768                    decoded = {"value": value, "selector": selector}
769                elif attr_spec["type"] == 'sub-message':
770                    decoded = self._decode_sub_msg(attr, attr_spec, search_attrs)
771                elif attr_spec["type"] == 'nest-type-value':
772                    decoded = self._decode_nest_type_value(attr, attr_spec)
773                else:
774                    if not self.process_unknown:
775                        raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
776                    decoded = self._decode_unknown(attr)
777
778                self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded)
779            except:
780                print(f"Error decoding '{attr_spec.name}' from '{space}'")
781                raise
782
783        return rsp
784
785    def _decode_extack_path(self, attrs, attr_set, offset, target):
786        for attr in attrs:
787            try:
788                attr_spec = attr_set.attrs_by_val[attr.type]
789            except KeyError:
790                raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'")
791            if offset > target:
792                break
793            if offset == target:
794                return '.' + attr_spec.name
795
796            if offset + attr.full_len <= target:
797                offset += attr.full_len
798                continue
799            if attr_spec['type'] != 'nest':
800                raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
801            offset += 4
802            subpath = self._decode_extack_path(NlAttrs(attr.raw),
803                                               self.attr_sets[attr_spec['nested-attributes']],
804                                               offset, target)
805            if subpath is None:
806                return None
807            return '.' + attr_spec.name + subpath
808
809        return None
810
811    def _decode_extack(self, request, op, extack):
812        if 'bad-attr-offs' not in extack:
813            return
814
815        msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op)
816        offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header)
817        path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
818                                        extack['bad-attr-offs'])
819        if path:
820            del extack['bad-attr-offs']
821            extack['bad-attr'] = path
822
823    def _struct_size(self, name):
824        if name:
825            members = self.consts[name].members
826            size = 0
827            for m in members:
828                if m.type in ['pad', 'binary']:
829                    if m.struct:
830                        size += self._struct_size(m.struct)
831                    else:
832                        size += m.len
833                else:
834                    format = NlAttr.get_format(m.type, m.byte_order)
835                    size += format.size
836            return size
837        else:
838            return 0
839
840    def _decode_struct(self, data, name):
841        members = self.consts[name].members
842        attrs = dict()
843        offset = 0
844        for m in members:
845            value = None
846            if m.type == 'pad':
847                offset += m.len
848            elif m.type == 'binary':
849                if m.struct:
850                    len = self._struct_size(m.struct)
851                    value = self._decode_struct(data[offset : offset + len],
852                                                m.struct)
853                    offset += len
854                else:
855                    value = data[offset : offset + m.len]
856                    offset += m.len
857            else:
858                format = NlAttr.get_format(m.type, m.byte_order)
859                [ value ] = format.unpack_from(data, offset)
860                offset += format.size
861            if value is not None:
862                if m.enum:
863                    value = self._decode_enum(value, m)
864                elif m.display_hint:
865                    value = self._formatted_string(value, m.display_hint)
866                attrs[m.name] = value
867        return attrs
868
869    def _encode_struct(self, name, vals):
870        members = self.consts[name].members
871        attr_payload = b''
872        for m in members:
873            value = vals.pop(m.name) if m.name in vals else None
874            if m.type == 'pad':
875                attr_payload += bytearray(m.len)
876            elif m.type == 'binary':
877                if m.struct:
878                    if value is None:
879                        value = dict()
880                    attr_payload += self._encode_struct(m.struct, value)
881                else:
882                    if value is None:
883                        attr_payload += bytearray(m.len)
884                    else:
885                        attr_payload += bytes.fromhex(value)
886            else:
887                if value is None:
888                    value = 0
889                format = NlAttr.get_format(m.type, m.byte_order)
890                attr_payload += format.pack(value)
891        return attr_payload
892
893    def _formatted_string(self, raw, display_hint):
894        if display_hint == 'mac':
895            formatted = ':'.join('%02x' % b for b in raw)
896        elif display_hint == 'hex':
897            if isinstance(raw, int):
898                formatted = hex(raw)
899            else:
900                formatted = bytes.hex(raw, ' ')
901        elif display_hint in [ 'ipv4', 'ipv6' ]:
902            formatted = format(ipaddress.ip_address(raw))
903        elif display_hint == 'uuid':
904            formatted = str(uuid.UUID(bytes=raw))
905        else:
906            formatted = raw
907        return formatted
908
909    def handle_ntf(self, decoded):
910        msg = dict()
911        if self.include_raw:
912            msg['raw'] = decoded
913        op = self.rsp_by_value[decoded.cmd()]
914        attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
915        if op.fixed_header:
916            attrs.update(self._decode_struct(decoded.raw, op.fixed_header))
917
918        msg['name'] = op['name']
919        msg['msg'] = attrs
920        self.async_msg_queue.put(msg)
921
922    def check_ntf(self):
923        while True:
924            try:
925                reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT)
926            except BlockingIOError:
927                return
928
929            nms = NlMsgs(reply)
930            self._recv_dbg_print(reply, nms)
931            for nl_msg in nms:
932                if nl_msg.error:
933                    print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
934                    print(nl_msg)
935                    continue
936                if nl_msg.done:
937                    print("Netlink done while checking for ntf!?")
938                    continue
939
940                decoded = self.nlproto.decode(self, nl_msg, None)
941                if decoded.cmd() not in self.async_msg_ids:
942                    print("Unexpected msg id while checking for ntf", decoded)
943                    continue
944
945                self.handle_ntf(decoded)
946
947    def poll_ntf(self, duration=None):
948        start_time = time.time()
949        selector = selectors.DefaultSelector()
950        selector.register(self.sock, selectors.EVENT_READ)
951
952        while True:
953            try:
954                yield self.async_msg_queue.get_nowait()
955            except queue.Empty:
956                if duration is not None:
957                    timeout = start_time + duration - time.time()
958                    if timeout <= 0:
959                        return
960                else:
961                    timeout = None
962                events = selector.select(timeout)
963                if events:
964                    self.check_ntf()
965
966    def operation_do_attributes(self, name):
967      """
968      For a given operation name, find and return a supported
969      set of attributes (as a dict).
970      """
971      op = self.find_operation(name)
972      if not op:
973        return None
974
975      return op['do']['request']['attributes'].copy()
976
977    def _encode_message(self, op, vals, flags, req_seq):
978        nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
979        for flag in flags or []:
980            nl_flags |= flag
981
982        msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
983        if op.fixed_header:
984            msg += self._encode_struct(op.fixed_header, vals)
985        search_attrs = SpaceAttrs(op.attr_set, vals)
986        for name, value in vals.items():
987            msg += self._add_attr(op.attr_set.name, name, value, search_attrs)
988        msg = _genl_msg_finalize(msg)
989        return msg
990
991    def _ops(self, ops):
992        reqs_by_seq = {}
993        req_seq = random.randint(1024, 65535)
994        payload = b''
995        for (method, vals, flags) in ops:
996            op = self.ops[method]
997            msg = self._encode_message(op, vals, flags, req_seq)
998            reqs_by_seq[req_seq] = (op, msg, flags)
999            payload += msg
1000            req_seq += 1
1001
1002        self.sock.send(payload, 0)
1003
1004        done = False
1005        rsp = []
1006        op_rsp = []
1007        while not done:
1008            reply = self.sock.recv(self._recv_size)
1009            nms = NlMsgs(reply, attr_space=op.attr_set)
1010            self._recv_dbg_print(reply, nms)
1011            for nl_msg in nms:
1012                if nl_msg.nl_seq in reqs_by_seq:
1013                    (op, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq]
1014                    if nl_msg.extack:
1015                        self._decode_extack(req_msg, op, nl_msg.extack)
1016                else:
1017                    op = None
1018                    req_flags = []
1019
1020                if nl_msg.error:
1021                    raise NlError(nl_msg)
1022                if nl_msg.done:
1023                    if nl_msg.extack:
1024                        print("Netlink warning:")
1025                        print(nl_msg)
1026
1027                    if Netlink.NLM_F_DUMP in req_flags:
1028                        rsp.append(op_rsp)
1029                    elif not op_rsp:
1030                        rsp.append(None)
1031                    elif len(op_rsp) == 1:
1032                        rsp.append(op_rsp[0])
1033                    else:
1034                        rsp.append(op_rsp)
1035                    op_rsp = []
1036
1037                    del reqs_by_seq[nl_msg.nl_seq]
1038                    done = len(reqs_by_seq) == 0
1039                    break
1040
1041                decoded = self.nlproto.decode(self, nl_msg, op)
1042
1043                # Check if this is a reply to our request
1044                if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value:
1045                    if decoded.cmd() in self.async_msg_ids:
1046                        self.handle_ntf(decoded)
1047                        continue
1048                    else:
1049                        print('Unexpected message: ' + repr(decoded))
1050                        continue
1051
1052                rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
1053                if op.fixed_header:
1054                    rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header))
1055                op_rsp.append(rsp_msg)
1056
1057        return rsp
1058
1059    def _op(self, method, vals, flags=None, dump=False):
1060        req_flags = flags or []
1061        if dump:
1062            req_flags.append(Netlink.NLM_F_DUMP)
1063
1064        ops = [(method, vals, req_flags)]
1065        return self._ops(ops)[0]
1066
1067    def do(self, method, vals, flags=None):
1068        return self._op(method, vals, flags)
1069
1070    def dump(self, method, vals):
1071        return self._op(method, vals, dump=True)
1072
1073    def do_multi(self, ops):
1074        return self._ops(ops)
1075