xref: /linux/tools/net/ynl/lib/ynl.py (revision e6a901a00822659181c93c86d8bbc2a17779fddc)
1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2
3from collections import namedtuple
4import functools
5import os
6import random
7import socket
8import struct
9from struct import Struct
10import sys
11import yaml
12import ipaddress
13import uuid
14
15from .nlspec import SpecFamily
16
17#
18# Generic Netlink code which should really be in some library, but I can't quickly find one.
19#
20
21
22class Netlink:
23    # Netlink socket
24    SOL_NETLINK = 270
25
26    NETLINK_ADD_MEMBERSHIP = 1
27    NETLINK_CAP_ACK = 10
28    NETLINK_EXT_ACK = 11
29    NETLINK_GET_STRICT_CHK = 12
30
31    # Netlink message
32    NLMSG_ERROR = 2
33    NLMSG_DONE = 3
34
35    NLM_F_REQUEST = 1
36    NLM_F_ACK = 4
37    NLM_F_ROOT = 0x100
38    NLM_F_MATCH = 0x200
39
40    NLM_F_REPLACE = 0x100
41    NLM_F_EXCL = 0x200
42    NLM_F_CREATE = 0x400
43    NLM_F_APPEND = 0x800
44
45    NLM_F_CAPPED = 0x100
46    NLM_F_ACK_TLVS = 0x200
47
48    NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
49
50    NLA_F_NESTED = 0x8000
51    NLA_F_NET_BYTEORDER = 0x4000
52
53    NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
54
55    # Genetlink defines
56    NETLINK_GENERIC = 16
57
58    GENL_ID_CTRL = 0x10
59
60    # nlctrl
61    CTRL_CMD_GETFAMILY = 3
62
63    CTRL_ATTR_FAMILY_ID = 1
64    CTRL_ATTR_FAMILY_NAME = 2
65    CTRL_ATTR_MAXATTR = 5
66    CTRL_ATTR_MCAST_GROUPS = 7
67
68    CTRL_ATTR_MCAST_GRP_NAME = 1
69    CTRL_ATTR_MCAST_GRP_ID = 2
70
71    # Extack types
72    NLMSGERR_ATTR_MSG = 1
73    NLMSGERR_ATTR_OFFS = 2
74    NLMSGERR_ATTR_COOKIE = 3
75    NLMSGERR_ATTR_POLICY = 4
76    NLMSGERR_ATTR_MISS_TYPE = 5
77    NLMSGERR_ATTR_MISS_NEST = 6
78
79
80class NlError(Exception):
81  def __init__(self, nl_msg):
82    self.nl_msg = nl_msg
83
84  def __str__(self):
85    return f"Netlink error: {os.strerror(-self.nl_msg.error)}\n{self.nl_msg}"
86
87
88class ConfigError(Exception):
89    pass
90
91
92class NlAttr:
93    ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
94    type_formats = {
95        'u8' : ScalarFormat(Struct('B'), Struct("B"),  Struct("B")),
96        's8' : ScalarFormat(Struct('b'), Struct("b"),  Struct("b")),
97        'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
98        's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
99        'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
100        's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
101        'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
102        's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
103    }
104
105    def __init__(self, raw, offset):
106        self._len, self._type = struct.unpack("HH", raw[offset : offset + 4])
107        self.type = self._type & ~Netlink.NLA_TYPE_MASK
108        self.is_nest = self._type & Netlink.NLA_F_NESTED
109        self.payload_len = self._len
110        self.full_len = (self.payload_len + 3) & ~3
111        self.raw = raw[offset + 4 : offset + self.payload_len]
112
113    @classmethod
114    def get_format(cls, attr_type, byte_order=None):
115        format = cls.type_formats[attr_type]
116        if byte_order:
117            return format.big if byte_order == "big-endian" \
118                else format.little
119        return format.native
120
121    def as_scalar(self, attr_type, byte_order=None):
122        format = self.get_format(attr_type, byte_order)
123        return format.unpack(self.raw)[0]
124
125    def as_auto_scalar(self, attr_type, byte_order=None):
126        if len(self.raw) != 4 and len(self.raw) != 8:
127            raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}")
128        real_type = attr_type[0] + str(len(self.raw) * 8)
129        format = self.get_format(real_type, byte_order)
130        return format.unpack(self.raw)[0]
131
132    def as_strz(self):
133        return self.raw.decode('ascii')[:-1]
134
135    def as_bin(self):
136        return self.raw
137
138    def as_c_array(self, type):
139        format = self.get_format(type)
140        return [ x[0] for x in format.iter_unpack(self.raw) ]
141
142    def __repr__(self):
143        return f"[type:{self.type} len:{self._len}] {self.raw}"
144
145
146class NlAttrs:
147    def __init__(self, msg, offset=0):
148        self.attrs = []
149
150        while offset < len(msg):
151            attr = NlAttr(msg, offset)
152            offset += attr.full_len
153            self.attrs.append(attr)
154
155    def __iter__(self):
156        yield from self.attrs
157
158    def __repr__(self):
159        msg = ''
160        for a in self.attrs:
161            if msg:
162                msg += '\n'
163            msg += repr(a)
164        return msg
165
166
167class NlMsg:
168    def __init__(self, msg, offset, attr_space=None):
169        self.hdr = msg[offset : offset + 16]
170
171        self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
172            struct.unpack("IHHII", self.hdr)
173
174        self.raw = msg[offset + 16 : offset + self.nl_len]
175
176        self.error = 0
177        self.done = 0
178
179        extack_off = None
180        if self.nl_type == Netlink.NLMSG_ERROR:
181            self.error = struct.unpack("i", self.raw[0:4])[0]
182            self.done = 1
183            extack_off = 20
184        elif self.nl_type == Netlink.NLMSG_DONE:
185            self.done = 1
186            extack_off = 4
187
188        self.extack = None
189        if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
190            self.extack = dict()
191            extack_attrs = NlAttrs(self.raw[extack_off:])
192            for extack in extack_attrs:
193                if extack.type == Netlink.NLMSGERR_ATTR_MSG:
194                    self.extack['msg'] = extack.as_strz()
195                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
196                    self.extack['miss-type'] = extack.as_scalar('u32')
197                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
198                    self.extack['miss-nest'] = extack.as_scalar('u32')
199                elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
200                    self.extack['bad-attr-offs'] = extack.as_scalar('u32')
201                else:
202                    if 'unknown' not in self.extack:
203                        self.extack['unknown'] = []
204                    self.extack['unknown'].append(extack)
205
206            if attr_space:
207                # We don't have the ability to parse nests yet, so only do global
208                if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
209                    miss_type = self.extack['miss-type']
210                    if miss_type in attr_space.attrs_by_val:
211                        spec = attr_space.attrs_by_val[miss_type]
212                        desc = spec['name']
213                        if 'doc' in spec:
214                            desc += f" ({spec['doc']})"
215                        self.extack['miss-type'] = desc
216
217    def cmd(self):
218        return self.nl_type
219
220    def __repr__(self):
221        msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}"
222        if self.error:
223            msg += '\n\terror: ' + str(self.error)
224        if self.extack:
225            msg += '\n\textack: ' + repr(self.extack)
226        return msg
227
228
229class NlMsgs:
230    def __init__(self, data, attr_space=None):
231        self.msgs = []
232
233        offset = 0
234        while offset < len(data):
235            msg = NlMsg(data, offset, attr_space=attr_space)
236            offset += msg.nl_len
237            self.msgs.append(msg)
238
239    def __iter__(self):
240        yield from self.msgs
241
242
243genl_family_name_to_id = None
244
245
246def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
247    # we prepend length in _genl_msg_finalize()
248    if seq is None:
249        seq = random.randint(1, 1024)
250    nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
251    genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
252    return nlmsg + genlmsg
253
254
255def _genl_msg_finalize(msg):
256    return struct.pack("I", len(msg) + 4) + msg
257
258
259def _genl_load_families():
260    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
261        sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
262
263        msg = _genl_msg(Netlink.GENL_ID_CTRL,
264                        Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
265                        Netlink.CTRL_CMD_GETFAMILY, 1)
266        msg = _genl_msg_finalize(msg)
267
268        sock.send(msg, 0)
269
270        global genl_family_name_to_id
271        genl_family_name_to_id = dict()
272
273        while True:
274            reply = sock.recv(128 * 1024)
275            nms = NlMsgs(reply)
276            for nl_msg in nms:
277                if nl_msg.error:
278                    print("Netlink error:", nl_msg.error)
279                    return
280                if nl_msg.done:
281                    return
282
283                gm = GenlMsg(nl_msg)
284                fam = dict()
285                for attr in NlAttrs(gm.raw):
286                    if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
287                        fam['id'] = attr.as_scalar('u16')
288                    elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
289                        fam['name'] = attr.as_strz()
290                    elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
291                        fam['maxattr'] = attr.as_scalar('u32')
292                    elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
293                        fam['mcast'] = dict()
294                        for entry in NlAttrs(attr.raw):
295                            mcast_name = None
296                            mcast_id = None
297                            for entry_attr in NlAttrs(entry.raw):
298                                if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
299                                    mcast_name = entry_attr.as_strz()
300                                elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
301                                    mcast_id = entry_attr.as_scalar('u32')
302                            if mcast_name and mcast_id is not None:
303                                fam['mcast'][mcast_name] = mcast_id
304                if 'name' in fam and 'id' in fam:
305                    genl_family_name_to_id[fam['name']] = fam
306
307
308class GenlMsg:
309    def __init__(self, nl_msg):
310        self.nl = nl_msg
311        self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
312        self.raw = nl_msg.raw[4:]
313
314    def cmd(self):
315        return self.genl_cmd
316
317    def __repr__(self):
318        msg = repr(self.nl)
319        msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
320        for a in self.raw_attrs:
321            msg += '\t\t' + repr(a) + '\n'
322        return msg
323
324
325class NetlinkProtocol:
326    def __init__(self, family_name, proto_num):
327        self.family_name = family_name
328        self.proto_num = proto_num
329
330    def _message(self, nl_type, nl_flags, seq=None):
331        if seq is None:
332            seq = random.randint(1, 1024)
333        nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
334        return nlmsg
335
336    def message(self, flags, command, version, seq=None):
337        return self._message(command, flags, seq)
338
339    def _decode(self, nl_msg):
340        return nl_msg
341
342    def decode(self, ynl, nl_msg):
343        msg = self._decode(nl_msg)
344        fixed_header_size = 0
345        if ynl:
346            op = ynl.rsp_by_value[msg.cmd()]
347            fixed_header_size = ynl._struct_size(op.fixed_header)
348        msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size)
349        return msg
350
351    def get_mcast_id(self, mcast_name, mcast_groups):
352        if mcast_name not in mcast_groups:
353            raise Exception(f'Multicast group "{mcast_name}" not present in the spec')
354        return mcast_groups[mcast_name].value
355
356    def msghdr_size(self):
357        return 16
358
359
360class GenlProtocol(NetlinkProtocol):
361    def __init__(self, family_name):
362        super().__init__(family_name, Netlink.NETLINK_GENERIC)
363
364        global genl_family_name_to_id
365        if genl_family_name_to_id is None:
366            _genl_load_families()
367
368        self.genl_family = genl_family_name_to_id[family_name]
369        self.family_id = genl_family_name_to_id[family_name]['id']
370
371    def message(self, flags, command, version, seq=None):
372        nlmsg = self._message(self.family_id, flags, seq)
373        genlmsg = struct.pack("BBH", command, version, 0)
374        return nlmsg + genlmsg
375
376    def _decode(self, nl_msg):
377        return GenlMsg(nl_msg)
378
379    def get_mcast_id(self, mcast_name, mcast_groups):
380        if mcast_name not in self.genl_family['mcast']:
381            raise Exception(f'Multicast group "{mcast_name}" not present in the family')
382        return self.genl_family['mcast'][mcast_name]
383
384    def msghdr_size(self):
385        return super().msghdr_size() + 4
386
387
388class SpaceAttrs:
389    SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values'])
390
391    def __init__(self, attr_space, attrs, outer = None):
392        outer_scopes = outer.scopes if outer else []
393        inner_scope = self.SpecValuesPair(attr_space, attrs)
394        self.scopes = [inner_scope] + outer_scopes
395
396    def lookup(self, name):
397        for scope in self.scopes:
398            if name in scope.spec:
399                if name in scope.values:
400                    return scope.values[name]
401                spec_name = scope.spec.yaml['name']
402                raise Exception(
403                    f"No value for '{name}' in attribute space '{spec_name}'")
404        raise Exception(f"Attribute '{name}' not defined in any attribute-set")
405
406
407#
408# YNL implementation details.
409#
410
411
412class YnlFamily(SpecFamily):
413    def __init__(self, def_path, schema=None, process_unknown=False,
414                 recv_size=0):
415        super().__init__(def_path, schema)
416
417        self.include_raw = False
418        self.process_unknown = process_unknown
419
420        try:
421            if self.proto == "netlink-raw":
422                self.nlproto = NetlinkProtocol(self.yaml['name'],
423                                               self.yaml['protonum'])
424            else:
425                self.nlproto = GenlProtocol(self.yaml['name'])
426        except KeyError:
427            raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
428
429        self._recv_dbg = False
430        # Note that netlink will use conservative (min) message size for
431        # the first dump recv() on the socket, our setting will only matter
432        # from the second recv() on.
433        self._recv_size = recv_size if recv_size else 131072
434        # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo)
435        # for a message, so smaller receive sizes will lead to truncation.
436        # Note that the min size for other families may be larger than 4k!
437        if self._recv_size < 4000:
438            raise ConfigError()
439
440        self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
441        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
442        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
443        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)
444
445        self.async_msg_ids = set()
446        self.async_msg_queue = []
447
448        for msg in self.msgs.values():
449            if msg.is_async:
450                self.async_msg_ids.add(msg.rsp_value)
451
452        for op_name, op in self.ops.items():
453            bound_f = functools.partial(self._op, op_name)
454            setattr(self, op.ident_name, bound_f)
455
456
457    def ntf_subscribe(self, mcast_name):
458        mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
459        self.sock.bind((0, 0))
460        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
461                             mcast_id)
462
463    def set_recv_dbg(self, enabled):
464        self._recv_dbg = enabled
465
466    def _recv_dbg_print(self, reply, nl_msgs):
467        if not self._recv_dbg:
468            return
469        print("Recv: read", len(reply), "bytes,",
470              len(nl_msgs.msgs), "messages", file=sys.stderr)
471        for nl_msg in nl_msgs:
472            print("  ", nl_msg, file=sys.stderr)
473
474    def _encode_enum(self, attr_spec, value):
475        enum = self.consts[attr_spec['enum']]
476        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
477            scalar = 0
478            if isinstance(value, str):
479                value = [value]
480            for single_value in value:
481                scalar += enum.entries[single_value].user_value(as_flags = True)
482            return scalar
483        else:
484            return enum.entries[value].user_value()
485
486    def _get_scalar(self, attr_spec, value):
487        try:
488            return int(value)
489        except (ValueError, TypeError) as e:
490            if 'enum' not in attr_spec:
491                raise e
492        return self._encode_enum(attr_spec, value)
493
494    def _add_attr(self, space, name, value, search_attrs):
495        try:
496            attr = self.attr_sets[space][name]
497        except KeyError:
498            raise Exception(f"Space '{space}' has no attribute '{name}'")
499        nl_type = attr.value
500
501        if attr.is_multi and isinstance(value, list):
502            attr_payload = b''
503            for subvalue in value:
504                attr_payload += self._add_attr(space, name, subvalue, search_attrs)
505            return attr_payload
506
507        if attr["type"] == 'nest':
508            nl_type |= Netlink.NLA_F_NESTED
509            attr_payload = b''
510            sub_attrs = SpaceAttrs(self.attr_sets[space], value, search_attrs)
511            for subname, subvalue in value.items():
512                attr_payload += self._add_attr(attr['nested-attributes'],
513                                               subname, subvalue, sub_attrs)
514        elif attr["type"] == 'flag':
515            if not value:
516                # If value is absent or false then skip attribute creation.
517                return b''
518            attr_payload = b''
519        elif attr["type"] == 'string':
520            attr_payload = str(value).encode('ascii') + b'\x00'
521        elif attr["type"] == 'binary':
522            if isinstance(value, bytes):
523                attr_payload = value
524            elif isinstance(value, str):
525                attr_payload = bytes.fromhex(value)
526            elif isinstance(value, dict) and attr.struct_name:
527                attr_payload = self._encode_struct(attr.struct_name, value)
528            else:
529                raise Exception(f'Unknown type for binary attribute, value: {value}')
530        elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar:
531            scalar = self._get_scalar(attr, value)
532            if attr.is_auto_scalar:
533                attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64')
534            else:
535                attr_type = attr["type"]
536            format = NlAttr.get_format(attr_type, attr.byte_order)
537            attr_payload = format.pack(scalar)
538        elif attr['type'] in "bitfield32":
539            scalar_value = self._get_scalar(attr, value["value"])
540            scalar_selector = self._get_scalar(attr, value["selector"])
541            attr_payload = struct.pack("II", scalar_value, scalar_selector)
542        elif attr['type'] == 'sub-message':
543            msg_format = self._resolve_selector(attr, search_attrs)
544            attr_payload = b''
545            if msg_format.fixed_header:
546                attr_payload += self._encode_struct(msg_format.fixed_header, value)
547            if msg_format.attr_set:
548                if msg_format.attr_set in self.attr_sets:
549                    nl_type |= Netlink.NLA_F_NESTED
550                    sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs)
551                    for subname, subvalue in value.items():
552                        attr_payload += self._add_attr(msg_format.attr_set,
553                                                       subname, subvalue, sub_attrs)
554                else:
555                    raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'")
556        else:
557            raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
558
559        pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
560        return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
561
562    def _decode_enum(self, raw, attr_spec):
563        enum = self.consts[attr_spec['enum']]
564        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
565            i = 0
566            value = set()
567            while raw:
568                if raw & 1:
569                    value.add(enum.entries_by_val[i].name)
570                raw >>= 1
571                i += 1
572        else:
573            value = enum.entries_by_val[raw].name
574        return value
575
576    def _decode_binary(self, attr, attr_spec):
577        if attr_spec.struct_name:
578            decoded = self._decode_struct(attr.raw, attr_spec.struct_name)
579        elif attr_spec.sub_type:
580            decoded = attr.as_c_array(attr_spec.sub_type)
581        else:
582            decoded = attr.as_bin()
583            if attr_spec.display_hint:
584                decoded = self._formatted_string(decoded, attr_spec.display_hint)
585        return decoded
586
587    def _decode_array_nest(self, attr, attr_spec):
588        decoded = []
589        offset = 0
590        while offset < len(attr.raw):
591            item = NlAttr(attr.raw, offset)
592            offset += item.full_len
593
594            subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes'])
595            decoded.append({ item.type: subattrs })
596        return decoded
597
598    def _decode_nest_type_value(self, attr, attr_spec):
599        decoded = {}
600        value = attr
601        for name in attr_spec['type-value']:
602            value = NlAttr(value.raw, 0)
603            decoded[name] = value.type
604        subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes'])
605        decoded.update(subattrs)
606        return decoded
607
608    def _decode_unknown(self, attr):
609        if attr.is_nest:
610            return self._decode(NlAttrs(attr.raw), None)
611        else:
612            return attr.as_bin()
613
614    def _rsp_add(self, rsp, name, is_multi, decoded):
615        if is_multi == None:
616            if name in rsp and type(rsp[name]) is not list:
617                rsp[name] = [rsp[name]]
618                is_multi = True
619            else:
620                is_multi = False
621
622        if not is_multi:
623            rsp[name] = decoded
624        elif name in rsp:
625            rsp[name].append(decoded)
626        else:
627            rsp[name] = [decoded]
628
629    def _resolve_selector(self, attr_spec, search_attrs):
630        sub_msg = attr_spec.sub_message
631        if sub_msg not in self.sub_msgs:
632            raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}")
633        sub_msg_spec = self.sub_msgs[sub_msg]
634
635        selector = attr_spec.selector
636        value = search_attrs.lookup(selector)
637        if value not in sub_msg_spec.formats:
638            raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'")
639
640        spec = sub_msg_spec.formats[value]
641        return spec
642
643    def _decode_sub_msg(self, attr, attr_spec, search_attrs):
644        msg_format = self._resolve_selector(attr_spec, search_attrs)
645        decoded = {}
646        offset = 0
647        if msg_format.fixed_header:
648            decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header));
649            offset = self._struct_size(msg_format.fixed_header)
650        if msg_format.attr_set:
651            if msg_format.attr_set in self.attr_sets:
652                subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set)
653                decoded.update(subdict)
654            else:
655                raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'")
656        return decoded
657
658    def _decode(self, attrs, space, outer_attrs = None):
659        rsp = dict()
660        if space:
661            attr_space = self.attr_sets[space]
662            search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs)
663
664        for attr in attrs:
665            try:
666                attr_spec = attr_space.attrs_by_val[attr.type]
667            except (KeyError, UnboundLocalError):
668                if not self.process_unknown:
669                    raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'")
670                attr_name = f"UnknownAttr({attr.type})"
671                self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr))
672                continue
673
674            if attr_spec["type"] == 'nest':
675                subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs)
676                decoded = subdict
677            elif attr_spec["type"] == 'string':
678                decoded = attr.as_strz()
679            elif attr_spec["type"] == 'binary':
680                decoded = self._decode_binary(attr, attr_spec)
681            elif attr_spec["type"] == 'flag':
682                decoded = True
683            elif attr_spec.is_auto_scalar:
684                decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order)
685            elif attr_spec["type"] in NlAttr.type_formats:
686                decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
687                if 'enum' in attr_spec:
688                    decoded = self._decode_enum(decoded, attr_spec)
689            elif attr_spec["type"] == 'array-nest':
690                decoded = self._decode_array_nest(attr, attr_spec)
691            elif attr_spec["type"] == 'bitfield32':
692                value, selector = struct.unpack("II", attr.raw)
693                if 'enum' in attr_spec:
694                    value = self._decode_enum(value, attr_spec)
695                    selector = self._decode_enum(selector, attr_spec)
696                decoded = {"value": value, "selector": selector}
697            elif attr_spec["type"] == 'sub-message':
698                decoded = self._decode_sub_msg(attr, attr_spec, search_attrs)
699            elif attr_spec["type"] == 'nest-type-value':
700                decoded = self._decode_nest_type_value(attr, attr_spec)
701            else:
702                if not self.process_unknown:
703                    raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
704                decoded = self._decode_unknown(attr)
705
706            self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded)
707
708        return rsp
709
710    def _decode_extack_path(self, attrs, attr_set, offset, target):
711        for attr in attrs:
712            try:
713                attr_spec = attr_set.attrs_by_val[attr.type]
714            except KeyError:
715                raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'")
716            if offset > target:
717                break
718            if offset == target:
719                return '.' + attr_spec.name
720
721            if offset + attr.full_len <= target:
722                offset += attr.full_len
723                continue
724            if attr_spec['type'] != 'nest':
725                raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
726            offset += 4
727            subpath = self._decode_extack_path(NlAttrs(attr.raw),
728                                               self.attr_sets[attr_spec['nested-attributes']],
729                                               offset, target)
730            if subpath is None:
731                return None
732            return '.' + attr_spec.name + subpath
733
734        return None
735
736    def _decode_extack(self, request, op, extack):
737        if 'bad-attr-offs' not in extack:
738            return
739
740        msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set))
741        offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header)
742        path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
743                                        extack['bad-attr-offs'])
744        if path:
745            del extack['bad-attr-offs']
746            extack['bad-attr'] = path
747
748    def _struct_size(self, name):
749        if name:
750            members = self.consts[name].members
751            size = 0
752            for m in members:
753                if m.type in ['pad', 'binary']:
754                    if m.struct:
755                        size += self._struct_size(m.struct)
756                    else:
757                        size += m.len
758                else:
759                    format = NlAttr.get_format(m.type, m.byte_order)
760                    size += format.size
761            return size
762        else:
763            return 0
764
765    def _decode_struct(self, data, name):
766        members = self.consts[name].members
767        attrs = dict()
768        offset = 0
769        for m in members:
770            value = None
771            if m.type == 'pad':
772                offset += m.len
773            elif m.type == 'binary':
774                if m.struct:
775                    len = self._struct_size(m.struct)
776                    value = self._decode_struct(data[offset : offset + len],
777                                                m.struct)
778                    offset += len
779                else:
780                    value = data[offset : offset + m.len]
781                    offset += m.len
782            else:
783                format = NlAttr.get_format(m.type, m.byte_order)
784                [ value ] = format.unpack_from(data, offset)
785                offset += format.size
786            if value is not None:
787                if m.enum:
788                    value = self._decode_enum(value, m)
789                elif m.display_hint:
790                    value = self._formatted_string(value, m.display_hint)
791                attrs[m.name] = value
792        return attrs
793
794    def _encode_struct(self, name, vals):
795        members = self.consts[name].members
796        attr_payload = b''
797        for m in members:
798            value = vals.pop(m.name) if m.name in vals else None
799            if m.type == 'pad':
800                attr_payload += bytearray(m.len)
801            elif m.type == 'binary':
802                if m.struct:
803                    if value is None:
804                        value = dict()
805                    attr_payload += self._encode_struct(m.struct, value)
806                else:
807                    if value is None:
808                        attr_payload += bytearray(m.len)
809                    else:
810                        attr_payload += bytes.fromhex(value)
811            else:
812                if value is None:
813                    value = 0
814                format = NlAttr.get_format(m.type, m.byte_order)
815                attr_payload += format.pack(value)
816        return attr_payload
817
818    def _formatted_string(self, raw, display_hint):
819        if display_hint == 'mac':
820            formatted = ':'.join('%02x' % b for b in raw)
821        elif display_hint == 'hex':
822            formatted = bytes.hex(raw, ' ')
823        elif display_hint in [ 'ipv4', 'ipv6' ]:
824            formatted = format(ipaddress.ip_address(raw))
825        elif display_hint == 'uuid':
826            formatted = str(uuid.UUID(bytes=raw))
827        else:
828            formatted = raw
829        return formatted
830
831    def handle_ntf(self, decoded):
832        msg = dict()
833        if self.include_raw:
834            msg['raw'] = decoded
835        op = self.rsp_by_value[decoded.cmd()]
836        attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
837        if op.fixed_header:
838            attrs.update(self._decode_struct(decoded.raw, op.fixed_header))
839
840        msg['name'] = op['name']
841        msg['msg'] = attrs
842        self.async_msg_queue.append(msg)
843
844    def check_ntf(self):
845        while True:
846            try:
847                reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT)
848            except BlockingIOError:
849                return
850
851            nms = NlMsgs(reply)
852            self._recv_dbg_print(reply, nms)
853            for nl_msg in nms:
854                if nl_msg.error:
855                    print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
856                    print(nl_msg)
857                    continue
858                if nl_msg.done:
859                    print("Netlink done while checking for ntf!?")
860                    continue
861
862                decoded = self.nlproto.decode(self, nl_msg)
863                if decoded.cmd() not in self.async_msg_ids:
864                    print("Unexpected msg id done while checking for ntf", decoded)
865                    continue
866
867                self.handle_ntf(decoded)
868
869    def operation_do_attributes(self, name):
870      """
871      For a given operation name, find and return a supported
872      set of attributes (as a dict).
873      """
874      op = self.find_operation(name)
875      if not op:
876        return None
877
878      return op['do']['request']['attributes'].copy()
879
880    def _op(self, method, vals, flags=None, dump=False):
881        op = self.ops[method]
882
883        nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
884        for flag in flags or []:
885            nl_flags |= flag
886        if dump:
887            nl_flags |= Netlink.NLM_F_DUMP
888
889        req_seq = random.randint(1024, 65535)
890        msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
891        if op.fixed_header:
892            msg += self._encode_struct(op.fixed_header, vals)
893        search_attrs = SpaceAttrs(op.attr_set, vals)
894        for name, value in vals.items():
895            msg += self._add_attr(op.attr_set.name, name, value, search_attrs)
896        msg = _genl_msg_finalize(msg)
897
898        self.sock.send(msg, 0)
899
900        done = False
901        rsp = []
902        while not done:
903            reply = self.sock.recv(self._recv_size)
904            nms = NlMsgs(reply, attr_space=op.attr_set)
905            self._recv_dbg_print(reply, nms)
906            for nl_msg in nms:
907                if nl_msg.extack:
908                    self._decode_extack(msg, op, nl_msg.extack)
909
910                if nl_msg.error:
911                    raise NlError(nl_msg)
912                if nl_msg.done:
913                    if nl_msg.extack:
914                        print("Netlink warning:")
915                        print(nl_msg)
916                    done = True
917                    break
918
919                decoded = self.nlproto.decode(self, nl_msg)
920
921                # Check if this is a reply to our request
922                if nl_msg.nl_seq != req_seq or decoded.cmd() != op.rsp_value:
923                    if decoded.cmd() in self.async_msg_ids:
924                        self.handle_ntf(decoded)
925                        continue
926                    else:
927                        print('Unexpected message: ' + repr(decoded))
928                        continue
929
930                rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
931                if op.fixed_header:
932                    rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header))
933                rsp.append(rsp_msg)
934
935        if not rsp:
936            return None
937        if not dump and len(rsp) == 1:
938            return rsp[0]
939        return rsp
940
941    def do(self, method, vals, flags=None):
942        return self._op(method, vals, flags)
943
944    def dump(self, method, vals):
945        return self._op(method, vals, [], dump=True)
946