xref: /linux/tools/net/ynl/lib/ynl.py (revision a885a6b2d37eaaae08323583bdb1928c8a2935fc)
1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2
3from collections import namedtuple
4from enum import Enum
5import functools
6import os
7import random
8import socket
9import struct
10from struct import Struct
11import sys
12import yaml
13import ipaddress
14import uuid
15import queue
16import time
17
18from .nlspec import SpecFamily
19
20#
21# Generic Netlink code which should really be in some library, but I can't quickly find one.
22#
23
24
25class Netlink:
26    # Netlink socket
27    SOL_NETLINK = 270
28
29    NETLINK_ADD_MEMBERSHIP = 1
30    NETLINK_CAP_ACK = 10
31    NETLINK_EXT_ACK = 11
32    NETLINK_GET_STRICT_CHK = 12
33
34    # Netlink message
35    NLMSG_ERROR = 2
36    NLMSG_DONE = 3
37
38    NLM_F_REQUEST = 1
39    NLM_F_ACK = 4
40    NLM_F_ROOT = 0x100
41    NLM_F_MATCH = 0x200
42
43    NLM_F_REPLACE = 0x100
44    NLM_F_EXCL = 0x200
45    NLM_F_CREATE = 0x400
46    NLM_F_APPEND = 0x800
47
48    NLM_F_CAPPED = 0x100
49    NLM_F_ACK_TLVS = 0x200
50
51    NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
52
53    NLA_F_NESTED = 0x8000
54    NLA_F_NET_BYTEORDER = 0x4000
55
56    NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
57
58    # Genetlink defines
59    NETLINK_GENERIC = 16
60
61    GENL_ID_CTRL = 0x10
62
63    # nlctrl
64    CTRL_CMD_GETFAMILY = 3
65
66    CTRL_ATTR_FAMILY_ID = 1
67    CTRL_ATTR_FAMILY_NAME = 2
68    CTRL_ATTR_MAXATTR = 5
69    CTRL_ATTR_MCAST_GROUPS = 7
70
71    CTRL_ATTR_MCAST_GRP_NAME = 1
72    CTRL_ATTR_MCAST_GRP_ID = 2
73
74    # Extack types
75    NLMSGERR_ATTR_MSG = 1
76    NLMSGERR_ATTR_OFFS = 2
77    NLMSGERR_ATTR_COOKIE = 3
78    NLMSGERR_ATTR_POLICY = 4
79    NLMSGERR_ATTR_MISS_TYPE = 5
80    NLMSGERR_ATTR_MISS_NEST = 6
81
82    # Policy types
83    NL_POLICY_TYPE_ATTR_TYPE = 1
84    NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2
85    NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3
86    NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4
87    NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5
88    NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6
89    NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7
90    NL_POLICY_TYPE_ATTR_POLICY_IDX = 8
91    NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9
92    NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10
93    NL_POLICY_TYPE_ATTR_PAD = 11
94    NL_POLICY_TYPE_ATTR_MASK = 12
95
96    AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64',
97                                  's8', 's16', 's32', 's64',
98                                  'binary', 'string', 'nul-string',
99                                  'nested', 'nested-array',
100                                  'bitfield32', 'sint', 'uint'])
101
102class NlError(Exception):
103  def __init__(self, nl_msg):
104    self.nl_msg = nl_msg
105    self.error = -nl_msg.error
106
107  def __str__(self):
108    return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}"
109
110
111class ConfigError(Exception):
112    pass
113
114
115class NlAttr:
116    ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
117    type_formats = {
118        'u8' : ScalarFormat(Struct('B'), Struct("B"),  Struct("B")),
119        's8' : ScalarFormat(Struct('b'), Struct("b"),  Struct("b")),
120        'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
121        's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
122        'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
123        's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
124        'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
125        's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
126    }
127
128    def __init__(self, raw, offset):
129        self._len, self._type = struct.unpack("HH", raw[offset : offset + 4])
130        self.type = self._type & ~Netlink.NLA_TYPE_MASK
131        self.is_nest = self._type & Netlink.NLA_F_NESTED
132        self.payload_len = self._len
133        self.full_len = (self.payload_len + 3) & ~3
134        self.raw = raw[offset + 4 : offset + self.payload_len]
135
136    @classmethod
137    def get_format(cls, attr_type, byte_order=None):
138        format = cls.type_formats[attr_type]
139        if byte_order:
140            return format.big if byte_order == "big-endian" \
141                else format.little
142        return format.native
143
144    def as_scalar(self, attr_type, byte_order=None):
145        format = self.get_format(attr_type, byte_order)
146        return format.unpack(self.raw)[0]
147
148    def as_auto_scalar(self, attr_type, byte_order=None):
149        if len(self.raw) != 4 and len(self.raw) != 8:
150            raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}")
151        real_type = attr_type[0] + str(len(self.raw) * 8)
152        format = self.get_format(real_type, byte_order)
153        return format.unpack(self.raw)[0]
154
155    def as_strz(self):
156        return self.raw.decode('ascii')[:-1]
157
158    def as_bin(self):
159        return self.raw
160
161    def as_c_array(self, type):
162        format = self.get_format(type)
163        return [ x[0] for x in format.iter_unpack(self.raw) ]
164
165    def __repr__(self):
166        return f"[type:{self.type} len:{self._len}] {self.raw}"
167
168
169class NlAttrs:
170    def __init__(self, msg, offset=0):
171        self.attrs = []
172
173        while offset < len(msg):
174            attr = NlAttr(msg, offset)
175            offset += attr.full_len
176            self.attrs.append(attr)
177
178    def __iter__(self):
179        yield from self.attrs
180
181    def __repr__(self):
182        msg = ''
183        for a in self.attrs:
184            if msg:
185                msg += '\n'
186            msg += repr(a)
187        return msg
188
189
190class NlMsg:
191    def __init__(self, msg, offset, attr_space=None):
192        self.hdr = msg[offset : offset + 16]
193
194        self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
195            struct.unpack("IHHII", self.hdr)
196
197        self.raw = msg[offset + 16 : offset + self.nl_len]
198
199        self.error = 0
200        self.done = 0
201
202        extack_off = None
203        if self.nl_type == Netlink.NLMSG_ERROR:
204            self.error = struct.unpack("i", self.raw[0:4])[0]
205            self.done = 1
206            extack_off = 20
207        elif self.nl_type == Netlink.NLMSG_DONE:
208            self.error = struct.unpack("i", self.raw[0:4])[0]
209            self.done = 1
210            extack_off = 4
211
212        self.extack = None
213        if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
214            self.extack = dict()
215            extack_attrs = NlAttrs(self.raw[extack_off:])
216            for extack in extack_attrs:
217                if extack.type == Netlink.NLMSGERR_ATTR_MSG:
218                    self.extack['msg'] = extack.as_strz()
219                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
220                    self.extack['miss-type'] = extack.as_scalar('u32')
221                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
222                    self.extack['miss-nest'] = extack.as_scalar('u32')
223                elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
224                    self.extack['bad-attr-offs'] = extack.as_scalar('u32')
225                elif extack.type == Netlink.NLMSGERR_ATTR_POLICY:
226                    self.extack['policy'] = self._decode_policy(extack.raw)
227                else:
228                    if 'unknown' not in self.extack:
229                        self.extack['unknown'] = []
230                    self.extack['unknown'].append(extack)
231
232            if attr_space:
233                # We don't have the ability to parse nests yet, so only do global
234                if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
235                    miss_type = self.extack['miss-type']
236                    if miss_type in attr_space.attrs_by_val:
237                        spec = attr_space.attrs_by_val[miss_type]
238                        self.extack['miss-type'] = spec['name']
239                        if 'doc' in spec:
240                            self.extack['miss-type-doc'] = spec['doc']
241
242    def _decode_policy(self, raw):
243        policy = {}
244        for attr in NlAttrs(raw):
245            if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE:
246                type = attr.as_scalar('u32')
247                policy['type'] = Netlink.AttrType(type).name
248            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S:
249                policy['min-value'] = attr.as_scalar('s64')
250            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S:
251                policy['max-value'] = attr.as_scalar('s64')
252            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U:
253                policy['min-value'] = attr.as_scalar('u64')
254            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U:
255                policy['max-value'] = attr.as_scalar('u64')
256            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH:
257                policy['min-length'] = attr.as_scalar('u32')
258            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH:
259                policy['max-length'] = attr.as_scalar('u32')
260            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK:
261                policy['bitfield32-mask'] = attr.as_scalar('u32')
262            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK:
263                policy['mask'] = attr.as_scalar('u64')
264        return policy
265
266    def cmd(self):
267        return self.nl_type
268
269    def __repr__(self):
270        msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}"
271        if self.error:
272            msg += '\n\terror: ' + str(self.error)
273        if self.extack:
274            msg += '\n\textack: ' + repr(self.extack)
275        return msg
276
277
278class NlMsgs:
279    def __init__(self, data, attr_space=None):
280        self.msgs = []
281
282        offset = 0
283        while offset < len(data):
284            msg = NlMsg(data, offset, attr_space=attr_space)
285            offset += msg.nl_len
286            self.msgs.append(msg)
287
288    def __iter__(self):
289        yield from self.msgs
290
291
292genl_family_name_to_id = None
293
294
295def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
296    # we prepend length in _genl_msg_finalize()
297    if seq is None:
298        seq = random.randint(1, 1024)
299    nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
300    genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
301    return nlmsg + genlmsg
302
303
304def _genl_msg_finalize(msg):
305    return struct.pack("I", len(msg) + 4) + msg
306
307
308def _genl_load_families():
309    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
310        sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
311
312        msg = _genl_msg(Netlink.GENL_ID_CTRL,
313                        Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
314                        Netlink.CTRL_CMD_GETFAMILY, 1)
315        msg = _genl_msg_finalize(msg)
316
317        sock.send(msg, 0)
318
319        global genl_family_name_to_id
320        genl_family_name_to_id = dict()
321
322        while True:
323            reply = sock.recv(128 * 1024)
324            nms = NlMsgs(reply)
325            for nl_msg in nms:
326                if nl_msg.error:
327                    print("Netlink error:", nl_msg.error)
328                    return
329                if nl_msg.done:
330                    return
331
332                gm = GenlMsg(nl_msg)
333                fam = dict()
334                for attr in NlAttrs(gm.raw):
335                    if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
336                        fam['id'] = attr.as_scalar('u16')
337                    elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
338                        fam['name'] = attr.as_strz()
339                    elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
340                        fam['maxattr'] = attr.as_scalar('u32')
341                    elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
342                        fam['mcast'] = dict()
343                        for entry in NlAttrs(attr.raw):
344                            mcast_name = None
345                            mcast_id = None
346                            for entry_attr in NlAttrs(entry.raw):
347                                if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
348                                    mcast_name = entry_attr.as_strz()
349                                elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
350                                    mcast_id = entry_attr.as_scalar('u32')
351                            if mcast_name and mcast_id is not None:
352                                fam['mcast'][mcast_name] = mcast_id
353                if 'name' in fam and 'id' in fam:
354                    genl_family_name_to_id[fam['name']] = fam
355
356
357class GenlMsg:
358    def __init__(self, nl_msg):
359        self.nl = nl_msg
360        self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
361        self.raw = nl_msg.raw[4:]
362
363    def cmd(self):
364        return self.genl_cmd
365
366    def __repr__(self):
367        msg = repr(self.nl)
368        msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
369        for a in self.raw_attrs:
370            msg += '\t\t' + repr(a) + '\n'
371        return msg
372
373
374class NetlinkProtocol:
375    def __init__(self, family_name, proto_num):
376        self.family_name = family_name
377        self.proto_num = proto_num
378
379    def _message(self, nl_type, nl_flags, seq=None):
380        if seq is None:
381            seq = random.randint(1, 1024)
382        nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
383        return nlmsg
384
385    def message(self, flags, command, version, seq=None):
386        return self._message(command, flags, seq)
387
388    def _decode(self, nl_msg):
389        return nl_msg
390
391    def decode(self, ynl, nl_msg, op):
392        msg = self._decode(nl_msg)
393        if op is None:
394            op = ynl.rsp_by_value[msg.cmd()]
395        fixed_header_size = ynl._struct_size(op.fixed_header)
396        msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size)
397        return msg
398
399    def get_mcast_id(self, mcast_name, mcast_groups):
400        if mcast_name not in mcast_groups:
401            raise Exception(f'Multicast group "{mcast_name}" not present in the spec')
402        return mcast_groups[mcast_name].value
403
404    def msghdr_size(self):
405        return 16
406
407
408class GenlProtocol(NetlinkProtocol):
409    def __init__(self, family_name):
410        super().__init__(family_name, Netlink.NETLINK_GENERIC)
411
412        global genl_family_name_to_id
413        if genl_family_name_to_id is None:
414            _genl_load_families()
415
416        self.genl_family = genl_family_name_to_id[family_name]
417        self.family_id = genl_family_name_to_id[family_name]['id']
418
419    def message(self, flags, command, version, seq=None):
420        nlmsg = self._message(self.family_id, flags, seq)
421        genlmsg = struct.pack("BBH", command, version, 0)
422        return nlmsg + genlmsg
423
424    def _decode(self, nl_msg):
425        return GenlMsg(nl_msg)
426
427    def get_mcast_id(self, mcast_name, mcast_groups):
428        if mcast_name not in self.genl_family['mcast']:
429            raise Exception(f'Multicast group "{mcast_name}" not present in the family')
430        return self.genl_family['mcast'][mcast_name]
431
432    def msghdr_size(self):
433        return super().msghdr_size() + 4
434
435
436class SpaceAttrs:
437    SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values'])
438
439    def __init__(self, attr_space, attrs, outer = None):
440        outer_scopes = outer.scopes if outer else []
441        inner_scope = self.SpecValuesPair(attr_space, attrs)
442        self.scopes = [inner_scope] + outer_scopes
443
444    def lookup(self, name):
445        for scope in self.scopes:
446            if name in scope.spec:
447                if name in scope.values:
448                    return scope.values[name]
449                spec_name = scope.spec.yaml['name']
450                raise Exception(
451                    f"No value for '{name}' in attribute space '{spec_name}'")
452        raise Exception(f"Attribute '{name}' not defined in any attribute-set")
453
454
455#
456# YNL implementation details.
457#
458
459
460class YnlFamily(SpecFamily):
461    def __init__(self, def_path, schema=None, process_unknown=False,
462                 recv_size=0):
463        super().__init__(def_path, schema)
464
465        self.include_raw = False
466        self.process_unknown = process_unknown
467
468        try:
469            if self.proto == "netlink-raw":
470                self.nlproto = NetlinkProtocol(self.yaml['name'],
471                                               self.yaml['protonum'])
472            else:
473                self.nlproto = GenlProtocol(self.yaml['name'])
474        except KeyError:
475            raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
476
477        self._recv_dbg = False
478        # Note that netlink will use conservative (min) message size for
479        # the first dump recv() on the socket, our setting will only matter
480        # from the second recv() on.
481        self._recv_size = recv_size if recv_size else 131072
482        # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo)
483        # for a message, so smaller receive sizes will lead to truncation.
484        # Note that the min size for other families may be larger than 4k!
485        if self._recv_size < 4000:
486            raise ConfigError()
487
488        self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
489        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
490        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
491        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)
492
493        self.async_msg_ids = set()
494        self.async_msg_queue = queue.Queue()
495
496        for msg in self.msgs.values():
497            if msg.is_async:
498                self.async_msg_ids.add(msg.rsp_value)
499
500        for op_name, op in self.ops.items():
501            bound_f = functools.partial(self._op, op_name)
502            setattr(self, op.ident_name, bound_f)
503
504
505    def ntf_subscribe(self, mcast_name):
506        mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
507        self.sock.bind((0, 0))
508        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
509                             mcast_id)
510
511    def set_recv_dbg(self, enabled):
512        self._recv_dbg = enabled
513
514    def _recv_dbg_print(self, reply, nl_msgs):
515        if not self._recv_dbg:
516            return
517        print("Recv: read", len(reply), "bytes,",
518              len(nl_msgs.msgs), "messages", file=sys.stderr)
519        for nl_msg in nl_msgs:
520            print("  ", nl_msg, file=sys.stderr)
521
522    def _encode_enum(self, attr_spec, value):
523        enum = self.consts[attr_spec['enum']]
524        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
525            scalar = 0
526            if isinstance(value, str):
527                value = [value]
528            for single_value in value:
529                scalar += enum.entries[single_value].user_value(as_flags = True)
530            return scalar
531        else:
532            return enum.entries[value].user_value()
533
534    def _get_scalar(self, attr_spec, value):
535        try:
536            return int(value)
537        except (ValueError, TypeError) as e:
538            if 'enum' not in attr_spec:
539                raise e
540        return self._encode_enum(attr_spec, value)
541
542    def _add_attr(self, space, name, value, search_attrs):
543        try:
544            attr = self.attr_sets[space][name]
545        except KeyError:
546            raise Exception(f"Space '{space}' has no attribute '{name}'")
547        nl_type = attr.value
548
549        if attr.is_multi and isinstance(value, list):
550            attr_payload = b''
551            for subvalue in value:
552                attr_payload += self._add_attr(space, name, subvalue, search_attrs)
553            return attr_payload
554
555        if attr["type"] == 'nest':
556            nl_type |= Netlink.NLA_F_NESTED
557            attr_payload = b''
558            sub_attrs = SpaceAttrs(self.attr_sets[space], value, search_attrs)
559            for subname, subvalue in value.items():
560                attr_payload += self._add_attr(attr['nested-attributes'],
561                                               subname, subvalue, sub_attrs)
562        elif attr["type"] == 'flag':
563            if not value:
564                # If value is absent or false then skip attribute creation.
565                return b''
566            attr_payload = b''
567        elif attr["type"] == 'string':
568            attr_payload = str(value).encode('ascii') + b'\x00'
569        elif attr["type"] == 'binary':
570            if isinstance(value, bytes):
571                attr_payload = value
572            elif isinstance(value, str):
573                attr_payload = bytes.fromhex(value)
574            elif isinstance(value, dict) and attr.struct_name:
575                attr_payload = self._encode_struct(attr.struct_name, value)
576            else:
577                raise Exception(f'Unknown type for binary attribute, value: {value}')
578        elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar:
579            scalar = self._get_scalar(attr, value)
580            if attr.is_auto_scalar:
581                attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64')
582            else:
583                attr_type = attr["type"]
584            format = NlAttr.get_format(attr_type, attr.byte_order)
585            attr_payload = format.pack(scalar)
586        elif attr['type'] in "bitfield32":
587            scalar_value = self._get_scalar(attr, value["value"])
588            scalar_selector = self._get_scalar(attr, value["selector"])
589            attr_payload = struct.pack("II", scalar_value, scalar_selector)
590        elif attr['type'] == 'sub-message':
591            msg_format = self._resolve_selector(attr, search_attrs)
592            attr_payload = b''
593            if msg_format.fixed_header:
594                attr_payload += self._encode_struct(msg_format.fixed_header, value)
595            if msg_format.attr_set:
596                if msg_format.attr_set in self.attr_sets:
597                    nl_type |= Netlink.NLA_F_NESTED
598                    sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs)
599                    for subname, subvalue in value.items():
600                        attr_payload += self._add_attr(msg_format.attr_set,
601                                                       subname, subvalue, sub_attrs)
602                else:
603                    raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'")
604        else:
605            raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
606
607        pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
608        return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
609
610    def _decode_enum(self, raw, attr_spec):
611        enum = self.consts[attr_spec['enum']]
612        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
613            i = 0
614            value = set()
615            while raw:
616                if raw & 1:
617                    value.add(enum.entries_by_val[i].name)
618                raw >>= 1
619                i += 1
620        else:
621            value = enum.entries_by_val[raw].name
622        return value
623
624    def _decode_binary(self, attr, attr_spec):
625        if attr_spec.struct_name:
626            decoded = self._decode_struct(attr.raw, attr_spec.struct_name)
627        elif attr_spec.sub_type:
628            decoded = attr.as_c_array(attr_spec.sub_type)
629        else:
630            decoded = attr.as_bin()
631            if attr_spec.display_hint:
632                decoded = self._formatted_string(decoded, attr_spec.display_hint)
633        return decoded
634
635    def _decode_array_attr(self, attr, attr_spec):
636        decoded = []
637        offset = 0
638        while offset < len(attr.raw):
639            item = NlAttr(attr.raw, offset)
640            offset += item.full_len
641
642            if attr_spec["sub-type"] == 'nest':
643                subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes'])
644                decoded.append({ item.type: subattrs })
645            elif attr_spec["sub-type"] == 'binary':
646                subattrs = item.as_bin()
647                if attr_spec.display_hint:
648                    subattrs = self._formatted_string(subattrs, attr_spec.display_hint)
649                decoded.append(subattrs)
650            elif attr_spec["sub-type"] in NlAttr.type_formats:
651                subattrs = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order)
652                if attr_spec.display_hint:
653                    subattrs = self._formatted_string(subattrs, attr_spec.display_hint)
654                decoded.append(subattrs)
655            else:
656                raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}')
657        return decoded
658
659    def _decode_nest_type_value(self, attr, attr_spec):
660        decoded = {}
661        value = attr
662        for name in attr_spec['type-value']:
663            value = NlAttr(value.raw, 0)
664            decoded[name] = value.type
665        subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes'])
666        decoded.update(subattrs)
667        return decoded
668
669    def _decode_unknown(self, attr):
670        if attr.is_nest:
671            return self._decode(NlAttrs(attr.raw), None)
672        else:
673            return attr.as_bin()
674
675    def _rsp_add(self, rsp, name, is_multi, decoded):
676        if is_multi == None:
677            if name in rsp and type(rsp[name]) is not list:
678                rsp[name] = [rsp[name]]
679                is_multi = True
680            else:
681                is_multi = False
682
683        if not is_multi:
684            rsp[name] = decoded
685        elif name in rsp:
686            rsp[name].append(decoded)
687        else:
688            rsp[name] = [decoded]
689
690    def _resolve_selector(self, attr_spec, search_attrs):
691        sub_msg = attr_spec.sub_message
692        if sub_msg not in self.sub_msgs:
693            raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}")
694        sub_msg_spec = self.sub_msgs[sub_msg]
695
696        selector = attr_spec.selector
697        value = search_attrs.lookup(selector)
698        if value not in sub_msg_spec.formats:
699            raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'")
700
701        spec = sub_msg_spec.formats[value]
702        return spec
703
704    def _decode_sub_msg(self, attr, attr_spec, search_attrs):
705        msg_format = self._resolve_selector(attr_spec, search_attrs)
706        decoded = {}
707        offset = 0
708        if msg_format.fixed_header:
709            decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header));
710            offset = self._struct_size(msg_format.fixed_header)
711        if msg_format.attr_set:
712            if msg_format.attr_set in self.attr_sets:
713                subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set)
714                decoded.update(subdict)
715            else:
716                raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'")
717        return decoded
718
719    def _decode(self, attrs, space, outer_attrs = None):
720        rsp = dict()
721        if space:
722            attr_space = self.attr_sets[space]
723            search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs)
724
725        for attr in attrs:
726            try:
727                attr_spec = attr_space.attrs_by_val[attr.type]
728            except (KeyError, UnboundLocalError):
729                if not self.process_unknown:
730                    raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'")
731                attr_name = f"UnknownAttr({attr.type})"
732                self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr))
733                continue
734
735            if attr_spec["type"] == 'nest':
736                subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs)
737                decoded = subdict
738            elif attr_spec["type"] == 'string':
739                decoded = attr.as_strz()
740            elif attr_spec["type"] == 'binary':
741                decoded = self._decode_binary(attr, attr_spec)
742            elif attr_spec["type"] == 'flag':
743                decoded = True
744            elif attr_spec.is_auto_scalar:
745                decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order)
746            elif attr_spec["type"] in NlAttr.type_formats:
747                decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
748                if 'enum' in attr_spec:
749                    decoded = self._decode_enum(decoded, attr_spec)
750                elif attr_spec.display_hint:
751                    decoded = self._formatted_string(decoded, attr_spec.display_hint)
752            elif attr_spec["type"] == 'indexed-array':
753                decoded = self._decode_array_attr(attr, attr_spec)
754            elif attr_spec["type"] == 'bitfield32':
755                value, selector = struct.unpack("II", attr.raw)
756                if 'enum' in attr_spec:
757                    value = self._decode_enum(value, attr_spec)
758                    selector = self._decode_enum(selector, attr_spec)
759                decoded = {"value": value, "selector": selector}
760            elif attr_spec["type"] == 'sub-message':
761                decoded = self._decode_sub_msg(attr, attr_spec, search_attrs)
762            elif attr_spec["type"] == 'nest-type-value':
763                decoded = self._decode_nest_type_value(attr, attr_spec)
764            else:
765                if not self.process_unknown:
766                    raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
767                decoded = self._decode_unknown(attr)
768
769            self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded)
770
771        return rsp
772
773    def _decode_extack_path(self, attrs, attr_set, offset, target):
774        for attr in attrs:
775            try:
776                attr_spec = attr_set.attrs_by_val[attr.type]
777            except KeyError:
778                raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'")
779            if offset > target:
780                break
781            if offset == target:
782                return '.' + attr_spec.name
783
784            if offset + attr.full_len <= target:
785                offset += attr.full_len
786                continue
787            if attr_spec['type'] != 'nest':
788                raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
789            offset += 4
790            subpath = self._decode_extack_path(NlAttrs(attr.raw),
791                                               self.attr_sets[attr_spec['nested-attributes']],
792                                               offset, target)
793            if subpath is None:
794                return None
795            return '.' + attr_spec.name + subpath
796
797        return None
798
799    def _decode_extack(self, request, op, extack):
800        if 'bad-attr-offs' not in extack:
801            return
802
803        msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op)
804        offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header)
805        path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
806                                        extack['bad-attr-offs'])
807        if path:
808            del extack['bad-attr-offs']
809            extack['bad-attr'] = path
810
811    def _struct_size(self, name):
812        if name:
813            members = self.consts[name].members
814            size = 0
815            for m in members:
816                if m.type in ['pad', 'binary']:
817                    if m.struct:
818                        size += self._struct_size(m.struct)
819                    else:
820                        size += m.len
821                else:
822                    format = NlAttr.get_format(m.type, m.byte_order)
823                    size += format.size
824            return size
825        else:
826            return 0
827
828    def _decode_struct(self, data, name):
829        members = self.consts[name].members
830        attrs = dict()
831        offset = 0
832        for m in members:
833            value = None
834            if m.type == 'pad':
835                offset += m.len
836            elif m.type == 'binary':
837                if m.struct:
838                    len = self._struct_size(m.struct)
839                    value = self._decode_struct(data[offset : offset + len],
840                                                m.struct)
841                    offset += len
842                else:
843                    value = data[offset : offset + m.len]
844                    offset += m.len
845            else:
846                format = NlAttr.get_format(m.type, m.byte_order)
847                [ value ] = format.unpack_from(data, offset)
848                offset += format.size
849            if value is not None:
850                if m.enum:
851                    value = self._decode_enum(value, m)
852                elif m.display_hint:
853                    value = self._formatted_string(value, m.display_hint)
854                attrs[m.name] = value
855        return attrs
856
857    def _encode_struct(self, name, vals):
858        members = self.consts[name].members
859        attr_payload = b''
860        for m in members:
861            value = vals.pop(m.name) if m.name in vals else None
862            if m.type == 'pad':
863                attr_payload += bytearray(m.len)
864            elif m.type == 'binary':
865                if m.struct:
866                    if value is None:
867                        value = dict()
868                    attr_payload += self._encode_struct(m.struct, value)
869                else:
870                    if value is None:
871                        attr_payload += bytearray(m.len)
872                    else:
873                        attr_payload += bytes.fromhex(value)
874            else:
875                if value is None:
876                    value = 0
877                format = NlAttr.get_format(m.type, m.byte_order)
878                attr_payload += format.pack(value)
879        return attr_payload
880
881    def _formatted_string(self, raw, display_hint):
882        if display_hint == 'mac':
883            formatted = ':'.join('%02x' % b for b in raw)
884        elif display_hint == 'hex':
885            if isinstance(raw, int):
886                formatted = hex(raw)
887            else:
888                formatted = bytes.hex(raw, ' ')
889        elif display_hint in [ 'ipv4', 'ipv6' ]:
890            formatted = format(ipaddress.ip_address(raw))
891        elif display_hint == 'uuid':
892            formatted = str(uuid.UUID(bytes=raw))
893        else:
894            formatted = raw
895        return formatted
896
897    def handle_ntf(self, decoded):
898        msg = dict()
899        if self.include_raw:
900            msg['raw'] = decoded
901        op = self.rsp_by_value[decoded.cmd()]
902        attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
903        if op.fixed_header:
904            attrs.update(self._decode_struct(decoded.raw, op.fixed_header))
905
906        msg['name'] = op['name']
907        msg['msg'] = attrs
908        self.async_msg_queue.put(msg)
909
910    def check_ntf(self, interval=0.1):
911        while True:
912            try:
913                reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT)
914                nms = NlMsgs(reply)
915                self._recv_dbg_print(reply, nms)
916                for nl_msg in nms:
917                    if nl_msg.error:
918                        print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
919                        print(nl_msg)
920                        continue
921                    if nl_msg.done:
922                        print("Netlink done while checking for ntf!?")
923                        continue
924
925                    decoded = self.nlproto.decode(self, nl_msg, None)
926                    if decoded.cmd() not in self.async_msg_ids:
927                        print("Unexpected msg id while checking for ntf", decoded)
928                        continue
929
930                    self.handle_ntf(decoded)
931            except BlockingIOError:
932                pass
933
934            try:
935                yield self.async_msg_queue.get_nowait()
936            except queue.Empty:
937                try:
938                    time.sleep(interval)
939                except KeyboardInterrupt:
940                    return
941
942    def operation_do_attributes(self, name):
943      """
944      For a given operation name, find and return a supported
945      set of attributes (as a dict).
946      """
947      op = self.find_operation(name)
948      if not op:
949        return None
950
951      return op['do']['request']['attributes'].copy()
952
953    def _encode_message(self, op, vals, flags, req_seq):
954        nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
955        for flag in flags or []:
956            nl_flags |= flag
957
958        msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
959        if op.fixed_header:
960            msg += self._encode_struct(op.fixed_header, vals)
961        search_attrs = SpaceAttrs(op.attr_set, vals)
962        for name, value in vals.items():
963            msg += self._add_attr(op.attr_set.name, name, value, search_attrs)
964        msg = _genl_msg_finalize(msg)
965        return msg
966
967    def _ops(self, ops):
968        reqs_by_seq = {}
969        req_seq = random.randint(1024, 65535)
970        payload = b''
971        for (method, vals, flags) in ops:
972            op = self.ops[method]
973            msg = self._encode_message(op, vals, flags, req_seq)
974            reqs_by_seq[req_seq] = (op, msg, flags)
975            payload += msg
976            req_seq += 1
977
978        self.sock.send(payload, 0)
979
980        done = False
981        rsp = []
982        op_rsp = []
983        while not done:
984            reply = self.sock.recv(self._recv_size)
985            nms = NlMsgs(reply, attr_space=op.attr_set)
986            self._recv_dbg_print(reply, nms)
987            for nl_msg in nms:
988                if nl_msg.nl_seq in reqs_by_seq:
989                    (op, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq]
990                    if nl_msg.extack:
991                        self._decode_extack(req_msg, op, nl_msg.extack)
992                else:
993                    op = None
994                    req_flags = []
995
996                if nl_msg.error:
997                    raise NlError(nl_msg)
998                if nl_msg.done:
999                    if nl_msg.extack:
1000                        print("Netlink warning:")
1001                        print(nl_msg)
1002
1003                    if Netlink.NLM_F_DUMP in req_flags:
1004                        rsp.append(op_rsp)
1005                    elif not op_rsp:
1006                        rsp.append(None)
1007                    elif len(op_rsp) == 1:
1008                        rsp.append(op_rsp[0])
1009                    else:
1010                        rsp.append(op_rsp)
1011                    op_rsp = []
1012
1013                    del reqs_by_seq[nl_msg.nl_seq]
1014                    done = len(reqs_by_seq) == 0
1015                    break
1016
1017                decoded = self.nlproto.decode(self, nl_msg, op)
1018
1019                # Check if this is a reply to our request
1020                if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value:
1021                    if decoded.cmd() in self.async_msg_ids:
1022                        self.handle_ntf(decoded)
1023                        continue
1024                    else:
1025                        print('Unexpected message: ' + repr(decoded))
1026                        continue
1027
1028                rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
1029                if op.fixed_header:
1030                    rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header))
1031                op_rsp.append(rsp_msg)
1032
1033        return rsp
1034
1035    def _op(self, method, vals, flags=None, dump=False):
1036        req_flags = flags or []
1037        if dump:
1038            req_flags.append(Netlink.NLM_F_DUMP)
1039
1040        ops = [(method, vals, req_flags)]
1041        return self._ops(ops)[0]
1042
1043    def do(self, method, vals, flags=None):
1044        return self._op(method, vals, flags)
1045
1046    def dump(self, method, vals):
1047        return self._op(method, vals, dump=True)
1048
1049    def do_multi(self, ops):
1050        return self._ops(ops)
1051