xref: /linux/tools/net/ynl/lib/ynl.py (revision 4b132aacb0768ac1e652cf517097ea6f237214b9)
1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2
3from collections import namedtuple
4from enum import Enum
5import functools
6import os
7import random
8import socket
9import struct
10from struct import Struct
11import sys
12import yaml
13import ipaddress
14import uuid
15
16from .nlspec import SpecFamily
17
18#
19# Generic Netlink code which should really be in some library, but I can't quickly find one.
20#
21
22
23class Netlink:
24    # Netlink socket
25    SOL_NETLINK = 270
26
27    NETLINK_ADD_MEMBERSHIP = 1
28    NETLINK_CAP_ACK = 10
29    NETLINK_EXT_ACK = 11
30    NETLINK_GET_STRICT_CHK = 12
31
32    # Netlink message
33    NLMSG_ERROR = 2
34    NLMSG_DONE = 3
35
36    NLM_F_REQUEST = 1
37    NLM_F_ACK = 4
38    NLM_F_ROOT = 0x100
39    NLM_F_MATCH = 0x200
40
41    NLM_F_REPLACE = 0x100
42    NLM_F_EXCL = 0x200
43    NLM_F_CREATE = 0x400
44    NLM_F_APPEND = 0x800
45
46    NLM_F_CAPPED = 0x100
47    NLM_F_ACK_TLVS = 0x200
48
49    NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
50
51    NLA_F_NESTED = 0x8000
52    NLA_F_NET_BYTEORDER = 0x4000
53
54    NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
55
56    # Genetlink defines
57    NETLINK_GENERIC = 16
58
59    GENL_ID_CTRL = 0x10
60
61    # nlctrl
62    CTRL_CMD_GETFAMILY = 3
63
64    CTRL_ATTR_FAMILY_ID = 1
65    CTRL_ATTR_FAMILY_NAME = 2
66    CTRL_ATTR_MAXATTR = 5
67    CTRL_ATTR_MCAST_GROUPS = 7
68
69    CTRL_ATTR_MCAST_GRP_NAME = 1
70    CTRL_ATTR_MCAST_GRP_ID = 2
71
72    # Extack types
73    NLMSGERR_ATTR_MSG = 1
74    NLMSGERR_ATTR_OFFS = 2
75    NLMSGERR_ATTR_COOKIE = 3
76    NLMSGERR_ATTR_POLICY = 4
77    NLMSGERR_ATTR_MISS_TYPE = 5
78    NLMSGERR_ATTR_MISS_NEST = 6
79
80    # Policy types
81    NL_POLICY_TYPE_ATTR_TYPE = 1
82    NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2
83    NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3
84    NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4
85    NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5
86    NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6
87    NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7
88    NL_POLICY_TYPE_ATTR_POLICY_IDX = 8
89    NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9
90    NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10
91    NL_POLICY_TYPE_ATTR_PAD = 11
92    NL_POLICY_TYPE_ATTR_MASK = 12
93
94    AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64',
95                                  's8', 's16', 's32', 's64',
96                                  'binary', 'string', 'nul-string',
97                                  'nested', 'nested-array',
98                                  'bitfield32', 'sint', 'uint'])
99
100class NlError(Exception):
101  def __init__(self, nl_msg):
102    self.nl_msg = nl_msg
103    self.error = -nl_msg.error
104
105  def __str__(self):
106    return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}"
107
108
109class ConfigError(Exception):
110    pass
111
112
113class NlAttr:
114    ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
115    type_formats = {
116        'u8' : ScalarFormat(Struct('B'), Struct("B"),  Struct("B")),
117        's8' : ScalarFormat(Struct('b'), Struct("b"),  Struct("b")),
118        'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
119        's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
120        'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
121        's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
122        'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
123        's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
124    }
125
126    def __init__(self, raw, offset):
127        self._len, self._type = struct.unpack("HH", raw[offset : offset + 4])
128        self.type = self._type & ~Netlink.NLA_TYPE_MASK
129        self.is_nest = self._type & Netlink.NLA_F_NESTED
130        self.payload_len = self._len
131        self.full_len = (self.payload_len + 3) & ~3
132        self.raw = raw[offset + 4 : offset + self.payload_len]
133
134    @classmethod
135    def get_format(cls, attr_type, byte_order=None):
136        format = cls.type_formats[attr_type]
137        if byte_order:
138            return format.big if byte_order == "big-endian" \
139                else format.little
140        return format.native
141
142    def as_scalar(self, attr_type, byte_order=None):
143        format = self.get_format(attr_type, byte_order)
144        return format.unpack(self.raw)[0]
145
146    def as_auto_scalar(self, attr_type, byte_order=None):
147        if len(self.raw) != 4 and len(self.raw) != 8:
148            raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}")
149        real_type = attr_type[0] + str(len(self.raw) * 8)
150        format = self.get_format(real_type, byte_order)
151        return format.unpack(self.raw)[0]
152
153    def as_strz(self):
154        return self.raw.decode('ascii')[:-1]
155
156    def as_bin(self):
157        return self.raw
158
159    def as_c_array(self, type):
160        format = self.get_format(type)
161        return [ x[0] for x in format.iter_unpack(self.raw) ]
162
163    def __repr__(self):
164        return f"[type:{self.type} len:{self._len}] {self.raw}"
165
166
167class NlAttrs:
168    def __init__(self, msg, offset=0):
169        self.attrs = []
170
171        while offset < len(msg):
172            attr = NlAttr(msg, offset)
173            offset += attr.full_len
174            self.attrs.append(attr)
175
176    def __iter__(self):
177        yield from self.attrs
178
179    def __repr__(self):
180        msg = ''
181        for a in self.attrs:
182            if msg:
183                msg += '\n'
184            msg += repr(a)
185        return msg
186
187
188class NlMsg:
189    def __init__(self, msg, offset, attr_space=None):
190        self.hdr = msg[offset : offset + 16]
191
192        self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
193            struct.unpack("IHHII", self.hdr)
194
195        self.raw = msg[offset + 16 : offset + self.nl_len]
196
197        self.error = 0
198        self.done = 0
199
200        extack_off = None
201        if self.nl_type == Netlink.NLMSG_ERROR:
202            self.error = struct.unpack("i", self.raw[0:4])[0]
203            self.done = 1
204            extack_off = 20
205        elif self.nl_type == Netlink.NLMSG_DONE:
206            self.error = struct.unpack("i", self.raw[0:4])[0]
207            self.done = 1
208            extack_off = 4
209
210        self.extack = None
211        if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
212            self.extack = dict()
213            extack_attrs = NlAttrs(self.raw[extack_off:])
214            for extack in extack_attrs:
215                if extack.type == Netlink.NLMSGERR_ATTR_MSG:
216                    self.extack['msg'] = extack.as_strz()
217                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
218                    self.extack['miss-type'] = extack.as_scalar('u32')
219                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
220                    self.extack['miss-nest'] = extack.as_scalar('u32')
221                elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
222                    self.extack['bad-attr-offs'] = extack.as_scalar('u32')
223                elif extack.type == Netlink.NLMSGERR_ATTR_POLICY:
224                    self.extack['policy'] = self._decode_policy(extack.raw)
225                else:
226                    if 'unknown' not in self.extack:
227                        self.extack['unknown'] = []
228                    self.extack['unknown'].append(extack)
229
230            if attr_space:
231                # We don't have the ability to parse nests yet, so only do global
232                if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
233                    miss_type = self.extack['miss-type']
234                    if miss_type in attr_space.attrs_by_val:
235                        spec = attr_space.attrs_by_val[miss_type]
236                        self.extack['miss-type'] = spec['name']
237                        if 'doc' in spec:
238                            self.extack['miss-type-doc'] = spec['doc']
239
240    def _decode_policy(self, raw):
241        policy = {}
242        for attr in NlAttrs(raw):
243            if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE:
244                type = attr.as_scalar('u32')
245                policy['type'] = Netlink.AttrType(type).name
246            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S:
247                policy['min-value'] = attr.as_scalar('s64')
248            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S:
249                policy['max-value'] = attr.as_scalar('s64')
250            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U:
251                policy['min-value'] = attr.as_scalar('u64')
252            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U:
253                policy['max-value'] = attr.as_scalar('u64')
254            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH:
255                policy['min-length'] = attr.as_scalar('u32')
256            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH:
257                policy['max-length'] = attr.as_scalar('u32')
258            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK:
259                policy['bitfield32-mask'] = attr.as_scalar('u32')
260            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK:
261                policy['mask'] = attr.as_scalar('u64')
262        return policy
263
264    def cmd(self):
265        return self.nl_type
266
267    def __repr__(self):
268        msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}"
269        if self.error:
270            msg += '\n\terror: ' + str(self.error)
271        if self.extack:
272            msg += '\n\textack: ' + repr(self.extack)
273        return msg
274
275
276class NlMsgs:
277    def __init__(self, data, attr_space=None):
278        self.msgs = []
279
280        offset = 0
281        while offset < len(data):
282            msg = NlMsg(data, offset, attr_space=attr_space)
283            offset += msg.nl_len
284            self.msgs.append(msg)
285
286    def __iter__(self):
287        yield from self.msgs
288
289
290genl_family_name_to_id = None
291
292
293def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
294    # we prepend length in _genl_msg_finalize()
295    if seq is None:
296        seq = random.randint(1, 1024)
297    nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
298    genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
299    return nlmsg + genlmsg
300
301
302def _genl_msg_finalize(msg):
303    return struct.pack("I", len(msg) + 4) + msg
304
305
306def _genl_load_families():
307    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
308        sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
309
310        msg = _genl_msg(Netlink.GENL_ID_CTRL,
311                        Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
312                        Netlink.CTRL_CMD_GETFAMILY, 1)
313        msg = _genl_msg_finalize(msg)
314
315        sock.send(msg, 0)
316
317        global genl_family_name_to_id
318        genl_family_name_to_id = dict()
319
320        while True:
321            reply = sock.recv(128 * 1024)
322            nms = NlMsgs(reply)
323            for nl_msg in nms:
324                if nl_msg.error:
325                    print("Netlink error:", nl_msg.error)
326                    return
327                if nl_msg.done:
328                    return
329
330                gm = GenlMsg(nl_msg)
331                fam = dict()
332                for attr in NlAttrs(gm.raw):
333                    if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
334                        fam['id'] = attr.as_scalar('u16')
335                    elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
336                        fam['name'] = attr.as_strz()
337                    elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
338                        fam['maxattr'] = attr.as_scalar('u32')
339                    elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
340                        fam['mcast'] = dict()
341                        for entry in NlAttrs(attr.raw):
342                            mcast_name = None
343                            mcast_id = None
344                            for entry_attr in NlAttrs(entry.raw):
345                                if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
346                                    mcast_name = entry_attr.as_strz()
347                                elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
348                                    mcast_id = entry_attr.as_scalar('u32')
349                            if mcast_name and mcast_id is not None:
350                                fam['mcast'][mcast_name] = mcast_id
351                if 'name' in fam and 'id' in fam:
352                    genl_family_name_to_id[fam['name']] = fam
353
354
355class GenlMsg:
356    def __init__(self, nl_msg):
357        self.nl = nl_msg
358        self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
359        self.raw = nl_msg.raw[4:]
360
361    def cmd(self):
362        return self.genl_cmd
363
364    def __repr__(self):
365        msg = repr(self.nl)
366        msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
367        for a in self.raw_attrs:
368            msg += '\t\t' + repr(a) + '\n'
369        return msg
370
371
372class NetlinkProtocol:
373    def __init__(self, family_name, proto_num):
374        self.family_name = family_name
375        self.proto_num = proto_num
376
377    def _message(self, nl_type, nl_flags, seq=None):
378        if seq is None:
379            seq = random.randint(1, 1024)
380        nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
381        return nlmsg
382
383    def message(self, flags, command, version, seq=None):
384        return self._message(command, flags, seq)
385
386    def _decode(self, nl_msg):
387        return nl_msg
388
389    def decode(self, ynl, nl_msg, op):
390        msg = self._decode(nl_msg)
391        fixed_header_size = ynl._struct_size(op.fixed_header)
392        msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size)
393        return msg
394
395    def get_mcast_id(self, mcast_name, mcast_groups):
396        if mcast_name not in mcast_groups:
397            raise Exception(f'Multicast group "{mcast_name}" not present in the spec')
398        return mcast_groups[mcast_name].value
399
400    def msghdr_size(self):
401        return 16
402
403
404class GenlProtocol(NetlinkProtocol):
405    def __init__(self, family_name):
406        super().__init__(family_name, Netlink.NETLINK_GENERIC)
407
408        global genl_family_name_to_id
409        if genl_family_name_to_id is None:
410            _genl_load_families()
411
412        self.genl_family = genl_family_name_to_id[family_name]
413        self.family_id = genl_family_name_to_id[family_name]['id']
414
415    def message(self, flags, command, version, seq=None):
416        nlmsg = self._message(self.family_id, flags, seq)
417        genlmsg = struct.pack("BBH", command, version, 0)
418        return nlmsg + genlmsg
419
420    def _decode(self, nl_msg):
421        return GenlMsg(nl_msg)
422
423    def get_mcast_id(self, mcast_name, mcast_groups):
424        if mcast_name not in self.genl_family['mcast']:
425            raise Exception(f'Multicast group "{mcast_name}" not present in the family')
426        return self.genl_family['mcast'][mcast_name]
427
428    def msghdr_size(self):
429        return super().msghdr_size() + 4
430
431
432class SpaceAttrs:
433    SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values'])
434
435    def __init__(self, attr_space, attrs, outer = None):
436        outer_scopes = outer.scopes if outer else []
437        inner_scope = self.SpecValuesPair(attr_space, attrs)
438        self.scopes = [inner_scope] + outer_scopes
439
440    def lookup(self, name):
441        for scope in self.scopes:
442            if name in scope.spec:
443                if name in scope.values:
444                    return scope.values[name]
445                spec_name = scope.spec.yaml['name']
446                raise Exception(
447                    f"No value for '{name}' in attribute space '{spec_name}'")
448        raise Exception(f"Attribute '{name}' not defined in any attribute-set")
449
450
451#
452# YNL implementation details.
453#
454
455
456class YnlFamily(SpecFamily):
457    def __init__(self, def_path, schema=None, process_unknown=False,
458                 recv_size=0):
459        super().__init__(def_path, schema)
460
461        self.include_raw = False
462        self.process_unknown = process_unknown
463
464        try:
465            if self.proto == "netlink-raw":
466                self.nlproto = NetlinkProtocol(self.yaml['name'],
467                                               self.yaml['protonum'])
468            else:
469                self.nlproto = GenlProtocol(self.yaml['name'])
470        except KeyError:
471            raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
472
473        self._recv_dbg = False
474        # Note that netlink will use conservative (min) message size for
475        # the first dump recv() on the socket, our setting will only matter
476        # from the second recv() on.
477        self._recv_size = recv_size if recv_size else 131072
478        # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo)
479        # for a message, so smaller receive sizes will lead to truncation.
480        # Note that the min size for other families may be larger than 4k!
481        if self._recv_size < 4000:
482            raise ConfigError()
483
484        self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
485        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
486        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
487        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)
488
489        self.async_msg_ids = set()
490        self.async_msg_queue = []
491
492        for msg in self.msgs.values():
493            if msg.is_async:
494                self.async_msg_ids.add(msg.rsp_value)
495
496        for op_name, op in self.ops.items():
497            bound_f = functools.partial(self._op, op_name)
498            setattr(self, op.ident_name, bound_f)
499
500
501    def ntf_subscribe(self, mcast_name):
502        mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
503        self.sock.bind((0, 0))
504        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
505                             mcast_id)
506
507    def set_recv_dbg(self, enabled):
508        self._recv_dbg = enabled
509
510    def _recv_dbg_print(self, reply, nl_msgs):
511        if not self._recv_dbg:
512            return
513        print("Recv: read", len(reply), "bytes,",
514              len(nl_msgs.msgs), "messages", file=sys.stderr)
515        for nl_msg in nl_msgs:
516            print("  ", nl_msg, file=sys.stderr)
517
518    def _encode_enum(self, attr_spec, value):
519        enum = self.consts[attr_spec['enum']]
520        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
521            scalar = 0
522            if isinstance(value, str):
523                value = [value]
524            for single_value in value:
525                scalar += enum.entries[single_value].user_value(as_flags = True)
526            return scalar
527        else:
528            return enum.entries[value].user_value()
529
530    def _get_scalar(self, attr_spec, value):
531        try:
532            return int(value)
533        except (ValueError, TypeError) as e:
534            if 'enum' not in attr_spec:
535                raise e
536        return self._encode_enum(attr_spec, value)
537
538    def _add_attr(self, space, name, value, search_attrs):
539        try:
540            attr = self.attr_sets[space][name]
541        except KeyError:
542            raise Exception(f"Space '{space}' has no attribute '{name}'")
543        nl_type = attr.value
544
545        if attr.is_multi and isinstance(value, list):
546            attr_payload = b''
547            for subvalue in value:
548                attr_payload += self._add_attr(space, name, subvalue, search_attrs)
549            return attr_payload
550
551        if attr["type"] == 'nest':
552            nl_type |= Netlink.NLA_F_NESTED
553            attr_payload = b''
554            sub_attrs = SpaceAttrs(self.attr_sets[space], value, search_attrs)
555            for subname, subvalue in value.items():
556                attr_payload += self._add_attr(attr['nested-attributes'],
557                                               subname, subvalue, sub_attrs)
558        elif attr["type"] == 'flag':
559            if not value:
560                # If value is absent or false then skip attribute creation.
561                return b''
562            attr_payload = b''
563        elif attr["type"] == 'string':
564            attr_payload = str(value).encode('ascii') + b'\x00'
565        elif attr["type"] == 'binary':
566            if isinstance(value, bytes):
567                attr_payload = value
568            elif isinstance(value, str):
569                attr_payload = bytes.fromhex(value)
570            elif isinstance(value, dict) and attr.struct_name:
571                attr_payload = self._encode_struct(attr.struct_name, value)
572            else:
573                raise Exception(f'Unknown type for binary attribute, value: {value}')
574        elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar:
575            scalar = self._get_scalar(attr, value)
576            if attr.is_auto_scalar:
577                attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64')
578            else:
579                attr_type = attr["type"]
580            format = NlAttr.get_format(attr_type, attr.byte_order)
581            attr_payload = format.pack(scalar)
582        elif attr['type'] in "bitfield32":
583            scalar_value = self._get_scalar(attr, value["value"])
584            scalar_selector = self._get_scalar(attr, value["selector"])
585            attr_payload = struct.pack("II", scalar_value, scalar_selector)
586        elif attr['type'] == 'sub-message':
587            msg_format = self._resolve_selector(attr, search_attrs)
588            attr_payload = b''
589            if msg_format.fixed_header:
590                attr_payload += self._encode_struct(msg_format.fixed_header, value)
591            if msg_format.attr_set:
592                if msg_format.attr_set in self.attr_sets:
593                    nl_type |= Netlink.NLA_F_NESTED
594                    sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs)
595                    for subname, subvalue in value.items():
596                        attr_payload += self._add_attr(msg_format.attr_set,
597                                                       subname, subvalue, sub_attrs)
598                else:
599                    raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'")
600        else:
601            raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
602
603        pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
604        return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
605
606    def _decode_enum(self, raw, attr_spec):
607        enum = self.consts[attr_spec['enum']]
608        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
609            i = 0
610            value = set()
611            while raw:
612                if raw & 1:
613                    value.add(enum.entries_by_val[i].name)
614                raw >>= 1
615                i += 1
616        else:
617            value = enum.entries_by_val[raw].name
618        return value
619
620    def _decode_binary(self, attr, attr_spec):
621        if attr_spec.struct_name:
622            decoded = self._decode_struct(attr.raw, attr_spec.struct_name)
623        elif attr_spec.sub_type:
624            decoded = attr.as_c_array(attr_spec.sub_type)
625        else:
626            decoded = attr.as_bin()
627            if attr_spec.display_hint:
628                decoded = self._formatted_string(decoded, attr_spec.display_hint)
629        return decoded
630
631    def _decode_array_attr(self, attr, attr_spec):
632        decoded = []
633        offset = 0
634        while offset < len(attr.raw):
635            item = NlAttr(attr.raw, offset)
636            offset += item.full_len
637
638            if attr_spec["sub-type"] == 'nest':
639                subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes'])
640                decoded.append({ item.type: subattrs })
641            elif attr_spec["sub-type"] == 'binary':
642                subattrs = item.as_bin()
643                if attr_spec.display_hint:
644                    subattrs = self._formatted_string(subattrs, attr_spec.display_hint)
645                decoded.append(subattrs)
646            elif attr_spec["sub-type"] in NlAttr.type_formats:
647                subattrs = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order)
648                if attr_spec.display_hint:
649                    subattrs = self._formatted_string(subattrs, attr_spec.display_hint)
650                decoded.append(subattrs)
651            else:
652                raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}')
653        return decoded
654
655    def _decode_nest_type_value(self, attr, attr_spec):
656        decoded = {}
657        value = attr
658        for name in attr_spec['type-value']:
659            value = NlAttr(value.raw, 0)
660            decoded[name] = value.type
661        subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes'])
662        decoded.update(subattrs)
663        return decoded
664
665    def _decode_unknown(self, attr):
666        if attr.is_nest:
667            return self._decode(NlAttrs(attr.raw), None)
668        else:
669            return attr.as_bin()
670
671    def _rsp_add(self, rsp, name, is_multi, decoded):
672        if is_multi == None:
673            if name in rsp and type(rsp[name]) is not list:
674                rsp[name] = [rsp[name]]
675                is_multi = True
676            else:
677                is_multi = False
678
679        if not is_multi:
680            rsp[name] = decoded
681        elif name in rsp:
682            rsp[name].append(decoded)
683        else:
684            rsp[name] = [decoded]
685
686    def _resolve_selector(self, attr_spec, search_attrs):
687        sub_msg = attr_spec.sub_message
688        if sub_msg not in self.sub_msgs:
689            raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}")
690        sub_msg_spec = self.sub_msgs[sub_msg]
691
692        selector = attr_spec.selector
693        value = search_attrs.lookup(selector)
694        if value not in sub_msg_spec.formats:
695            raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'")
696
697        spec = sub_msg_spec.formats[value]
698        return spec
699
700    def _decode_sub_msg(self, attr, attr_spec, search_attrs):
701        msg_format = self._resolve_selector(attr_spec, search_attrs)
702        decoded = {}
703        offset = 0
704        if msg_format.fixed_header:
705            decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header));
706            offset = self._struct_size(msg_format.fixed_header)
707        if msg_format.attr_set:
708            if msg_format.attr_set in self.attr_sets:
709                subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set)
710                decoded.update(subdict)
711            else:
712                raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'")
713        return decoded
714
715    def _decode(self, attrs, space, outer_attrs = None):
716        rsp = dict()
717        if space:
718            attr_space = self.attr_sets[space]
719            search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs)
720
721        for attr in attrs:
722            try:
723                attr_spec = attr_space.attrs_by_val[attr.type]
724            except (KeyError, UnboundLocalError):
725                if not self.process_unknown:
726                    raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'")
727                attr_name = f"UnknownAttr({attr.type})"
728                self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr))
729                continue
730
731            if attr_spec["type"] == 'nest':
732                subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs)
733                decoded = subdict
734            elif attr_spec["type"] == 'string':
735                decoded = attr.as_strz()
736            elif attr_spec["type"] == 'binary':
737                decoded = self._decode_binary(attr, attr_spec)
738            elif attr_spec["type"] == 'flag':
739                decoded = True
740            elif attr_spec.is_auto_scalar:
741                decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order)
742            elif attr_spec["type"] in NlAttr.type_formats:
743                decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
744                if 'enum' in attr_spec:
745                    decoded = self._decode_enum(decoded, attr_spec)
746                elif attr_spec.display_hint:
747                    decoded = self._formatted_string(decoded, attr_spec.display_hint)
748            elif attr_spec["type"] == 'indexed-array':
749                decoded = self._decode_array_attr(attr, attr_spec)
750            elif attr_spec["type"] == 'bitfield32':
751                value, selector = struct.unpack("II", attr.raw)
752                if 'enum' in attr_spec:
753                    value = self._decode_enum(value, attr_spec)
754                    selector = self._decode_enum(selector, attr_spec)
755                decoded = {"value": value, "selector": selector}
756            elif attr_spec["type"] == 'sub-message':
757                decoded = self._decode_sub_msg(attr, attr_spec, search_attrs)
758            elif attr_spec["type"] == 'nest-type-value':
759                decoded = self._decode_nest_type_value(attr, attr_spec)
760            else:
761                if not self.process_unknown:
762                    raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
763                decoded = self._decode_unknown(attr)
764
765            self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded)
766
767        return rsp
768
769    def _decode_extack_path(self, attrs, attr_set, offset, target):
770        for attr in attrs:
771            try:
772                attr_spec = attr_set.attrs_by_val[attr.type]
773            except KeyError:
774                raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'")
775            if offset > target:
776                break
777            if offset == target:
778                return '.' + attr_spec.name
779
780            if offset + attr.full_len <= target:
781                offset += attr.full_len
782                continue
783            if attr_spec['type'] != 'nest':
784                raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
785            offset += 4
786            subpath = self._decode_extack_path(NlAttrs(attr.raw),
787                                               self.attr_sets[attr_spec['nested-attributes']],
788                                               offset, target)
789            if subpath is None:
790                return None
791            return '.' + attr_spec.name + subpath
792
793        return None
794
795    def _decode_extack(self, request, op, extack):
796        if 'bad-attr-offs' not in extack:
797            return
798
799        msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op)
800        offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header)
801        path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
802                                        extack['bad-attr-offs'])
803        if path:
804            del extack['bad-attr-offs']
805            extack['bad-attr'] = path
806
807    def _struct_size(self, name):
808        if name:
809            members = self.consts[name].members
810            size = 0
811            for m in members:
812                if m.type in ['pad', 'binary']:
813                    if m.struct:
814                        size += self._struct_size(m.struct)
815                    else:
816                        size += m.len
817                else:
818                    format = NlAttr.get_format(m.type, m.byte_order)
819                    size += format.size
820            return size
821        else:
822            return 0
823
824    def _decode_struct(self, data, name):
825        members = self.consts[name].members
826        attrs = dict()
827        offset = 0
828        for m in members:
829            value = None
830            if m.type == 'pad':
831                offset += m.len
832            elif m.type == 'binary':
833                if m.struct:
834                    len = self._struct_size(m.struct)
835                    value = self._decode_struct(data[offset : offset + len],
836                                                m.struct)
837                    offset += len
838                else:
839                    value = data[offset : offset + m.len]
840                    offset += m.len
841            else:
842                format = NlAttr.get_format(m.type, m.byte_order)
843                [ value ] = format.unpack_from(data, offset)
844                offset += format.size
845            if value is not None:
846                if m.enum:
847                    value = self._decode_enum(value, m)
848                elif m.display_hint:
849                    value = self._formatted_string(value, m.display_hint)
850                attrs[m.name] = value
851        return attrs
852
853    def _encode_struct(self, name, vals):
854        members = self.consts[name].members
855        attr_payload = b''
856        for m in members:
857            value = vals.pop(m.name) if m.name in vals else None
858            if m.type == 'pad':
859                attr_payload += bytearray(m.len)
860            elif m.type == 'binary':
861                if m.struct:
862                    if value is None:
863                        value = dict()
864                    attr_payload += self._encode_struct(m.struct, value)
865                else:
866                    if value is None:
867                        attr_payload += bytearray(m.len)
868                    else:
869                        attr_payload += bytes.fromhex(value)
870            else:
871                if value is None:
872                    value = 0
873                format = NlAttr.get_format(m.type, m.byte_order)
874                attr_payload += format.pack(value)
875        return attr_payload
876
877    def _formatted_string(self, raw, display_hint):
878        if display_hint == 'mac':
879            formatted = ':'.join('%02x' % b for b in raw)
880        elif display_hint == 'hex':
881            if isinstance(raw, int):
882                formatted = hex(raw)
883            else:
884                formatted = bytes.hex(raw, ' ')
885        elif display_hint in [ 'ipv4', 'ipv6' ]:
886            formatted = format(ipaddress.ip_address(raw))
887        elif display_hint == 'uuid':
888            formatted = str(uuid.UUID(bytes=raw))
889        else:
890            formatted = raw
891        return formatted
892
893    def handle_ntf(self, decoded):
894        msg = dict()
895        if self.include_raw:
896            msg['raw'] = decoded
897        op = self.rsp_by_value[decoded.cmd()]
898        attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
899        if op.fixed_header:
900            attrs.update(self._decode_struct(decoded.raw, op.fixed_header))
901
902        msg['name'] = op['name']
903        msg['msg'] = attrs
904        self.async_msg_queue.append(msg)
905
906    def check_ntf(self):
907        while True:
908            try:
909                reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT)
910            except BlockingIOError:
911                return
912
913            nms = NlMsgs(reply)
914            self._recv_dbg_print(reply, nms)
915            for nl_msg in nms:
916                if nl_msg.error:
917                    print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
918                    print(nl_msg)
919                    continue
920                if nl_msg.done:
921                    print("Netlink done while checking for ntf!?")
922                    continue
923
924                op = self.rsp_by_value[nl_msg.cmd()]
925                decoded = self.nlproto.decode(self, nl_msg, op)
926                if decoded.cmd() not in self.async_msg_ids:
927                    print("Unexpected msg id done while checking for ntf", decoded)
928                    continue
929
930                self.handle_ntf(decoded)
931
932    def operation_do_attributes(self, name):
933      """
934      For a given operation name, find and return a supported
935      set of attributes (as a dict).
936      """
937      op = self.find_operation(name)
938      if not op:
939        return None
940
941      return op['do']['request']['attributes'].copy()
942
943    def _encode_message(self, op, vals, flags, req_seq):
944        nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
945        for flag in flags or []:
946            nl_flags |= flag
947
948        msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
949        if op.fixed_header:
950            msg += self._encode_struct(op.fixed_header, vals)
951        search_attrs = SpaceAttrs(op.attr_set, vals)
952        for name, value in vals.items():
953            msg += self._add_attr(op.attr_set.name, name, value, search_attrs)
954        msg = _genl_msg_finalize(msg)
955        return msg
956
957    def _ops(self, ops):
958        reqs_by_seq = {}
959        req_seq = random.randint(1024, 65535)
960        payload = b''
961        for (method, vals, flags) in ops:
962            op = self.ops[method]
963            msg = self._encode_message(op, vals, flags, req_seq)
964            reqs_by_seq[req_seq] = (op, msg, flags)
965            payload += msg
966            req_seq += 1
967
968        self.sock.send(payload, 0)
969
970        done = False
971        rsp = []
972        op_rsp = []
973        while not done:
974            reply = self.sock.recv(self._recv_size)
975            nms = NlMsgs(reply, attr_space=op.attr_set)
976            self._recv_dbg_print(reply, nms)
977            for nl_msg in nms:
978                if nl_msg.nl_seq in reqs_by_seq:
979                    (op, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq]
980                    if nl_msg.extack:
981                        self._decode_extack(req_msg, op, nl_msg.extack)
982                else:
983                    op = self.rsp_by_value[nl_msg.cmd()]
984                    req_flags = []
985
986                if nl_msg.error:
987                    raise NlError(nl_msg)
988                if nl_msg.done:
989                    if nl_msg.extack:
990                        print("Netlink warning:")
991                        print(nl_msg)
992
993                    if Netlink.NLM_F_DUMP in req_flags:
994                        rsp.append(op_rsp)
995                    elif not op_rsp:
996                        rsp.append(None)
997                    elif len(op_rsp) == 1:
998                        rsp.append(op_rsp[0])
999                    else:
1000                        rsp.append(op_rsp)
1001                    op_rsp = []
1002
1003                    del reqs_by_seq[nl_msg.nl_seq]
1004                    done = len(reqs_by_seq) == 0
1005                    break
1006
1007                decoded = self.nlproto.decode(self, nl_msg, op)
1008
1009                # Check if this is a reply to our request
1010                if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value:
1011                    if decoded.cmd() in self.async_msg_ids:
1012                        self.handle_ntf(decoded)
1013                        continue
1014                    else:
1015                        print('Unexpected message: ' + repr(decoded))
1016                        continue
1017
1018                rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
1019                if op.fixed_header:
1020                    rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header))
1021                op_rsp.append(rsp_msg)
1022
1023        return rsp
1024
1025    def _op(self, method, vals, flags=None, dump=False):
1026        req_flags = flags or []
1027        if dump:
1028            req_flags.append(Netlink.NLM_F_DUMP)
1029
1030        ops = [(method, vals, req_flags)]
1031        return self._ops(ops)[0]
1032
1033    def do(self, method, vals, flags=None):
1034        return self._op(method, vals, flags)
1035
1036    def dump(self, method, vals):
1037        return self._op(method, vals, dump=True)
1038
1039    def do_multi(self, ops):
1040        return self._ops(ops)
1041