xref: /linux/tools/net/ynl/lib/ynl.py (revision 2a52ca7c98960aafb0eca9ef96b2d0c932171357)
1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2
3from collections import namedtuple
4from enum import Enum
5import functools
6import os
7import random
8import socket
9import struct
10from struct import Struct
11import sys
12import yaml
13import ipaddress
14import uuid
15
16from .nlspec import SpecFamily
17
18#
19# Generic Netlink code which should really be in some library, but I can't quickly find one.
20#
21
22
23class Netlink:
24    # Netlink socket
25    SOL_NETLINK = 270
26
27    NETLINK_ADD_MEMBERSHIP = 1
28    NETLINK_CAP_ACK = 10
29    NETLINK_EXT_ACK = 11
30    NETLINK_GET_STRICT_CHK = 12
31
32    # Netlink message
33    NLMSG_ERROR = 2
34    NLMSG_DONE = 3
35
36    NLM_F_REQUEST = 1
37    NLM_F_ACK = 4
38    NLM_F_ROOT = 0x100
39    NLM_F_MATCH = 0x200
40
41    NLM_F_REPLACE = 0x100
42    NLM_F_EXCL = 0x200
43    NLM_F_CREATE = 0x400
44    NLM_F_APPEND = 0x800
45
46    NLM_F_CAPPED = 0x100
47    NLM_F_ACK_TLVS = 0x200
48
49    NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
50
51    NLA_F_NESTED = 0x8000
52    NLA_F_NET_BYTEORDER = 0x4000
53
54    NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
55
56    # Genetlink defines
57    NETLINK_GENERIC = 16
58
59    GENL_ID_CTRL = 0x10
60
61    # nlctrl
62    CTRL_CMD_GETFAMILY = 3
63
64    CTRL_ATTR_FAMILY_ID = 1
65    CTRL_ATTR_FAMILY_NAME = 2
66    CTRL_ATTR_MAXATTR = 5
67    CTRL_ATTR_MCAST_GROUPS = 7
68
69    CTRL_ATTR_MCAST_GRP_NAME = 1
70    CTRL_ATTR_MCAST_GRP_ID = 2
71
72    # Extack types
73    NLMSGERR_ATTR_MSG = 1
74    NLMSGERR_ATTR_OFFS = 2
75    NLMSGERR_ATTR_COOKIE = 3
76    NLMSGERR_ATTR_POLICY = 4
77    NLMSGERR_ATTR_MISS_TYPE = 5
78    NLMSGERR_ATTR_MISS_NEST = 6
79
80    # Policy types
81    NL_POLICY_TYPE_ATTR_TYPE = 1
82    NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2
83    NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3
84    NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4
85    NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5
86    NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6
87    NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7
88    NL_POLICY_TYPE_ATTR_POLICY_IDX = 8
89    NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9
90    NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10
91    NL_POLICY_TYPE_ATTR_PAD = 11
92    NL_POLICY_TYPE_ATTR_MASK = 12
93
94    AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64',
95                                  's8', 's16', 's32', 's64',
96                                  'binary', 'string', 'nul-string',
97                                  'nested', 'nested-array',
98                                  'bitfield32', 'sint', 'uint'])
99
100class NlError(Exception):
101  def __init__(self, nl_msg):
102    self.nl_msg = nl_msg
103    self.error = -nl_msg.error
104
105  def __str__(self):
106    return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}"
107
108
109class ConfigError(Exception):
110    pass
111
112
113class NlAttr:
114    ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
115    type_formats = {
116        'u8' : ScalarFormat(Struct('B'), Struct("B"),  Struct("B")),
117        's8' : ScalarFormat(Struct('b'), Struct("b"),  Struct("b")),
118        'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
119        's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
120        'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
121        's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
122        'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
123        's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
124    }
125
126    def __init__(self, raw, offset):
127        self._len, self._type = struct.unpack("HH", raw[offset : offset + 4])
128        self.type = self._type & ~Netlink.NLA_TYPE_MASK
129        self.is_nest = self._type & Netlink.NLA_F_NESTED
130        self.payload_len = self._len
131        self.full_len = (self.payload_len + 3) & ~3
132        self.raw = raw[offset + 4 : offset + self.payload_len]
133
134    @classmethod
135    def get_format(cls, attr_type, byte_order=None):
136        format = cls.type_formats[attr_type]
137        if byte_order:
138            return format.big if byte_order == "big-endian" \
139                else format.little
140        return format.native
141
142    def as_scalar(self, attr_type, byte_order=None):
143        format = self.get_format(attr_type, byte_order)
144        return format.unpack(self.raw)[0]
145
146    def as_auto_scalar(self, attr_type, byte_order=None):
147        if len(self.raw) != 4 and len(self.raw) != 8:
148            raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}")
149        real_type = attr_type[0] + str(len(self.raw) * 8)
150        format = self.get_format(real_type, byte_order)
151        return format.unpack(self.raw)[0]
152
153    def as_strz(self):
154        return self.raw.decode('ascii')[:-1]
155
156    def as_bin(self):
157        return self.raw
158
159    def as_c_array(self, type):
160        format = self.get_format(type)
161        return [ x[0] for x in format.iter_unpack(self.raw) ]
162
163    def __repr__(self):
164        return f"[type:{self.type} len:{self._len}] {self.raw}"
165
166
167class NlAttrs:
168    def __init__(self, msg, offset=0):
169        self.attrs = []
170
171        while offset < len(msg):
172            attr = NlAttr(msg, offset)
173            offset += attr.full_len
174            self.attrs.append(attr)
175
176    def __iter__(self):
177        yield from self.attrs
178
179    def __repr__(self):
180        msg = ''
181        for a in self.attrs:
182            if msg:
183                msg += '\n'
184            msg += repr(a)
185        return msg
186
187
188class NlMsg:
189    def __init__(self, msg, offset, attr_space=None):
190        self.hdr = msg[offset : offset + 16]
191
192        self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
193            struct.unpack("IHHII", self.hdr)
194
195        self.raw = msg[offset + 16 : offset + self.nl_len]
196
197        self.error = 0
198        self.done = 0
199
200        extack_off = None
201        if self.nl_type == Netlink.NLMSG_ERROR:
202            self.error = struct.unpack("i", self.raw[0:4])[0]
203            self.done = 1
204            extack_off = 20
205        elif self.nl_type == Netlink.NLMSG_DONE:
206            self.error = struct.unpack("i", self.raw[0:4])[0]
207            self.done = 1
208            extack_off = 4
209
210        self.extack = None
211        if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
212            self.extack = dict()
213            extack_attrs = NlAttrs(self.raw[extack_off:])
214            for extack in extack_attrs:
215                if extack.type == Netlink.NLMSGERR_ATTR_MSG:
216                    self.extack['msg'] = extack.as_strz()
217                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
218                    self.extack['miss-type'] = extack.as_scalar('u32')
219                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
220                    self.extack['miss-nest'] = extack.as_scalar('u32')
221                elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
222                    self.extack['bad-attr-offs'] = extack.as_scalar('u32')
223                elif extack.type == Netlink.NLMSGERR_ATTR_POLICY:
224                    self.extack['policy'] = self._decode_policy(extack.raw)
225                else:
226                    if 'unknown' not in self.extack:
227                        self.extack['unknown'] = []
228                    self.extack['unknown'].append(extack)
229
230            if attr_space:
231                # We don't have the ability to parse nests yet, so only do global
232                if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
233                    miss_type = self.extack['miss-type']
234                    if miss_type in attr_space.attrs_by_val:
235                        spec = attr_space.attrs_by_val[miss_type]
236                        self.extack['miss-type'] = spec['name']
237                        if 'doc' in spec:
238                            self.extack['miss-type-doc'] = spec['doc']
239
240    def _decode_policy(self, raw):
241        policy = {}
242        for attr in NlAttrs(raw):
243            if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE:
244                type = attr.as_scalar('u32')
245                policy['type'] = Netlink.AttrType(type).name
246            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S:
247                policy['min-value'] = attr.as_scalar('s64')
248            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S:
249                policy['max-value'] = attr.as_scalar('s64')
250            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U:
251                policy['min-value'] = attr.as_scalar('u64')
252            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U:
253                policy['max-value'] = attr.as_scalar('u64')
254            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH:
255                policy['min-length'] = attr.as_scalar('u32')
256            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH:
257                policy['max-length'] = attr.as_scalar('u32')
258            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK:
259                policy['bitfield32-mask'] = attr.as_scalar('u32')
260            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK:
261                policy['mask'] = attr.as_scalar('u64')
262        return policy
263
264    def cmd(self):
265        return self.nl_type
266
267    def __repr__(self):
268        msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}"
269        if self.error:
270            msg += '\n\terror: ' + str(self.error)
271        if self.extack:
272            msg += '\n\textack: ' + repr(self.extack)
273        return msg
274
275
276class NlMsgs:
277    def __init__(self, data, attr_space=None):
278        self.msgs = []
279
280        offset = 0
281        while offset < len(data):
282            msg = NlMsg(data, offset, attr_space=attr_space)
283            offset += msg.nl_len
284            self.msgs.append(msg)
285
286    def __iter__(self):
287        yield from self.msgs
288
289
290genl_family_name_to_id = None
291
292
293def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
294    # we prepend length in _genl_msg_finalize()
295    if seq is None:
296        seq = random.randint(1, 1024)
297    nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
298    genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
299    return nlmsg + genlmsg
300
301
302def _genl_msg_finalize(msg):
303    return struct.pack("I", len(msg) + 4) + msg
304
305
306def _genl_load_families():
307    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
308        sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
309
310        msg = _genl_msg(Netlink.GENL_ID_CTRL,
311                        Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
312                        Netlink.CTRL_CMD_GETFAMILY, 1)
313        msg = _genl_msg_finalize(msg)
314
315        sock.send(msg, 0)
316
317        global genl_family_name_to_id
318        genl_family_name_to_id = dict()
319
320        while True:
321            reply = sock.recv(128 * 1024)
322            nms = NlMsgs(reply)
323            for nl_msg in nms:
324                if nl_msg.error:
325                    print("Netlink error:", nl_msg.error)
326                    return
327                if nl_msg.done:
328                    return
329
330                gm = GenlMsg(nl_msg)
331                fam = dict()
332                for attr in NlAttrs(gm.raw):
333                    if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
334                        fam['id'] = attr.as_scalar('u16')
335                    elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
336                        fam['name'] = attr.as_strz()
337                    elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
338                        fam['maxattr'] = attr.as_scalar('u32')
339                    elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
340                        fam['mcast'] = dict()
341                        for entry in NlAttrs(attr.raw):
342                            mcast_name = None
343                            mcast_id = None
344                            for entry_attr in NlAttrs(entry.raw):
345                                if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
346                                    mcast_name = entry_attr.as_strz()
347                                elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
348                                    mcast_id = entry_attr.as_scalar('u32')
349                            if mcast_name and mcast_id is not None:
350                                fam['mcast'][mcast_name] = mcast_id
351                if 'name' in fam and 'id' in fam:
352                    genl_family_name_to_id[fam['name']] = fam
353
354
355class GenlMsg:
356    def __init__(self, nl_msg):
357        self.nl = nl_msg
358        self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
359        self.raw = nl_msg.raw[4:]
360
361    def cmd(self):
362        return self.genl_cmd
363
364    def __repr__(self):
365        msg = repr(self.nl)
366        msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
367        for a in self.raw_attrs:
368            msg += '\t\t' + repr(a) + '\n'
369        return msg
370
371
372class NetlinkProtocol:
373    def __init__(self, family_name, proto_num):
374        self.family_name = family_name
375        self.proto_num = proto_num
376
377    def _message(self, nl_type, nl_flags, seq=None):
378        if seq is None:
379            seq = random.randint(1, 1024)
380        nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
381        return nlmsg
382
383    def message(self, flags, command, version, seq=None):
384        return self._message(command, flags, seq)
385
386    def _decode(self, nl_msg):
387        return nl_msg
388
389    def decode(self, ynl, nl_msg, op):
390        msg = self._decode(nl_msg)
391        fixed_header_size = ynl._struct_size(op.fixed_header)
392        msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size)
393        return msg
394
395    def get_mcast_id(self, mcast_name, mcast_groups):
396        if mcast_name not in mcast_groups:
397            raise Exception(f'Multicast group "{mcast_name}" not present in the spec')
398        return mcast_groups[mcast_name].value
399
400    def msghdr_size(self):
401        return 16
402
403
404class GenlProtocol(NetlinkProtocol):
405    def __init__(self, family_name):
406        super().__init__(family_name, Netlink.NETLINK_GENERIC)
407
408        global genl_family_name_to_id
409        if genl_family_name_to_id is None:
410            _genl_load_families()
411
412        self.genl_family = genl_family_name_to_id[family_name]
413        self.family_id = genl_family_name_to_id[family_name]['id']
414
415    def message(self, flags, command, version, seq=None):
416        nlmsg = self._message(self.family_id, flags, seq)
417        genlmsg = struct.pack("BBH", command, version, 0)
418        return nlmsg + genlmsg
419
420    def _decode(self, nl_msg):
421        return GenlMsg(nl_msg)
422
423    def get_mcast_id(self, mcast_name, mcast_groups):
424        if mcast_name not in self.genl_family['mcast']:
425            raise Exception(f'Multicast group "{mcast_name}" not present in the family')
426        return self.genl_family['mcast'][mcast_name]
427
428    def msghdr_size(self):
429        return super().msghdr_size() + 4
430
431
432class SpaceAttrs:
433    SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values'])
434
435    def __init__(self, attr_space, attrs, outer = None):
436        outer_scopes = outer.scopes if outer else []
437        inner_scope = self.SpecValuesPair(attr_space, attrs)
438        self.scopes = [inner_scope] + outer_scopes
439
440    def lookup(self, name):
441        for scope in self.scopes:
442            if name in scope.spec:
443                if name in scope.values:
444                    return scope.values[name]
445                spec_name = scope.spec.yaml['name']
446                raise Exception(
447                    f"No value for '{name}' in attribute space '{spec_name}'")
448        raise Exception(f"Attribute '{name}' not defined in any attribute-set")
449
450
451#
452# YNL implementation details.
453#
454
455
456class YnlFamily(SpecFamily):
457    def __init__(self, def_path, schema=None, process_unknown=False,
458                 recv_size=0):
459        super().__init__(def_path, schema)
460
461        self.include_raw = False
462        self.process_unknown = process_unknown
463
464        try:
465            if self.proto == "netlink-raw":
466                self.nlproto = NetlinkProtocol(self.yaml['name'],
467                                               self.yaml['protonum'])
468            else:
469                self.nlproto = GenlProtocol(self.yaml['name'])
470        except KeyError:
471            raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
472
473        self._recv_dbg = False
474        # Note that netlink will use conservative (min) message size for
475        # the first dump recv() on the socket, our setting will only matter
476        # from the second recv() on.
477        self._recv_size = recv_size if recv_size else 131072
478        # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo)
479        # for a message, so smaller receive sizes will lead to truncation.
480        # Note that the min size for other families may be larger than 4k!
481        if self._recv_size < 4000:
482            raise ConfigError()
483
484        self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
485        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
486        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
487        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)
488
489        self.async_msg_ids = set()
490        self.async_msg_queue = []
491
492        for msg in self.msgs.values():
493            if msg.is_async:
494                self.async_msg_ids.add(msg.rsp_value)
495
496        for op_name, op in self.ops.items():
497            bound_f = functools.partial(self._op, op_name)
498            setattr(self, op.ident_name, bound_f)
499
500
501    def ntf_subscribe(self, mcast_name):
502        mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
503        self.sock.bind((0, 0))
504        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
505                             mcast_id)
506
507    def set_recv_dbg(self, enabled):
508        self._recv_dbg = enabled
509
510    def _recv_dbg_print(self, reply, nl_msgs):
511        if not self._recv_dbg:
512            return
513        print("Recv: read", len(reply), "bytes,",
514              len(nl_msgs.msgs), "messages", file=sys.stderr)
515        for nl_msg in nl_msgs:
516            print("  ", nl_msg, file=sys.stderr)
517
518    def _encode_enum(self, attr_spec, value):
519        enum = self.consts[attr_spec['enum']]
520        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
521            scalar = 0
522            if isinstance(value, str):
523                value = [value]
524            for single_value in value:
525                scalar += enum.entries[single_value].user_value(as_flags = True)
526            return scalar
527        else:
528            return enum.entries[value].user_value()
529
530    def _get_scalar(self, attr_spec, value):
531        try:
532            return int(value)
533        except (ValueError, TypeError) as e:
534            if 'enum' not in attr_spec:
535                raise e
536        return self._encode_enum(attr_spec, value)
537
538    def _add_attr(self, space, name, value, search_attrs):
539        try:
540            attr = self.attr_sets[space][name]
541        except KeyError:
542            raise Exception(f"Space '{space}' has no attribute '{name}'")
543        nl_type = attr.value
544
545        if attr.is_multi and isinstance(value, list):
546            attr_payload = b''
547            for subvalue in value:
548                attr_payload += self._add_attr(space, name, subvalue, search_attrs)
549            return attr_payload
550
551        if attr["type"] == 'nest':
552            nl_type |= Netlink.NLA_F_NESTED
553            attr_payload = b''
554            sub_attrs = SpaceAttrs(self.attr_sets[space], value, search_attrs)
555            for subname, subvalue in value.items():
556                attr_payload += self._add_attr(attr['nested-attributes'],
557                                               subname, subvalue, sub_attrs)
558        elif attr["type"] == 'flag':
559            if not value:
560                # If value is absent or false then skip attribute creation.
561                return b''
562            attr_payload = b''
563        elif attr["type"] == 'string':
564            attr_payload = str(value).encode('ascii') + b'\x00'
565        elif attr["type"] == 'binary':
566            if isinstance(value, bytes):
567                attr_payload = value
568            elif isinstance(value, str):
569                attr_payload = bytes.fromhex(value)
570            elif isinstance(value, dict) and attr.struct_name:
571                attr_payload = self._encode_struct(attr.struct_name, value)
572            else:
573                raise Exception(f'Unknown type for binary attribute, value: {value}')
574        elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar:
575            scalar = self._get_scalar(attr, value)
576            if attr.is_auto_scalar:
577                attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64')
578            else:
579                attr_type = attr["type"]
580            format = NlAttr.get_format(attr_type, attr.byte_order)
581            attr_payload = format.pack(scalar)
582        elif attr['type'] in "bitfield32":
583            scalar_value = self._get_scalar(attr, value["value"])
584            scalar_selector = self._get_scalar(attr, value["selector"])
585            attr_payload = struct.pack("II", scalar_value, scalar_selector)
586        elif attr['type'] == 'sub-message':
587            msg_format = self._resolve_selector(attr, search_attrs)
588            attr_payload = b''
589            if msg_format.fixed_header:
590                attr_payload += self._encode_struct(msg_format.fixed_header, value)
591            if msg_format.attr_set:
592                if msg_format.attr_set in self.attr_sets:
593                    nl_type |= Netlink.NLA_F_NESTED
594                    sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs)
595                    for subname, subvalue in value.items():
596                        attr_payload += self._add_attr(msg_format.attr_set,
597                                                       subname, subvalue, sub_attrs)
598                else:
599                    raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'")
600        else:
601            raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
602
603        pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
604        return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
605
606    def _decode_enum(self, raw, attr_spec):
607        enum = self.consts[attr_spec['enum']]
608        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
609            i = 0
610            value = set()
611            while raw:
612                if raw & 1:
613                    value.add(enum.entries_by_val[i].name)
614                raw >>= 1
615                i += 1
616        else:
617            value = enum.entries_by_val[raw].name
618        return value
619
620    def _decode_binary(self, attr, attr_spec):
621        if attr_spec.struct_name:
622            decoded = self._decode_struct(attr.raw, attr_spec.struct_name)
623        elif attr_spec.sub_type:
624            decoded = attr.as_c_array(attr_spec.sub_type)
625        else:
626            decoded = attr.as_bin()
627            if attr_spec.display_hint:
628                decoded = self._formatted_string(decoded, attr_spec.display_hint)
629        return decoded
630
631    def _decode_array_attr(self, attr, attr_spec):
632        decoded = []
633        offset = 0
634        while offset < len(attr.raw):
635            item = NlAttr(attr.raw, offset)
636            offset += item.full_len
637
638            if attr_spec["sub-type"] == 'nest':
639                subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes'])
640                decoded.append({ item.type: subattrs })
641            elif attr_spec["sub-type"] == 'binary':
642                subattrs = item.as_bin()
643                if attr_spec.display_hint:
644                    subattrs = self._formatted_string(subattrs, attr_spec.display_hint)
645                decoded.append(subattrs)
646            elif attr_spec["sub-type"] in NlAttr.type_formats:
647                subattrs = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order)
648                if attr_spec.display_hint:
649                    subattrs = self._formatted_string(subattrs, attr_spec.display_hint)
650                decoded.append(subattrs)
651            else:
652                raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}')
653        return decoded
654
655    def _decode_nest_type_value(self, attr, attr_spec):
656        decoded = {}
657        value = attr
658        for name in attr_spec['type-value']:
659            value = NlAttr(value.raw, 0)
660            decoded[name] = value.type
661        subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes'])
662        decoded.update(subattrs)
663        return decoded
664
665    def _decode_unknown(self, attr):
666        if attr.is_nest:
667            return self._decode(NlAttrs(attr.raw), None)
668        else:
669            return attr.as_bin()
670
671    def _rsp_add(self, rsp, name, is_multi, decoded):
672        if is_multi == None:
673            if name in rsp and type(rsp[name]) is not list:
674                rsp[name] = [rsp[name]]
675                is_multi = True
676            else:
677                is_multi = False
678
679        if not is_multi:
680            rsp[name] = decoded
681        elif name in rsp:
682            rsp[name].append(decoded)
683        else:
684            rsp[name] = [decoded]
685
686    def _resolve_selector(self, attr_spec, search_attrs):
687        sub_msg = attr_spec.sub_message
688        if sub_msg not in self.sub_msgs:
689            raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}")
690        sub_msg_spec = self.sub_msgs[sub_msg]
691
692        selector = attr_spec.selector
693        value = search_attrs.lookup(selector)
694        if value not in sub_msg_spec.formats:
695            raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'")
696
697        spec = sub_msg_spec.formats[value]
698        return spec
699
700    def _decode_sub_msg(self, attr, attr_spec, search_attrs):
701        msg_format = self._resolve_selector(attr_spec, search_attrs)
702        decoded = {}
703        offset = 0
704        if msg_format.fixed_header:
705            decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header));
706            offset = self._struct_size(msg_format.fixed_header)
707        if msg_format.attr_set:
708            if msg_format.attr_set in self.attr_sets:
709                subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set)
710                decoded.update(subdict)
711            else:
712                raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'")
713        return decoded
714
715    def _decode(self, attrs, space, outer_attrs = None):
716        rsp = dict()
717        if space:
718            attr_space = self.attr_sets[space]
719            search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs)
720
721        for attr in attrs:
722            try:
723                attr_spec = attr_space.attrs_by_val[attr.type]
724            except (KeyError, UnboundLocalError):
725                if not self.process_unknown:
726                    raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'")
727                attr_name = f"UnknownAttr({attr.type})"
728                self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr))
729                continue
730
731            if attr_spec["type"] == 'nest':
732                subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs)
733                decoded = subdict
734            elif attr_spec["type"] == 'string':
735                decoded = attr.as_strz()
736            elif attr_spec["type"] == 'binary':
737                decoded = self._decode_binary(attr, attr_spec)
738            elif attr_spec["type"] == 'flag':
739                decoded = True
740            elif attr_spec.is_auto_scalar:
741                decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order)
742            elif attr_spec["type"] in NlAttr.type_formats:
743                decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
744                if 'enum' in attr_spec:
745                    decoded = self._decode_enum(decoded, attr_spec)
746            elif attr_spec["type"] == 'indexed-array':
747                decoded = self._decode_array_attr(attr, attr_spec)
748            elif attr_spec["type"] == 'bitfield32':
749                value, selector = struct.unpack("II", attr.raw)
750                if 'enum' in attr_spec:
751                    value = self._decode_enum(value, attr_spec)
752                    selector = self._decode_enum(selector, attr_spec)
753                decoded = {"value": value, "selector": selector}
754            elif attr_spec["type"] == 'sub-message':
755                decoded = self._decode_sub_msg(attr, attr_spec, search_attrs)
756            elif attr_spec["type"] == 'nest-type-value':
757                decoded = self._decode_nest_type_value(attr, attr_spec)
758            else:
759                if not self.process_unknown:
760                    raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
761                decoded = self._decode_unknown(attr)
762
763            self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded)
764
765        return rsp
766
767    def _decode_extack_path(self, attrs, attr_set, offset, target):
768        for attr in attrs:
769            try:
770                attr_spec = attr_set.attrs_by_val[attr.type]
771            except KeyError:
772                raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'")
773            if offset > target:
774                break
775            if offset == target:
776                return '.' + attr_spec.name
777
778            if offset + attr.full_len <= target:
779                offset += attr.full_len
780                continue
781            if attr_spec['type'] != 'nest':
782                raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
783            offset += 4
784            subpath = self._decode_extack_path(NlAttrs(attr.raw),
785                                               self.attr_sets[attr_spec['nested-attributes']],
786                                               offset, target)
787            if subpath is None:
788                return None
789            return '.' + attr_spec.name + subpath
790
791        return None
792
793    def _decode_extack(self, request, op, extack):
794        if 'bad-attr-offs' not in extack:
795            return
796
797        msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op)
798        offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header)
799        path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
800                                        extack['bad-attr-offs'])
801        if path:
802            del extack['bad-attr-offs']
803            extack['bad-attr'] = path
804
805    def _struct_size(self, name):
806        if name:
807            members = self.consts[name].members
808            size = 0
809            for m in members:
810                if m.type in ['pad', 'binary']:
811                    if m.struct:
812                        size += self._struct_size(m.struct)
813                    else:
814                        size += m.len
815                else:
816                    format = NlAttr.get_format(m.type, m.byte_order)
817                    size += format.size
818            return size
819        else:
820            return 0
821
822    def _decode_struct(self, data, name):
823        members = self.consts[name].members
824        attrs = dict()
825        offset = 0
826        for m in members:
827            value = None
828            if m.type == 'pad':
829                offset += m.len
830            elif m.type == 'binary':
831                if m.struct:
832                    len = self._struct_size(m.struct)
833                    value = self._decode_struct(data[offset : offset + len],
834                                                m.struct)
835                    offset += len
836                else:
837                    value = data[offset : offset + m.len]
838                    offset += m.len
839            else:
840                format = NlAttr.get_format(m.type, m.byte_order)
841                [ value ] = format.unpack_from(data, offset)
842                offset += format.size
843            if value is not None:
844                if m.enum:
845                    value = self._decode_enum(value, m)
846                elif m.display_hint:
847                    value = self._formatted_string(value, m.display_hint)
848                attrs[m.name] = value
849        return attrs
850
851    def _encode_struct(self, name, vals):
852        members = self.consts[name].members
853        attr_payload = b''
854        for m in members:
855            value = vals.pop(m.name) if m.name in vals else None
856            if m.type == 'pad':
857                attr_payload += bytearray(m.len)
858            elif m.type == 'binary':
859                if m.struct:
860                    if value is None:
861                        value = dict()
862                    attr_payload += self._encode_struct(m.struct, value)
863                else:
864                    if value is None:
865                        attr_payload += bytearray(m.len)
866                    else:
867                        attr_payload += bytes.fromhex(value)
868            else:
869                if value is None:
870                    value = 0
871                format = NlAttr.get_format(m.type, m.byte_order)
872                attr_payload += format.pack(value)
873        return attr_payload
874
875    def _formatted_string(self, raw, display_hint):
876        if display_hint == 'mac':
877            formatted = ':'.join('%02x' % b for b in raw)
878        elif display_hint == 'hex':
879            if isinstance(raw, int):
880                formatted = hex(raw)
881            else:
882                formatted = bytes.hex(raw, ' ')
883        elif display_hint in [ 'ipv4', 'ipv6' ]:
884            formatted = format(ipaddress.ip_address(raw))
885        elif display_hint == 'uuid':
886            formatted = str(uuid.UUID(bytes=raw))
887        else:
888            formatted = raw
889        return formatted
890
891    def handle_ntf(self, decoded):
892        msg = dict()
893        if self.include_raw:
894            msg['raw'] = decoded
895        op = self.rsp_by_value[decoded.cmd()]
896        attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
897        if op.fixed_header:
898            attrs.update(self._decode_struct(decoded.raw, op.fixed_header))
899
900        msg['name'] = op['name']
901        msg['msg'] = attrs
902        self.async_msg_queue.append(msg)
903
904    def check_ntf(self):
905        while True:
906            try:
907                reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT)
908            except BlockingIOError:
909                return
910
911            nms = NlMsgs(reply)
912            self._recv_dbg_print(reply, nms)
913            for nl_msg in nms:
914                if nl_msg.error:
915                    print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
916                    print(nl_msg)
917                    continue
918                if nl_msg.done:
919                    print("Netlink done while checking for ntf!?")
920                    continue
921
922                op = self.rsp_by_value[nl_msg.cmd()]
923                decoded = self.nlproto.decode(self, nl_msg, op)
924                if decoded.cmd() not in self.async_msg_ids:
925                    print("Unexpected msg id done while checking for ntf", decoded)
926                    continue
927
928                self.handle_ntf(decoded)
929
930    def operation_do_attributes(self, name):
931      """
932      For a given operation name, find and return a supported
933      set of attributes (as a dict).
934      """
935      op = self.find_operation(name)
936      if not op:
937        return None
938
939      return op['do']['request']['attributes'].copy()
940
941    def _encode_message(self, op, vals, flags, req_seq):
942        nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
943        for flag in flags or []:
944            nl_flags |= flag
945
946        msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
947        if op.fixed_header:
948            msg += self._encode_struct(op.fixed_header, vals)
949        search_attrs = SpaceAttrs(op.attr_set, vals)
950        for name, value in vals.items():
951            msg += self._add_attr(op.attr_set.name, name, value, search_attrs)
952        msg = _genl_msg_finalize(msg)
953        return msg
954
955    def _ops(self, ops):
956        reqs_by_seq = {}
957        req_seq = random.randint(1024, 65535)
958        payload = b''
959        for (method, vals, flags) in ops:
960            op = self.ops[method]
961            msg = self._encode_message(op, vals, flags, req_seq)
962            reqs_by_seq[req_seq] = (op, msg, flags)
963            payload += msg
964            req_seq += 1
965
966        self.sock.send(payload, 0)
967
968        done = False
969        rsp = []
970        op_rsp = []
971        while not done:
972            reply = self.sock.recv(self._recv_size)
973            nms = NlMsgs(reply, attr_space=op.attr_set)
974            self._recv_dbg_print(reply, nms)
975            for nl_msg in nms:
976                if nl_msg.nl_seq in reqs_by_seq:
977                    (op, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq]
978                    if nl_msg.extack:
979                        self._decode_extack(req_msg, op, nl_msg.extack)
980                else:
981                    op = self.rsp_by_value[nl_msg.cmd()]
982                    req_flags = []
983
984                if nl_msg.error:
985                    raise NlError(nl_msg)
986                if nl_msg.done:
987                    if nl_msg.extack:
988                        print("Netlink warning:")
989                        print(nl_msg)
990
991                    if Netlink.NLM_F_DUMP in req_flags:
992                        rsp.append(op_rsp)
993                    elif not op_rsp:
994                        rsp.append(None)
995                    elif len(op_rsp) == 1:
996                        rsp.append(op_rsp[0])
997                    else:
998                        rsp.append(op_rsp)
999                    op_rsp = []
1000
1001                    del reqs_by_seq[nl_msg.nl_seq]
1002                    done = len(reqs_by_seq) == 0
1003                    break
1004
1005                decoded = self.nlproto.decode(self, nl_msg, op)
1006
1007                # Check if this is a reply to our request
1008                if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value:
1009                    if decoded.cmd() in self.async_msg_ids:
1010                        self.handle_ntf(decoded)
1011                        continue
1012                    else:
1013                        print('Unexpected message: ' + repr(decoded))
1014                        continue
1015
1016                rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
1017                if op.fixed_header:
1018                    rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header))
1019                op_rsp.append(rsp_msg)
1020
1021        return rsp
1022
1023    def _op(self, method, vals, flags=None, dump=False):
1024        req_flags = flags or []
1025        if dump:
1026            req_flags.append(Netlink.NLM_F_DUMP)
1027
1028        ops = [(method, vals, req_flags)]
1029        return self._ops(ops)[0]
1030
1031    def do(self, method, vals, flags=None):
1032        return self._op(method, vals, flags)
1033
1034    def dump(self, method, vals):
1035        return self._op(method, vals, dump=True)
1036
1037    def do_multi(self, ops):
1038        return self._ops(ops)
1039