xref: /linux/tools/net/ynl/lib/ynl.py (revision 1a371190a375f98c9b106f758ea41558c3f92556)
1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2
3from collections import namedtuple
4from enum import Enum
5import functools
6import os
7import random
8import socket
9import struct
10from struct import Struct
11import sys
12import yaml
13import ipaddress
14import uuid
15
16from .nlspec import SpecFamily
17
18#
19# Generic Netlink code which should really be in some library, but I can't quickly find one.
20#
21
22
23class Netlink:
24    # Netlink socket
25    SOL_NETLINK = 270
26
27    NETLINK_ADD_MEMBERSHIP = 1
28    NETLINK_CAP_ACK = 10
29    NETLINK_EXT_ACK = 11
30    NETLINK_GET_STRICT_CHK = 12
31
32    # Netlink message
33    NLMSG_ERROR = 2
34    NLMSG_DONE = 3
35
36    NLM_F_REQUEST = 1
37    NLM_F_ACK = 4
38    NLM_F_ROOT = 0x100
39    NLM_F_MATCH = 0x200
40
41    NLM_F_REPLACE = 0x100
42    NLM_F_EXCL = 0x200
43    NLM_F_CREATE = 0x400
44    NLM_F_APPEND = 0x800
45
46    NLM_F_CAPPED = 0x100
47    NLM_F_ACK_TLVS = 0x200
48
49    NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
50
51    NLA_F_NESTED = 0x8000
52    NLA_F_NET_BYTEORDER = 0x4000
53
54    NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
55
56    # Genetlink defines
57    NETLINK_GENERIC = 16
58
59    GENL_ID_CTRL = 0x10
60
61    # nlctrl
62    CTRL_CMD_GETFAMILY = 3
63
64    CTRL_ATTR_FAMILY_ID = 1
65    CTRL_ATTR_FAMILY_NAME = 2
66    CTRL_ATTR_MAXATTR = 5
67    CTRL_ATTR_MCAST_GROUPS = 7
68
69    CTRL_ATTR_MCAST_GRP_NAME = 1
70    CTRL_ATTR_MCAST_GRP_ID = 2
71
72    # Extack types
73    NLMSGERR_ATTR_MSG = 1
74    NLMSGERR_ATTR_OFFS = 2
75    NLMSGERR_ATTR_COOKIE = 3
76    NLMSGERR_ATTR_POLICY = 4
77    NLMSGERR_ATTR_MISS_TYPE = 5
78    NLMSGERR_ATTR_MISS_NEST = 6
79
80    # Policy types
81    NL_POLICY_TYPE_ATTR_TYPE = 1
82    NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2
83    NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3
84    NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4
85    NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5
86    NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6
87    NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7
88    NL_POLICY_TYPE_ATTR_POLICY_IDX = 8
89    NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9
90    NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10
91    NL_POLICY_TYPE_ATTR_PAD = 11
92    NL_POLICY_TYPE_ATTR_MASK = 12
93
94    AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64',
95                                  's8', 's16', 's32', 's64',
96                                  'binary', 'string', 'nul-string',
97                                  'nested', 'nested-array',
98                                  'bitfield32', 'sint', 'uint'])
99
100class NlError(Exception):
101  def __init__(self, nl_msg):
102    self.nl_msg = nl_msg
103    self.error = -nl_msg.error
104
105  def __str__(self):
106    return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}"
107
108
109class ConfigError(Exception):
110    pass
111
112
113class NlAttr:
114    ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
115    type_formats = {
116        'u8' : ScalarFormat(Struct('B'), Struct("B"),  Struct("B")),
117        's8' : ScalarFormat(Struct('b'), Struct("b"),  Struct("b")),
118        'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
119        's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
120        'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
121        's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
122        'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
123        's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
124    }
125
126    def __init__(self, raw, offset):
127        self._len, self._type = struct.unpack("HH", raw[offset : offset + 4])
128        self.type = self._type & ~Netlink.NLA_TYPE_MASK
129        self.is_nest = self._type & Netlink.NLA_F_NESTED
130        self.payload_len = self._len
131        self.full_len = (self.payload_len + 3) & ~3
132        self.raw = raw[offset + 4 : offset + self.payload_len]
133
134    @classmethod
135    def get_format(cls, attr_type, byte_order=None):
136        format = cls.type_formats[attr_type]
137        if byte_order:
138            return format.big if byte_order == "big-endian" \
139                else format.little
140        return format.native
141
142    def as_scalar(self, attr_type, byte_order=None):
143        format = self.get_format(attr_type, byte_order)
144        return format.unpack(self.raw)[0]
145
146    def as_auto_scalar(self, attr_type, byte_order=None):
147        if len(self.raw) != 4 and len(self.raw) != 8:
148            raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}")
149        real_type = attr_type[0] + str(len(self.raw) * 8)
150        format = self.get_format(real_type, byte_order)
151        return format.unpack(self.raw)[0]
152
153    def as_strz(self):
154        return self.raw.decode('ascii')[:-1]
155
156    def as_bin(self):
157        return self.raw
158
159    def as_c_array(self, type):
160        format = self.get_format(type)
161        return [ x[0] for x in format.iter_unpack(self.raw) ]
162
163    def __repr__(self):
164        return f"[type:{self.type} len:{self._len}] {self.raw}"
165
166
167class NlAttrs:
168    def __init__(self, msg, offset=0):
169        self.attrs = []
170
171        while offset < len(msg):
172            attr = NlAttr(msg, offset)
173            offset += attr.full_len
174            self.attrs.append(attr)
175
176    def __iter__(self):
177        yield from self.attrs
178
179    def __repr__(self):
180        msg = ''
181        for a in self.attrs:
182            if msg:
183                msg += '\n'
184            msg += repr(a)
185        return msg
186
187
188class NlMsg:
189    def __init__(self, msg, offset, attr_space=None):
190        self.hdr = msg[offset : offset + 16]
191
192        self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
193            struct.unpack("IHHII", self.hdr)
194
195        self.raw = msg[offset + 16 : offset + self.nl_len]
196
197        self.error = 0
198        self.done = 0
199
200        extack_off = None
201        if self.nl_type == Netlink.NLMSG_ERROR:
202            self.error = struct.unpack("i", self.raw[0:4])[0]
203            self.done = 1
204            extack_off = 20
205        elif self.nl_type == Netlink.NLMSG_DONE:
206            self.error = struct.unpack("i", self.raw[0:4])[0]
207            self.done = 1
208            extack_off = 4
209
210        self.extack = None
211        if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
212            self.extack = dict()
213            extack_attrs = NlAttrs(self.raw[extack_off:])
214            for extack in extack_attrs:
215                if extack.type == Netlink.NLMSGERR_ATTR_MSG:
216                    self.extack['msg'] = extack.as_strz()
217                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
218                    self.extack['miss-type'] = extack.as_scalar('u32')
219                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
220                    self.extack['miss-nest'] = extack.as_scalar('u32')
221                elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
222                    self.extack['bad-attr-offs'] = extack.as_scalar('u32')
223                elif extack.type == Netlink.NLMSGERR_ATTR_POLICY:
224                    self.extack['policy'] = self._decode_policy(extack.raw)
225                else:
226                    if 'unknown' not in self.extack:
227                        self.extack['unknown'] = []
228                    self.extack['unknown'].append(extack)
229
230            if attr_space:
231                # We don't have the ability to parse nests yet, so only do global
232                if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
233                    miss_type = self.extack['miss-type']
234                    if miss_type in attr_space.attrs_by_val:
235                        spec = attr_space.attrs_by_val[miss_type]
236                        self.extack['miss-type'] = spec['name']
237                        if 'doc' in spec:
238                            self.extack['miss-type-doc'] = spec['doc']
239
240    def _decode_policy(self, raw):
241        policy = {}
242        for attr in NlAttrs(raw):
243            if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE:
244                type = attr.as_scalar('u32')
245                policy['type'] = Netlink.AttrType(type).name
246            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S:
247                policy['min-value'] = attr.as_scalar('s64')
248            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S:
249                policy['max-value'] = attr.as_scalar('s64')
250            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U:
251                policy['min-value'] = attr.as_scalar('u64')
252            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U:
253                policy['max-value'] = attr.as_scalar('u64')
254            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH:
255                policy['min-length'] = attr.as_scalar('u32')
256            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH:
257                policy['max-length'] = attr.as_scalar('u32')
258            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK:
259                policy['bitfield32-mask'] = attr.as_scalar('u32')
260            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK:
261                policy['mask'] = attr.as_scalar('u64')
262        return policy
263
264    def cmd(self):
265        return self.nl_type
266
267    def __repr__(self):
268        msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}"
269        if self.error:
270            msg += '\n\terror: ' + str(self.error)
271        if self.extack:
272            msg += '\n\textack: ' + repr(self.extack)
273        return msg
274
275
276class NlMsgs:
277    def __init__(self, data, attr_space=None):
278        self.msgs = []
279
280        offset = 0
281        while offset < len(data):
282            msg = NlMsg(data, offset, attr_space=attr_space)
283            offset += msg.nl_len
284            self.msgs.append(msg)
285
286    def __iter__(self):
287        yield from self.msgs
288
289
290genl_family_name_to_id = None
291
292
293def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
294    # we prepend length in _genl_msg_finalize()
295    if seq is None:
296        seq = random.randint(1, 1024)
297    nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
298    genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
299    return nlmsg + genlmsg
300
301
302def _genl_msg_finalize(msg):
303    return struct.pack("I", len(msg) + 4) + msg
304
305
306def _genl_load_families():
307    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
308        sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
309
310        msg = _genl_msg(Netlink.GENL_ID_CTRL,
311                        Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
312                        Netlink.CTRL_CMD_GETFAMILY, 1)
313        msg = _genl_msg_finalize(msg)
314
315        sock.send(msg, 0)
316
317        global genl_family_name_to_id
318        genl_family_name_to_id = dict()
319
320        while True:
321            reply = sock.recv(128 * 1024)
322            nms = NlMsgs(reply)
323            for nl_msg in nms:
324                if nl_msg.error:
325                    print("Netlink error:", nl_msg.error)
326                    return
327                if nl_msg.done:
328                    return
329
330                gm = GenlMsg(nl_msg)
331                fam = dict()
332                for attr in NlAttrs(gm.raw):
333                    if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
334                        fam['id'] = attr.as_scalar('u16')
335                    elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
336                        fam['name'] = attr.as_strz()
337                    elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
338                        fam['maxattr'] = attr.as_scalar('u32')
339                    elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
340                        fam['mcast'] = dict()
341                        for entry in NlAttrs(attr.raw):
342                            mcast_name = None
343                            mcast_id = None
344                            for entry_attr in NlAttrs(entry.raw):
345                                if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
346                                    mcast_name = entry_attr.as_strz()
347                                elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
348                                    mcast_id = entry_attr.as_scalar('u32')
349                            if mcast_name and mcast_id is not None:
350                                fam['mcast'][mcast_name] = mcast_id
351                if 'name' in fam and 'id' in fam:
352                    genl_family_name_to_id[fam['name']] = fam
353
354
355class GenlMsg:
356    def __init__(self, nl_msg):
357        self.nl = nl_msg
358        self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
359        self.raw = nl_msg.raw[4:]
360
361    def cmd(self):
362        return self.genl_cmd
363
364    def __repr__(self):
365        msg = repr(self.nl)
366        msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
367        for a in self.raw_attrs:
368            msg += '\t\t' + repr(a) + '\n'
369        return msg
370
371
372class NetlinkProtocol:
373    def __init__(self, family_name, proto_num):
374        self.family_name = family_name
375        self.proto_num = proto_num
376
377    def _message(self, nl_type, nl_flags, seq=None):
378        if seq is None:
379            seq = random.randint(1, 1024)
380        nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
381        return nlmsg
382
383    def message(self, flags, command, version, seq=None):
384        return self._message(command, flags, seq)
385
386    def _decode(self, nl_msg):
387        return nl_msg
388
389    def decode(self, ynl, nl_msg, op):
390        msg = self._decode(nl_msg)
391        if op is None:
392            op = ynl.rsp_by_value[msg.cmd()]
393        fixed_header_size = ynl._struct_size(op.fixed_header)
394        msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size)
395        return msg
396
397    def get_mcast_id(self, mcast_name, mcast_groups):
398        if mcast_name not in mcast_groups:
399            raise Exception(f'Multicast group "{mcast_name}" not present in the spec')
400        return mcast_groups[mcast_name].value
401
402    def msghdr_size(self):
403        return 16
404
405
406class GenlProtocol(NetlinkProtocol):
407    def __init__(self, family_name):
408        super().__init__(family_name, Netlink.NETLINK_GENERIC)
409
410        global genl_family_name_to_id
411        if genl_family_name_to_id is None:
412            _genl_load_families()
413
414        self.genl_family = genl_family_name_to_id[family_name]
415        self.family_id = genl_family_name_to_id[family_name]['id']
416
417    def message(self, flags, command, version, seq=None):
418        nlmsg = self._message(self.family_id, flags, seq)
419        genlmsg = struct.pack("BBH", command, version, 0)
420        return nlmsg + genlmsg
421
422    def _decode(self, nl_msg):
423        return GenlMsg(nl_msg)
424
425    def get_mcast_id(self, mcast_name, mcast_groups):
426        if mcast_name not in self.genl_family['mcast']:
427            raise Exception(f'Multicast group "{mcast_name}" not present in the family')
428        return self.genl_family['mcast'][mcast_name]
429
430    def msghdr_size(self):
431        return super().msghdr_size() + 4
432
433
434class SpaceAttrs:
435    SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values'])
436
437    def __init__(self, attr_space, attrs, outer = None):
438        outer_scopes = outer.scopes if outer else []
439        inner_scope = self.SpecValuesPair(attr_space, attrs)
440        self.scopes = [inner_scope] + outer_scopes
441
442    def lookup(self, name):
443        for scope in self.scopes:
444            if name in scope.spec:
445                if name in scope.values:
446                    return scope.values[name]
447                spec_name = scope.spec.yaml['name']
448                raise Exception(
449                    f"No value for '{name}' in attribute space '{spec_name}'")
450        raise Exception(f"Attribute '{name}' not defined in any attribute-set")
451
452
453#
454# YNL implementation details.
455#
456
457
458class YnlFamily(SpecFamily):
459    def __init__(self, def_path, schema=None, process_unknown=False,
460                 recv_size=0):
461        super().__init__(def_path, schema)
462
463        self.include_raw = False
464        self.process_unknown = process_unknown
465
466        try:
467            if self.proto == "netlink-raw":
468                self.nlproto = NetlinkProtocol(self.yaml['name'],
469                                               self.yaml['protonum'])
470            else:
471                self.nlproto = GenlProtocol(self.yaml['name'])
472        except KeyError:
473            raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
474
475        self._recv_dbg = False
476        # Note that netlink will use conservative (min) message size for
477        # the first dump recv() on the socket, our setting will only matter
478        # from the second recv() on.
479        self._recv_size = recv_size if recv_size else 131072
480        # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo)
481        # for a message, so smaller receive sizes will lead to truncation.
482        # Note that the min size for other families may be larger than 4k!
483        if self._recv_size < 4000:
484            raise ConfigError()
485
486        self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
487        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
488        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
489        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)
490
491        self.async_msg_ids = set()
492        self.async_msg_queue = []
493
494        for msg in self.msgs.values():
495            if msg.is_async:
496                self.async_msg_ids.add(msg.rsp_value)
497
498        for op_name, op in self.ops.items():
499            bound_f = functools.partial(self._op, op_name)
500            setattr(self, op.ident_name, bound_f)
501
502
503    def ntf_subscribe(self, mcast_name):
504        mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
505        self.sock.bind((0, 0))
506        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
507                             mcast_id)
508
509    def set_recv_dbg(self, enabled):
510        self._recv_dbg = enabled
511
512    def _recv_dbg_print(self, reply, nl_msgs):
513        if not self._recv_dbg:
514            return
515        print("Recv: read", len(reply), "bytes,",
516              len(nl_msgs.msgs), "messages", file=sys.stderr)
517        for nl_msg in nl_msgs:
518            print("  ", nl_msg, file=sys.stderr)
519
520    def _encode_enum(self, attr_spec, value):
521        enum = self.consts[attr_spec['enum']]
522        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
523            scalar = 0
524            if isinstance(value, str):
525                value = [value]
526            for single_value in value:
527                scalar += enum.entries[single_value].user_value(as_flags = True)
528            return scalar
529        else:
530            return enum.entries[value].user_value()
531
532    def _get_scalar(self, attr_spec, value):
533        try:
534            return int(value)
535        except (ValueError, TypeError) as e:
536            if 'enum' not in attr_spec:
537                raise e
538        return self._encode_enum(attr_spec, value)
539
540    def _add_attr(self, space, name, value, search_attrs):
541        try:
542            attr = self.attr_sets[space][name]
543        except KeyError:
544            raise Exception(f"Space '{space}' has no attribute '{name}'")
545        nl_type = attr.value
546
547        if attr.is_multi and isinstance(value, list):
548            attr_payload = b''
549            for subvalue in value:
550                attr_payload += self._add_attr(space, name, subvalue, search_attrs)
551            return attr_payload
552
553        if attr["type"] == 'nest':
554            nl_type |= Netlink.NLA_F_NESTED
555            attr_payload = b''
556            sub_attrs = SpaceAttrs(self.attr_sets[space], value, search_attrs)
557            for subname, subvalue in value.items():
558                attr_payload += self._add_attr(attr['nested-attributes'],
559                                               subname, subvalue, sub_attrs)
560        elif attr["type"] == 'flag':
561            if not value:
562                # If value is absent or false then skip attribute creation.
563                return b''
564            attr_payload = b''
565        elif attr["type"] == 'string':
566            attr_payload = str(value).encode('ascii') + b'\x00'
567        elif attr["type"] == 'binary':
568            if isinstance(value, bytes):
569                attr_payload = value
570            elif isinstance(value, str):
571                attr_payload = bytes.fromhex(value)
572            elif isinstance(value, dict) and attr.struct_name:
573                attr_payload = self._encode_struct(attr.struct_name, value)
574            else:
575                raise Exception(f'Unknown type for binary attribute, value: {value}')
576        elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar:
577            scalar = self._get_scalar(attr, value)
578            if attr.is_auto_scalar:
579                attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64')
580            else:
581                attr_type = attr["type"]
582            format = NlAttr.get_format(attr_type, attr.byte_order)
583            attr_payload = format.pack(scalar)
584        elif attr['type'] in "bitfield32":
585            scalar_value = self._get_scalar(attr, value["value"])
586            scalar_selector = self._get_scalar(attr, value["selector"])
587            attr_payload = struct.pack("II", scalar_value, scalar_selector)
588        elif attr['type'] == 'sub-message':
589            msg_format = self._resolve_selector(attr, search_attrs)
590            attr_payload = b''
591            if msg_format.fixed_header:
592                attr_payload += self._encode_struct(msg_format.fixed_header, value)
593            if msg_format.attr_set:
594                if msg_format.attr_set in self.attr_sets:
595                    nl_type |= Netlink.NLA_F_NESTED
596                    sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs)
597                    for subname, subvalue in value.items():
598                        attr_payload += self._add_attr(msg_format.attr_set,
599                                                       subname, subvalue, sub_attrs)
600                else:
601                    raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'")
602        else:
603            raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
604
605        pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
606        return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
607
608    def _decode_enum(self, raw, attr_spec):
609        enum = self.consts[attr_spec['enum']]
610        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
611            i = 0
612            value = set()
613            while raw:
614                if raw & 1:
615                    value.add(enum.entries_by_val[i].name)
616                raw >>= 1
617                i += 1
618        else:
619            value = enum.entries_by_val[raw].name
620        return value
621
622    def _decode_binary(self, attr, attr_spec):
623        if attr_spec.struct_name:
624            decoded = self._decode_struct(attr.raw, attr_spec.struct_name)
625        elif attr_spec.sub_type:
626            decoded = attr.as_c_array(attr_spec.sub_type)
627        else:
628            decoded = attr.as_bin()
629            if attr_spec.display_hint:
630                decoded = self._formatted_string(decoded, attr_spec.display_hint)
631        return decoded
632
633    def _decode_array_attr(self, attr, attr_spec):
634        decoded = []
635        offset = 0
636        while offset < len(attr.raw):
637            item = NlAttr(attr.raw, offset)
638            offset += item.full_len
639
640            if attr_spec["sub-type"] == 'nest':
641                subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes'])
642                decoded.append({ item.type: subattrs })
643            elif attr_spec["sub-type"] == 'binary':
644                subattrs = item.as_bin()
645                if attr_spec.display_hint:
646                    subattrs = self._formatted_string(subattrs, attr_spec.display_hint)
647                decoded.append(subattrs)
648            elif attr_spec["sub-type"] in NlAttr.type_formats:
649                subattrs = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order)
650                if attr_spec.display_hint:
651                    subattrs = self._formatted_string(subattrs, attr_spec.display_hint)
652                decoded.append(subattrs)
653            else:
654                raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}')
655        return decoded
656
657    def _decode_nest_type_value(self, attr, attr_spec):
658        decoded = {}
659        value = attr
660        for name in attr_spec['type-value']:
661            value = NlAttr(value.raw, 0)
662            decoded[name] = value.type
663        subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes'])
664        decoded.update(subattrs)
665        return decoded
666
667    def _decode_unknown(self, attr):
668        if attr.is_nest:
669            return self._decode(NlAttrs(attr.raw), None)
670        else:
671            return attr.as_bin()
672
673    def _rsp_add(self, rsp, name, is_multi, decoded):
674        if is_multi == None:
675            if name in rsp and type(rsp[name]) is not list:
676                rsp[name] = [rsp[name]]
677                is_multi = True
678            else:
679                is_multi = False
680
681        if not is_multi:
682            rsp[name] = decoded
683        elif name in rsp:
684            rsp[name].append(decoded)
685        else:
686            rsp[name] = [decoded]
687
688    def _resolve_selector(self, attr_spec, search_attrs):
689        sub_msg = attr_spec.sub_message
690        if sub_msg not in self.sub_msgs:
691            raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}")
692        sub_msg_spec = self.sub_msgs[sub_msg]
693
694        selector = attr_spec.selector
695        value = search_attrs.lookup(selector)
696        if value not in sub_msg_spec.formats:
697            raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'")
698
699        spec = sub_msg_spec.formats[value]
700        return spec
701
702    def _decode_sub_msg(self, attr, attr_spec, search_attrs):
703        msg_format = self._resolve_selector(attr_spec, search_attrs)
704        decoded = {}
705        offset = 0
706        if msg_format.fixed_header:
707            decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header));
708            offset = self._struct_size(msg_format.fixed_header)
709        if msg_format.attr_set:
710            if msg_format.attr_set in self.attr_sets:
711                subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set)
712                decoded.update(subdict)
713            else:
714                raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'")
715        return decoded
716
717    def _decode(self, attrs, space, outer_attrs = None):
718        rsp = dict()
719        if space:
720            attr_space = self.attr_sets[space]
721            search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs)
722
723        for attr in attrs:
724            try:
725                attr_spec = attr_space.attrs_by_val[attr.type]
726            except (KeyError, UnboundLocalError):
727                if not self.process_unknown:
728                    raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'")
729                attr_name = f"UnknownAttr({attr.type})"
730                self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr))
731                continue
732
733            if attr_spec["type"] == 'nest':
734                subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs)
735                decoded = subdict
736            elif attr_spec["type"] == 'string':
737                decoded = attr.as_strz()
738            elif attr_spec["type"] == 'binary':
739                decoded = self._decode_binary(attr, attr_spec)
740            elif attr_spec["type"] == 'flag':
741                decoded = True
742            elif attr_spec.is_auto_scalar:
743                decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order)
744            elif attr_spec["type"] in NlAttr.type_formats:
745                decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
746                if 'enum' in attr_spec:
747                    decoded = self._decode_enum(decoded, attr_spec)
748                elif attr_spec.display_hint:
749                    decoded = self._formatted_string(decoded, attr_spec.display_hint)
750            elif attr_spec["type"] == 'indexed-array':
751                decoded = self._decode_array_attr(attr, attr_spec)
752            elif attr_spec["type"] == 'bitfield32':
753                value, selector = struct.unpack("II", attr.raw)
754                if 'enum' in attr_spec:
755                    value = self._decode_enum(value, attr_spec)
756                    selector = self._decode_enum(selector, attr_spec)
757                decoded = {"value": value, "selector": selector}
758            elif attr_spec["type"] == 'sub-message':
759                decoded = self._decode_sub_msg(attr, attr_spec, search_attrs)
760            elif attr_spec["type"] == 'nest-type-value':
761                decoded = self._decode_nest_type_value(attr, attr_spec)
762            else:
763                if not self.process_unknown:
764                    raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
765                decoded = self._decode_unknown(attr)
766
767            self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded)
768
769        return rsp
770
771    def _decode_extack_path(self, attrs, attr_set, offset, target):
772        for attr in attrs:
773            try:
774                attr_spec = attr_set.attrs_by_val[attr.type]
775            except KeyError:
776                raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'")
777            if offset > target:
778                break
779            if offset == target:
780                return '.' + attr_spec.name
781
782            if offset + attr.full_len <= target:
783                offset += attr.full_len
784                continue
785            if attr_spec['type'] != 'nest':
786                raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
787            offset += 4
788            subpath = self._decode_extack_path(NlAttrs(attr.raw),
789                                               self.attr_sets[attr_spec['nested-attributes']],
790                                               offset, target)
791            if subpath is None:
792                return None
793            return '.' + attr_spec.name + subpath
794
795        return None
796
797    def _decode_extack(self, request, op, extack):
798        if 'bad-attr-offs' not in extack:
799            return
800
801        msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op)
802        offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header)
803        path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
804                                        extack['bad-attr-offs'])
805        if path:
806            del extack['bad-attr-offs']
807            extack['bad-attr'] = path
808
809    def _struct_size(self, name):
810        if name:
811            members = self.consts[name].members
812            size = 0
813            for m in members:
814                if m.type in ['pad', 'binary']:
815                    if m.struct:
816                        size += self._struct_size(m.struct)
817                    else:
818                        size += m.len
819                else:
820                    format = NlAttr.get_format(m.type, m.byte_order)
821                    size += format.size
822            return size
823        else:
824            return 0
825
826    def _decode_struct(self, data, name):
827        members = self.consts[name].members
828        attrs = dict()
829        offset = 0
830        for m in members:
831            value = None
832            if m.type == 'pad':
833                offset += m.len
834            elif m.type == 'binary':
835                if m.struct:
836                    len = self._struct_size(m.struct)
837                    value = self._decode_struct(data[offset : offset + len],
838                                                m.struct)
839                    offset += len
840                else:
841                    value = data[offset : offset + m.len]
842                    offset += m.len
843            else:
844                format = NlAttr.get_format(m.type, m.byte_order)
845                [ value ] = format.unpack_from(data, offset)
846                offset += format.size
847            if value is not None:
848                if m.enum:
849                    value = self._decode_enum(value, m)
850                elif m.display_hint:
851                    value = self._formatted_string(value, m.display_hint)
852                attrs[m.name] = value
853        return attrs
854
855    def _encode_struct(self, name, vals):
856        members = self.consts[name].members
857        attr_payload = b''
858        for m in members:
859            value = vals.pop(m.name) if m.name in vals else None
860            if m.type == 'pad':
861                attr_payload += bytearray(m.len)
862            elif m.type == 'binary':
863                if m.struct:
864                    if value is None:
865                        value = dict()
866                    attr_payload += self._encode_struct(m.struct, value)
867                else:
868                    if value is None:
869                        attr_payload += bytearray(m.len)
870                    else:
871                        attr_payload += bytes.fromhex(value)
872            else:
873                if value is None:
874                    value = 0
875                format = NlAttr.get_format(m.type, m.byte_order)
876                attr_payload += format.pack(value)
877        return attr_payload
878
879    def _formatted_string(self, raw, display_hint):
880        if display_hint == 'mac':
881            formatted = ':'.join('%02x' % b for b in raw)
882        elif display_hint == 'hex':
883            if isinstance(raw, int):
884                formatted = hex(raw)
885            else:
886                formatted = bytes.hex(raw, ' ')
887        elif display_hint in [ 'ipv4', 'ipv6' ]:
888            formatted = format(ipaddress.ip_address(raw))
889        elif display_hint == 'uuid':
890            formatted = str(uuid.UUID(bytes=raw))
891        else:
892            formatted = raw
893        return formatted
894
895    def handle_ntf(self, decoded):
896        msg = dict()
897        if self.include_raw:
898            msg['raw'] = decoded
899        op = self.rsp_by_value[decoded.cmd()]
900        attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
901        if op.fixed_header:
902            attrs.update(self._decode_struct(decoded.raw, op.fixed_header))
903
904        msg['name'] = op['name']
905        msg['msg'] = attrs
906        self.async_msg_queue.append(msg)
907
908    def check_ntf(self):
909        while True:
910            try:
911                reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT)
912            except BlockingIOError:
913                return
914
915            nms = NlMsgs(reply)
916            self._recv_dbg_print(reply, nms)
917            for nl_msg in nms:
918                if nl_msg.error:
919                    print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
920                    print(nl_msg)
921                    continue
922                if nl_msg.done:
923                    print("Netlink done while checking for ntf!?")
924                    continue
925
926                decoded = self.nlproto.decode(self, nl_msg, None)
927                if decoded.cmd() not in self.async_msg_ids:
928                    print("Unexpected msg id done while checking for ntf", decoded)
929                    continue
930
931                self.handle_ntf(decoded)
932
933    def operation_do_attributes(self, name):
934      """
935      For a given operation name, find and return a supported
936      set of attributes (as a dict).
937      """
938      op = self.find_operation(name)
939      if not op:
940        return None
941
942      return op['do']['request']['attributes'].copy()
943
944    def _encode_message(self, op, vals, flags, req_seq):
945        nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
946        for flag in flags or []:
947            nl_flags |= flag
948
949        msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
950        if op.fixed_header:
951            msg += self._encode_struct(op.fixed_header, vals)
952        search_attrs = SpaceAttrs(op.attr_set, vals)
953        for name, value in vals.items():
954            msg += self._add_attr(op.attr_set.name, name, value, search_attrs)
955        msg = _genl_msg_finalize(msg)
956        return msg
957
958    def _ops(self, ops):
959        reqs_by_seq = {}
960        req_seq = random.randint(1024, 65535)
961        payload = b''
962        for (method, vals, flags) in ops:
963            op = self.ops[method]
964            msg = self._encode_message(op, vals, flags, req_seq)
965            reqs_by_seq[req_seq] = (op, msg, flags)
966            payload += msg
967            req_seq += 1
968
969        self.sock.send(payload, 0)
970
971        done = False
972        rsp = []
973        op_rsp = []
974        while not done:
975            reply = self.sock.recv(self._recv_size)
976            nms = NlMsgs(reply, attr_space=op.attr_set)
977            self._recv_dbg_print(reply, nms)
978            for nl_msg in nms:
979                if nl_msg.nl_seq in reqs_by_seq:
980                    (op, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq]
981                    if nl_msg.extack:
982                        self._decode_extack(req_msg, op, nl_msg.extack)
983                else:
984                    op = None
985                    req_flags = []
986
987                if nl_msg.error:
988                    raise NlError(nl_msg)
989                if nl_msg.done:
990                    if nl_msg.extack:
991                        print("Netlink warning:")
992                        print(nl_msg)
993
994                    if Netlink.NLM_F_DUMP in req_flags:
995                        rsp.append(op_rsp)
996                    elif not op_rsp:
997                        rsp.append(None)
998                    elif len(op_rsp) == 1:
999                        rsp.append(op_rsp[0])
1000                    else:
1001                        rsp.append(op_rsp)
1002                    op_rsp = []
1003
1004                    del reqs_by_seq[nl_msg.nl_seq]
1005                    done = len(reqs_by_seq) == 0
1006                    break
1007
1008                decoded = self.nlproto.decode(self, nl_msg, op)
1009
1010                # Check if this is a reply to our request
1011                if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value:
1012                    if decoded.cmd() in self.async_msg_ids:
1013                        self.handle_ntf(decoded)
1014                        continue
1015                    else:
1016                        print('Unexpected message: ' + repr(decoded))
1017                        continue
1018
1019                rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
1020                if op.fixed_header:
1021                    rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header))
1022                op_rsp.append(rsp_msg)
1023
1024        return rsp
1025
1026    def _op(self, method, vals, flags=None, dump=False):
1027        req_flags = flags or []
1028        if dump:
1029            req_flags.append(Netlink.NLM_F_DUMP)
1030
1031        ops = [(method, vals, req_flags)]
1032        return self._ops(ops)[0]
1033
1034    def do(self, method, vals, flags=None):
1035        return self._op(method, vals, flags)
1036
1037    def dump(self, method, vals):
1038        return self._op(method, vals, dump=True)
1039
1040    def do_multi(self, ops):
1041        return self._ops(ops)
1042