xref: /linux/tools/net/ynl/lib/ynl.py (revision 163e9fc6957fc24d1d6c0a30a3febfd2ecade039)
1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2
3from collections import namedtuple
4import functools
5import os
6import random
7import socket
8import struct
9from struct import Struct
10import sys
11import yaml
12import ipaddress
13import uuid
14
15from .nlspec import SpecFamily
16
17#
18# Generic Netlink code which should really be in some library, but I can't quickly find one.
19#
20
21
22class Netlink:
23    # Netlink socket
24    SOL_NETLINK = 270
25
26    NETLINK_ADD_MEMBERSHIP = 1
27    NETLINK_CAP_ACK = 10
28    NETLINK_EXT_ACK = 11
29    NETLINK_GET_STRICT_CHK = 12
30
31    # Netlink message
32    NLMSG_ERROR = 2
33    NLMSG_DONE = 3
34
35    NLM_F_REQUEST = 1
36    NLM_F_ACK = 4
37    NLM_F_ROOT = 0x100
38    NLM_F_MATCH = 0x200
39
40    NLM_F_REPLACE = 0x100
41    NLM_F_EXCL = 0x200
42    NLM_F_CREATE = 0x400
43    NLM_F_APPEND = 0x800
44
45    NLM_F_CAPPED = 0x100
46    NLM_F_ACK_TLVS = 0x200
47
48    NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
49
50    NLA_F_NESTED = 0x8000
51    NLA_F_NET_BYTEORDER = 0x4000
52
53    NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
54
55    # Genetlink defines
56    NETLINK_GENERIC = 16
57
58    GENL_ID_CTRL = 0x10
59
60    # nlctrl
61    CTRL_CMD_GETFAMILY = 3
62
63    CTRL_ATTR_FAMILY_ID = 1
64    CTRL_ATTR_FAMILY_NAME = 2
65    CTRL_ATTR_MAXATTR = 5
66    CTRL_ATTR_MCAST_GROUPS = 7
67
68    CTRL_ATTR_MCAST_GRP_NAME = 1
69    CTRL_ATTR_MCAST_GRP_ID = 2
70
71    # Extack types
72    NLMSGERR_ATTR_MSG = 1
73    NLMSGERR_ATTR_OFFS = 2
74    NLMSGERR_ATTR_COOKIE = 3
75    NLMSGERR_ATTR_POLICY = 4
76    NLMSGERR_ATTR_MISS_TYPE = 5
77    NLMSGERR_ATTR_MISS_NEST = 6
78
79
80class NlError(Exception):
81  def __init__(self, nl_msg):
82    self.nl_msg = nl_msg
83
84  def __str__(self):
85    return f"Netlink error: {os.strerror(-self.nl_msg.error)}\n{self.nl_msg}"
86
87
88class ConfigError(Exception):
89    pass
90
91
92class NlAttr:
93    ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
94    type_formats = {
95        'u8' : ScalarFormat(Struct('B'), Struct("B"),  Struct("B")),
96        's8' : ScalarFormat(Struct('b'), Struct("b"),  Struct("b")),
97        'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
98        's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
99        'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
100        's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
101        'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
102        's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
103    }
104
105    def __init__(self, raw, offset):
106        self._len, self._type = struct.unpack("HH", raw[offset : offset + 4])
107        self.type = self._type & ~Netlink.NLA_TYPE_MASK
108        self.is_nest = self._type & Netlink.NLA_F_NESTED
109        self.payload_len = self._len
110        self.full_len = (self.payload_len + 3) & ~3
111        self.raw = raw[offset + 4 : offset + self.payload_len]
112
113    @classmethod
114    def get_format(cls, attr_type, byte_order=None):
115        format = cls.type_formats[attr_type]
116        if byte_order:
117            return format.big if byte_order == "big-endian" \
118                else format.little
119        return format.native
120
121    def as_scalar(self, attr_type, byte_order=None):
122        format = self.get_format(attr_type, byte_order)
123        return format.unpack(self.raw)[0]
124
125    def as_auto_scalar(self, attr_type, byte_order=None):
126        if len(self.raw) != 4 and len(self.raw) != 8:
127            raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}")
128        real_type = attr_type[0] + str(len(self.raw) * 8)
129        format = self.get_format(real_type, byte_order)
130        return format.unpack(self.raw)[0]
131
132    def as_strz(self):
133        return self.raw.decode('ascii')[:-1]
134
135    def as_bin(self):
136        return self.raw
137
138    def as_c_array(self, type):
139        format = self.get_format(type)
140        return [ x[0] for x in format.iter_unpack(self.raw) ]
141
142    def __repr__(self):
143        return f"[type:{self.type} len:{self._len}] {self.raw}"
144
145
146class NlAttrs:
147    def __init__(self, msg, offset=0):
148        self.attrs = []
149
150        while offset < len(msg):
151            attr = NlAttr(msg, offset)
152            offset += attr.full_len
153            self.attrs.append(attr)
154
155    def __iter__(self):
156        yield from self.attrs
157
158    def __repr__(self):
159        msg = ''
160        for a in self.attrs:
161            if msg:
162                msg += '\n'
163            msg += repr(a)
164        return msg
165
166
167class NlMsg:
168    def __init__(self, msg, offset, attr_space=None):
169        self.hdr = msg[offset : offset + 16]
170
171        self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
172            struct.unpack("IHHII", self.hdr)
173
174        self.raw = msg[offset + 16 : offset + self.nl_len]
175
176        self.error = 0
177        self.done = 0
178
179        extack_off = None
180        if self.nl_type == Netlink.NLMSG_ERROR:
181            self.error = struct.unpack("i", self.raw[0:4])[0]
182            self.done = 1
183            extack_off = 20
184        elif self.nl_type == Netlink.NLMSG_DONE:
185            self.error = struct.unpack("i", self.raw[0:4])[0]
186            self.done = 1
187            extack_off = 4
188
189        self.extack = None
190        if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
191            self.extack = dict()
192            extack_attrs = NlAttrs(self.raw[extack_off:])
193            for extack in extack_attrs:
194                if extack.type == Netlink.NLMSGERR_ATTR_MSG:
195                    self.extack['msg'] = extack.as_strz()
196                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
197                    self.extack['miss-type'] = extack.as_scalar('u32')
198                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
199                    self.extack['miss-nest'] = extack.as_scalar('u32')
200                elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
201                    self.extack['bad-attr-offs'] = extack.as_scalar('u32')
202                else:
203                    if 'unknown' not in self.extack:
204                        self.extack['unknown'] = []
205                    self.extack['unknown'].append(extack)
206
207            if attr_space:
208                # We don't have the ability to parse nests yet, so only do global
209                if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
210                    miss_type = self.extack['miss-type']
211                    if miss_type in attr_space.attrs_by_val:
212                        spec = attr_space.attrs_by_val[miss_type]
213                        desc = spec['name']
214                        if 'doc' in spec:
215                            desc += f" ({spec['doc']})"
216                        self.extack['miss-type'] = desc
217
218    def cmd(self):
219        return self.nl_type
220
221    def __repr__(self):
222        msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}"
223        if self.error:
224            msg += '\n\terror: ' + str(self.error)
225        if self.extack:
226            msg += '\n\textack: ' + repr(self.extack)
227        return msg
228
229
230class NlMsgs:
231    def __init__(self, data, attr_space=None):
232        self.msgs = []
233
234        offset = 0
235        while offset < len(data):
236            msg = NlMsg(data, offset, attr_space=attr_space)
237            offset += msg.nl_len
238            self.msgs.append(msg)
239
240    def __iter__(self):
241        yield from self.msgs
242
243
244genl_family_name_to_id = None
245
246
247def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
248    # we prepend length in _genl_msg_finalize()
249    if seq is None:
250        seq = random.randint(1, 1024)
251    nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
252    genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
253    return nlmsg + genlmsg
254
255
256def _genl_msg_finalize(msg):
257    return struct.pack("I", len(msg) + 4) + msg
258
259
260def _genl_load_families():
261    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
262        sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
263
264        msg = _genl_msg(Netlink.GENL_ID_CTRL,
265                        Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
266                        Netlink.CTRL_CMD_GETFAMILY, 1)
267        msg = _genl_msg_finalize(msg)
268
269        sock.send(msg, 0)
270
271        global genl_family_name_to_id
272        genl_family_name_to_id = dict()
273
274        while True:
275            reply = sock.recv(128 * 1024)
276            nms = NlMsgs(reply)
277            for nl_msg in nms:
278                if nl_msg.error:
279                    print("Netlink error:", nl_msg.error)
280                    return
281                if nl_msg.done:
282                    return
283
284                gm = GenlMsg(nl_msg)
285                fam = dict()
286                for attr in NlAttrs(gm.raw):
287                    if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
288                        fam['id'] = attr.as_scalar('u16')
289                    elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
290                        fam['name'] = attr.as_strz()
291                    elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
292                        fam['maxattr'] = attr.as_scalar('u32')
293                    elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
294                        fam['mcast'] = dict()
295                        for entry in NlAttrs(attr.raw):
296                            mcast_name = None
297                            mcast_id = None
298                            for entry_attr in NlAttrs(entry.raw):
299                                if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
300                                    mcast_name = entry_attr.as_strz()
301                                elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
302                                    mcast_id = entry_attr.as_scalar('u32')
303                            if mcast_name and mcast_id is not None:
304                                fam['mcast'][mcast_name] = mcast_id
305                if 'name' in fam and 'id' in fam:
306                    genl_family_name_to_id[fam['name']] = fam
307
308
309class GenlMsg:
310    def __init__(self, nl_msg):
311        self.nl = nl_msg
312        self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
313        self.raw = nl_msg.raw[4:]
314
315    def cmd(self):
316        return self.genl_cmd
317
318    def __repr__(self):
319        msg = repr(self.nl)
320        msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
321        for a in self.raw_attrs:
322            msg += '\t\t' + repr(a) + '\n'
323        return msg
324
325
326class NetlinkProtocol:
327    def __init__(self, family_name, proto_num):
328        self.family_name = family_name
329        self.proto_num = proto_num
330
331    def _message(self, nl_type, nl_flags, seq=None):
332        if seq is None:
333            seq = random.randint(1, 1024)
334        nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
335        return nlmsg
336
337    def message(self, flags, command, version, seq=None):
338        return self._message(command, flags, seq)
339
340    def _decode(self, nl_msg):
341        return nl_msg
342
343    def decode(self, ynl, nl_msg):
344        msg = self._decode(nl_msg)
345        fixed_header_size = 0
346        if ynl:
347            op = ynl.rsp_by_value[msg.cmd()]
348            fixed_header_size = ynl._struct_size(op.fixed_header)
349        msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size)
350        return msg
351
352    def get_mcast_id(self, mcast_name, mcast_groups):
353        if mcast_name not in mcast_groups:
354            raise Exception(f'Multicast group "{mcast_name}" not present in the spec')
355        return mcast_groups[mcast_name].value
356
357    def msghdr_size(self):
358        return 16
359
360
361class GenlProtocol(NetlinkProtocol):
362    def __init__(self, family_name):
363        super().__init__(family_name, Netlink.NETLINK_GENERIC)
364
365        global genl_family_name_to_id
366        if genl_family_name_to_id is None:
367            _genl_load_families()
368
369        self.genl_family = genl_family_name_to_id[family_name]
370        self.family_id = genl_family_name_to_id[family_name]['id']
371
372    def message(self, flags, command, version, seq=None):
373        nlmsg = self._message(self.family_id, flags, seq)
374        genlmsg = struct.pack("BBH", command, version, 0)
375        return nlmsg + genlmsg
376
377    def _decode(self, nl_msg):
378        return GenlMsg(nl_msg)
379
380    def get_mcast_id(self, mcast_name, mcast_groups):
381        if mcast_name not in self.genl_family['mcast']:
382            raise Exception(f'Multicast group "{mcast_name}" not present in the family')
383        return self.genl_family['mcast'][mcast_name]
384
385    def msghdr_size(self):
386        return super().msghdr_size() + 4
387
388
389class SpaceAttrs:
390    SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values'])
391
392    def __init__(self, attr_space, attrs, outer = None):
393        outer_scopes = outer.scopes if outer else []
394        inner_scope = self.SpecValuesPair(attr_space, attrs)
395        self.scopes = [inner_scope] + outer_scopes
396
397    def lookup(self, name):
398        for scope in self.scopes:
399            if name in scope.spec:
400                if name in scope.values:
401                    return scope.values[name]
402                spec_name = scope.spec.yaml['name']
403                raise Exception(
404                    f"No value for '{name}' in attribute space '{spec_name}'")
405        raise Exception(f"Attribute '{name}' not defined in any attribute-set")
406
407
408#
409# YNL implementation details.
410#
411
412
413class YnlFamily(SpecFamily):
414    def __init__(self, def_path, schema=None, process_unknown=False,
415                 recv_size=0):
416        super().__init__(def_path, schema)
417
418        self.include_raw = False
419        self.process_unknown = process_unknown
420
421        try:
422            if self.proto == "netlink-raw":
423                self.nlproto = NetlinkProtocol(self.yaml['name'],
424                                               self.yaml['protonum'])
425            else:
426                self.nlproto = GenlProtocol(self.yaml['name'])
427        except KeyError:
428            raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
429
430        self._recv_dbg = False
431        # Note that netlink will use conservative (min) message size for
432        # the first dump recv() on the socket, our setting will only matter
433        # from the second recv() on.
434        self._recv_size = recv_size if recv_size else 131072
435        # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo)
436        # for a message, so smaller receive sizes will lead to truncation.
437        # Note that the min size for other families may be larger than 4k!
438        if self._recv_size < 4000:
439            raise ConfigError()
440
441        self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
442        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
443        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
444        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)
445
446        self.async_msg_ids = set()
447        self.async_msg_queue = []
448
449        for msg in self.msgs.values():
450            if msg.is_async:
451                self.async_msg_ids.add(msg.rsp_value)
452
453        for op_name, op in self.ops.items():
454            bound_f = functools.partial(self._op, op_name)
455            setattr(self, op.ident_name, bound_f)
456
457
458    def ntf_subscribe(self, mcast_name):
459        mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
460        self.sock.bind((0, 0))
461        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
462                             mcast_id)
463
464    def set_recv_dbg(self, enabled):
465        self._recv_dbg = enabled
466
467    def _recv_dbg_print(self, reply, nl_msgs):
468        if not self._recv_dbg:
469            return
470        print("Recv: read", len(reply), "bytes,",
471              len(nl_msgs.msgs), "messages", file=sys.stderr)
472        for nl_msg in nl_msgs:
473            print("  ", nl_msg, file=sys.stderr)
474
475    def _encode_enum(self, attr_spec, value):
476        enum = self.consts[attr_spec['enum']]
477        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
478            scalar = 0
479            if isinstance(value, str):
480                value = [value]
481            for single_value in value:
482                scalar += enum.entries[single_value].user_value(as_flags = True)
483            return scalar
484        else:
485            return enum.entries[value].user_value()
486
487    def _get_scalar(self, attr_spec, value):
488        try:
489            return int(value)
490        except (ValueError, TypeError) as e:
491            if 'enum' not in attr_spec:
492                raise e
493        return self._encode_enum(attr_spec, value)
494
495    def _add_attr(self, space, name, value, search_attrs):
496        try:
497            attr = self.attr_sets[space][name]
498        except KeyError:
499            raise Exception(f"Space '{space}' has no attribute '{name}'")
500        nl_type = attr.value
501
502        if attr.is_multi and isinstance(value, list):
503            attr_payload = b''
504            for subvalue in value:
505                attr_payload += self._add_attr(space, name, subvalue, search_attrs)
506            return attr_payload
507
508        if attr["type"] == 'nest':
509            nl_type |= Netlink.NLA_F_NESTED
510            attr_payload = b''
511            sub_attrs = SpaceAttrs(self.attr_sets[space], value, search_attrs)
512            for subname, subvalue in value.items():
513                attr_payload += self._add_attr(attr['nested-attributes'],
514                                               subname, subvalue, sub_attrs)
515        elif attr["type"] == 'flag':
516            if not value:
517                # If value is absent or false then skip attribute creation.
518                return b''
519            attr_payload = b''
520        elif attr["type"] == 'string':
521            attr_payload = str(value).encode('ascii') + b'\x00'
522        elif attr["type"] == 'binary':
523            if isinstance(value, bytes):
524                attr_payload = value
525            elif isinstance(value, str):
526                attr_payload = bytes.fromhex(value)
527            elif isinstance(value, dict) and attr.struct_name:
528                attr_payload = self._encode_struct(attr.struct_name, value)
529            else:
530                raise Exception(f'Unknown type for binary attribute, value: {value}')
531        elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar:
532            scalar = self._get_scalar(attr, value)
533            if attr.is_auto_scalar:
534                attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64')
535            else:
536                attr_type = attr["type"]
537            format = NlAttr.get_format(attr_type, attr.byte_order)
538            attr_payload = format.pack(scalar)
539        elif attr['type'] in "bitfield32":
540            scalar_value = self._get_scalar(attr, value["value"])
541            scalar_selector = self._get_scalar(attr, value["selector"])
542            attr_payload = struct.pack("II", scalar_value, scalar_selector)
543        elif attr['type'] == 'sub-message':
544            msg_format = self._resolve_selector(attr, search_attrs)
545            attr_payload = b''
546            if msg_format.fixed_header:
547                attr_payload += self._encode_struct(msg_format.fixed_header, value)
548            if msg_format.attr_set:
549                if msg_format.attr_set in self.attr_sets:
550                    nl_type |= Netlink.NLA_F_NESTED
551                    sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs)
552                    for subname, subvalue in value.items():
553                        attr_payload += self._add_attr(msg_format.attr_set,
554                                                       subname, subvalue, sub_attrs)
555                else:
556                    raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'")
557        else:
558            raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
559
560        pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
561        return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
562
563    def _decode_enum(self, raw, attr_spec):
564        enum = self.consts[attr_spec['enum']]
565        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
566            i = 0
567            value = set()
568            while raw:
569                if raw & 1:
570                    value.add(enum.entries_by_val[i].name)
571                raw >>= 1
572                i += 1
573        else:
574            value = enum.entries_by_val[raw].name
575        return value
576
577    def _decode_binary(self, attr, attr_spec):
578        if attr_spec.struct_name:
579            decoded = self._decode_struct(attr.raw, attr_spec.struct_name)
580        elif attr_spec.sub_type:
581            decoded = attr.as_c_array(attr_spec.sub_type)
582        else:
583            decoded = attr.as_bin()
584            if attr_spec.display_hint:
585                decoded = self._formatted_string(decoded, attr_spec.display_hint)
586        return decoded
587
588    def _decode_array_nest(self, attr, attr_spec):
589        decoded = []
590        offset = 0
591        while offset < len(attr.raw):
592            item = NlAttr(attr.raw, offset)
593            offset += item.full_len
594
595            subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes'])
596            decoded.append({ item.type: subattrs })
597        return decoded
598
599    def _decode_nest_type_value(self, attr, attr_spec):
600        decoded = {}
601        value = attr
602        for name in attr_spec['type-value']:
603            value = NlAttr(value.raw, 0)
604            decoded[name] = value.type
605        subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes'])
606        decoded.update(subattrs)
607        return decoded
608
609    def _decode_unknown(self, attr):
610        if attr.is_nest:
611            return self._decode(NlAttrs(attr.raw), None)
612        else:
613            return attr.as_bin()
614
615    def _rsp_add(self, rsp, name, is_multi, decoded):
616        if is_multi == None:
617            if name in rsp and type(rsp[name]) is not list:
618                rsp[name] = [rsp[name]]
619                is_multi = True
620            else:
621                is_multi = False
622
623        if not is_multi:
624            rsp[name] = decoded
625        elif name in rsp:
626            rsp[name].append(decoded)
627        else:
628            rsp[name] = [decoded]
629
630    def _resolve_selector(self, attr_spec, search_attrs):
631        sub_msg = attr_spec.sub_message
632        if sub_msg not in self.sub_msgs:
633            raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}")
634        sub_msg_spec = self.sub_msgs[sub_msg]
635
636        selector = attr_spec.selector
637        value = search_attrs.lookup(selector)
638        if value not in sub_msg_spec.formats:
639            raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'")
640
641        spec = sub_msg_spec.formats[value]
642        return spec
643
644    def _decode_sub_msg(self, attr, attr_spec, search_attrs):
645        msg_format = self._resolve_selector(attr_spec, search_attrs)
646        decoded = {}
647        offset = 0
648        if msg_format.fixed_header:
649            decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header));
650            offset = self._struct_size(msg_format.fixed_header)
651        if msg_format.attr_set:
652            if msg_format.attr_set in self.attr_sets:
653                subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set)
654                decoded.update(subdict)
655            else:
656                raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'")
657        return decoded
658
659    def _decode(self, attrs, space, outer_attrs = None):
660        rsp = dict()
661        if space:
662            attr_space = self.attr_sets[space]
663            search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs)
664
665        for attr in attrs:
666            try:
667                attr_spec = attr_space.attrs_by_val[attr.type]
668            except (KeyError, UnboundLocalError):
669                if not self.process_unknown:
670                    raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'")
671                attr_name = f"UnknownAttr({attr.type})"
672                self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr))
673                continue
674
675            if attr_spec["type"] == 'nest':
676                subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs)
677                decoded = subdict
678            elif attr_spec["type"] == 'string':
679                decoded = attr.as_strz()
680            elif attr_spec["type"] == 'binary':
681                decoded = self._decode_binary(attr, attr_spec)
682            elif attr_spec["type"] == 'flag':
683                decoded = True
684            elif attr_spec.is_auto_scalar:
685                decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order)
686            elif attr_spec["type"] in NlAttr.type_formats:
687                decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
688                if 'enum' in attr_spec:
689                    decoded = self._decode_enum(decoded, attr_spec)
690            elif attr_spec["type"] == 'array-nest':
691                decoded = self._decode_array_nest(attr, attr_spec)
692            elif attr_spec["type"] == 'bitfield32':
693                value, selector = struct.unpack("II", attr.raw)
694                if 'enum' in attr_spec:
695                    value = self._decode_enum(value, attr_spec)
696                    selector = self._decode_enum(selector, attr_spec)
697                decoded = {"value": value, "selector": selector}
698            elif attr_spec["type"] == 'sub-message':
699                decoded = self._decode_sub_msg(attr, attr_spec, search_attrs)
700            elif attr_spec["type"] == 'nest-type-value':
701                decoded = self._decode_nest_type_value(attr, attr_spec)
702            else:
703                if not self.process_unknown:
704                    raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
705                decoded = self._decode_unknown(attr)
706
707            self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded)
708
709        return rsp
710
711    def _decode_extack_path(self, attrs, attr_set, offset, target):
712        for attr in attrs:
713            try:
714                attr_spec = attr_set.attrs_by_val[attr.type]
715            except KeyError:
716                raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'")
717            if offset > target:
718                break
719            if offset == target:
720                return '.' + attr_spec.name
721
722            if offset + attr.full_len <= target:
723                offset += attr.full_len
724                continue
725            if attr_spec['type'] != 'nest':
726                raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
727            offset += 4
728            subpath = self._decode_extack_path(NlAttrs(attr.raw),
729                                               self.attr_sets[attr_spec['nested-attributes']],
730                                               offset, target)
731            if subpath is None:
732                return None
733            return '.' + attr_spec.name + subpath
734
735        return None
736
737    def _decode_extack(self, request, op, extack):
738        if 'bad-attr-offs' not in extack:
739            return
740
741        msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set))
742        offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header)
743        path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
744                                        extack['bad-attr-offs'])
745        if path:
746            del extack['bad-attr-offs']
747            extack['bad-attr'] = path
748
749    def _struct_size(self, name):
750        if name:
751            members = self.consts[name].members
752            size = 0
753            for m in members:
754                if m.type in ['pad', 'binary']:
755                    if m.struct:
756                        size += self._struct_size(m.struct)
757                    else:
758                        size += m.len
759                else:
760                    format = NlAttr.get_format(m.type, m.byte_order)
761                    size += format.size
762            return size
763        else:
764            return 0
765
766    def _decode_struct(self, data, name):
767        members = self.consts[name].members
768        attrs = dict()
769        offset = 0
770        for m in members:
771            value = None
772            if m.type == 'pad':
773                offset += m.len
774            elif m.type == 'binary':
775                if m.struct:
776                    len = self._struct_size(m.struct)
777                    value = self._decode_struct(data[offset : offset + len],
778                                                m.struct)
779                    offset += len
780                else:
781                    value = data[offset : offset + m.len]
782                    offset += m.len
783            else:
784                format = NlAttr.get_format(m.type, m.byte_order)
785                [ value ] = format.unpack_from(data, offset)
786                offset += format.size
787            if value is not None:
788                if m.enum:
789                    value = self._decode_enum(value, m)
790                elif m.display_hint:
791                    value = self._formatted_string(value, m.display_hint)
792                attrs[m.name] = value
793        return attrs
794
795    def _encode_struct(self, name, vals):
796        members = self.consts[name].members
797        attr_payload = b''
798        for m in members:
799            value = vals.pop(m.name) if m.name in vals else None
800            if m.type == 'pad':
801                attr_payload += bytearray(m.len)
802            elif m.type == 'binary':
803                if m.struct:
804                    if value is None:
805                        value = dict()
806                    attr_payload += self._encode_struct(m.struct, value)
807                else:
808                    if value is None:
809                        attr_payload += bytearray(m.len)
810                    else:
811                        attr_payload += bytes.fromhex(value)
812            else:
813                if value is None:
814                    value = 0
815                format = NlAttr.get_format(m.type, m.byte_order)
816                attr_payload += format.pack(value)
817        return attr_payload
818
819    def _formatted_string(self, raw, display_hint):
820        if display_hint == 'mac':
821            formatted = ':'.join('%02x' % b for b in raw)
822        elif display_hint == 'hex':
823            formatted = bytes.hex(raw, ' ')
824        elif display_hint in [ 'ipv4', 'ipv6' ]:
825            formatted = format(ipaddress.ip_address(raw))
826        elif display_hint == 'uuid':
827            formatted = str(uuid.UUID(bytes=raw))
828        else:
829            formatted = raw
830        return formatted
831
832    def handle_ntf(self, decoded):
833        msg = dict()
834        if self.include_raw:
835            msg['raw'] = decoded
836        op = self.rsp_by_value[decoded.cmd()]
837        attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
838        if op.fixed_header:
839            attrs.update(self._decode_struct(decoded.raw, op.fixed_header))
840
841        msg['name'] = op['name']
842        msg['msg'] = attrs
843        self.async_msg_queue.append(msg)
844
845    def check_ntf(self):
846        while True:
847            try:
848                reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT)
849            except BlockingIOError:
850                return
851
852            nms = NlMsgs(reply)
853            self._recv_dbg_print(reply, nms)
854            for nl_msg in nms:
855                if nl_msg.error:
856                    print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
857                    print(nl_msg)
858                    continue
859                if nl_msg.done:
860                    print("Netlink done while checking for ntf!?")
861                    continue
862
863                decoded = self.nlproto.decode(self, nl_msg)
864                if decoded.cmd() not in self.async_msg_ids:
865                    print("Unexpected msg id done while checking for ntf", decoded)
866                    continue
867
868                self.handle_ntf(decoded)
869
870    def operation_do_attributes(self, name):
871      """
872      For a given operation name, find and return a supported
873      set of attributes (as a dict).
874      """
875      op = self.find_operation(name)
876      if not op:
877        return None
878
879      return op['do']['request']['attributes'].copy()
880
881    def _op(self, method, vals, flags=None, dump=False):
882        op = self.ops[method]
883
884        nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
885        for flag in flags or []:
886            nl_flags |= flag
887        if dump:
888            nl_flags |= Netlink.NLM_F_DUMP
889
890        req_seq = random.randint(1024, 65535)
891        msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
892        if op.fixed_header:
893            msg += self._encode_struct(op.fixed_header, vals)
894        search_attrs = SpaceAttrs(op.attr_set, vals)
895        for name, value in vals.items():
896            msg += self._add_attr(op.attr_set.name, name, value, search_attrs)
897        msg = _genl_msg_finalize(msg)
898
899        self.sock.send(msg, 0)
900
901        done = False
902        rsp = []
903        while not done:
904            reply = self.sock.recv(self._recv_size)
905            nms = NlMsgs(reply, attr_space=op.attr_set)
906            self._recv_dbg_print(reply, nms)
907            for nl_msg in nms:
908                if nl_msg.extack:
909                    self._decode_extack(msg, op, nl_msg.extack)
910
911                if nl_msg.error:
912                    raise NlError(nl_msg)
913                if nl_msg.done:
914                    if nl_msg.extack:
915                        print("Netlink warning:")
916                        print(nl_msg)
917                    done = True
918                    break
919
920                decoded = self.nlproto.decode(self, nl_msg)
921
922                # Check if this is a reply to our request
923                if nl_msg.nl_seq != req_seq or decoded.cmd() != op.rsp_value:
924                    if decoded.cmd() in self.async_msg_ids:
925                        self.handle_ntf(decoded)
926                        continue
927                    else:
928                        print('Unexpected message: ' + repr(decoded))
929                        continue
930
931                rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
932                if op.fixed_header:
933                    rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header))
934                rsp.append(rsp_msg)
935
936        if not rsp:
937            return None
938        if not dump and len(rsp) == 1:
939            return rsp[0]
940        return rsp
941
942    def do(self, method, vals, flags=None):
943        return self._op(method, vals, flags)
944
945    def dump(self, method, vals):
946        return self._op(method, vals, [], dump=True)
947