xref: /linux/tools/net/ynl/pyynl/lib/ynl.py (revision d8f87aa5fa0a4276491fa8ef436cd22605a3f9ba)
1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2#
3# pylint: disable=missing-class-docstring, missing-function-docstring
4# pylint: disable=too-many-branches, too-many-locals, too-many-instance-attributes
5# pylint: disable=too-many-lines
6
7"""
8YAML Netlink Library
9
10An implementation of the genetlink and raw netlink protocols.
11"""
12
13from collections import namedtuple
14from enum import Enum
15import functools
16import os
17import random
18import socket
19import struct
20from struct import Struct
21import sys
22import ipaddress
23import uuid
24import queue
25import selectors
26import time
27
28from .nlspec import SpecFamily
29
30#
31# Generic Netlink code which should really be in some library, but I can't quickly find one.
32#
33
34
35class YnlException(Exception):
36    pass
37
38
39# pylint: disable=too-few-public-methods
40class Netlink:
41    # Netlink socket
42    SOL_NETLINK = 270
43
44    NETLINK_ADD_MEMBERSHIP = 1
45    NETLINK_CAP_ACK = 10
46    NETLINK_EXT_ACK = 11
47    NETLINK_GET_STRICT_CHK = 12
48
49    # Netlink message
50    NLMSG_ERROR = 2
51    NLMSG_DONE = 3
52
53    NLM_F_REQUEST = 1
54    NLM_F_ACK = 4
55    NLM_F_ROOT = 0x100
56    NLM_F_MATCH = 0x200
57
58    NLM_F_REPLACE = 0x100
59    NLM_F_EXCL = 0x200
60    NLM_F_CREATE = 0x400
61    NLM_F_APPEND = 0x800
62
63    NLM_F_CAPPED = 0x100
64    NLM_F_ACK_TLVS = 0x200
65
66    NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
67
68    NLA_F_NESTED = 0x8000
69    NLA_F_NET_BYTEORDER = 0x4000
70
71    NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
72
73    # Genetlink defines
74    NETLINK_GENERIC = 16
75
76    GENL_ID_CTRL = 0x10
77
78    # nlctrl
79    CTRL_CMD_GETFAMILY = 3
80
81    CTRL_ATTR_FAMILY_ID = 1
82    CTRL_ATTR_FAMILY_NAME = 2
83    CTRL_ATTR_MAXATTR = 5
84    CTRL_ATTR_MCAST_GROUPS = 7
85
86    CTRL_ATTR_MCAST_GRP_NAME = 1
87    CTRL_ATTR_MCAST_GRP_ID = 2
88
89    # Extack types
90    NLMSGERR_ATTR_MSG = 1
91    NLMSGERR_ATTR_OFFS = 2
92    NLMSGERR_ATTR_COOKIE = 3
93    NLMSGERR_ATTR_POLICY = 4
94    NLMSGERR_ATTR_MISS_TYPE = 5
95    NLMSGERR_ATTR_MISS_NEST = 6
96
97    # Policy types
98    NL_POLICY_TYPE_ATTR_TYPE = 1
99    NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2
100    NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3
101    NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4
102    NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5
103    NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6
104    NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7
105    NL_POLICY_TYPE_ATTR_POLICY_IDX = 8
106    NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9
107    NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10
108    NL_POLICY_TYPE_ATTR_PAD = 11
109    NL_POLICY_TYPE_ATTR_MASK = 12
110
111    AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64',
112                                  's8', 's16', 's32', 's64',
113                                  'binary', 'string', 'nul-string',
114                                  'nested', 'nested-array',
115                                  'bitfield32', 'sint', 'uint'])
116
117class NlError(Exception):
118    def __init__(self, nl_msg):
119        self.nl_msg = nl_msg
120        self.error = -nl_msg.error
121
122    def __str__(self):
123        msg = "Netlink error: "
124
125        extack = self.nl_msg.extack.copy() if self.nl_msg.extack else {}
126        if 'msg' in extack:
127            msg += extack['msg'] + ': '
128            del extack['msg']
129        msg += os.strerror(self.error)
130        if extack:
131            msg += ' ' + str(extack)
132        return msg
133
134
135class ConfigError(Exception):
136    pass
137
138
139class NlAttr:
140    ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
141    type_formats = {
142        'u8' : ScalarFormat(Struct('B'), Struct("B"),  Struct("B")),
143        's8' : ScalarFormat(Struct('b'), Struct("b"),  Struct("b")),
144        'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
145        's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
146        'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
147        's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
148        'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
149        's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
150    }
151
152    def __init__(self, raw, offset):
153        self._len, self._type = struct.unpack("HH", raw[offset : offset + 4])
154        self.type = self._type & ~Netlink.NLA_TYPE_MASK
155        self.is_nest = self._type & Netlink.NLA_F_NESTED
156        self.payload_len = self._len
157        self.full_len = (self.payload_len + 3) & ~3
158        self.raw = raw[offset + 4 : offset + self.payload_len]
159
160    @classmethod
161    def get_format(cls, attr_type, byte_order=None):
162        format_ = cls.type_formats[attr_type]
163        if byte_order:
164            return format_.big if byte_order == "big-endian" \
165                else format_.little
166        return format_.native
167
168    def as_scalar(self, attr_type, byte_order=None):
169        format_ = self.get_format(attr_type, byte_order)
170        return format_.unpack(self.raw)[0]
171
172    def as_auto_scalar(self, attr_type, byte_order=None):
173        if len(self.raw) != 4 and len(self.raw) != 8:
174            raise YnlException(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}")
175        real_type = attr_type[0] + str(len(self.raw) * 8)
176        format_ = self.get_format(real_type, byte_order)
177        return format_.unpack(self.raw)[0]
178
179    def as_strz(self):
180        return self.raw.decode('ascii')[:-1]
181
182    def as_bin(self):
183        return self.raw
184
185    def as_c_array(self, c_type):
186        format_ = self.get_format(c_type)
187        return [ x[0] for x in format_.iter_unpack(self.raw) ]
188
189    def __repr__(self):
190        return f"[type:{self.type} len:{self._len}] {self.raw}"
191
192
193class NlAttrs:
194    def __init__(self, msg, offset=0):
195        self.attrs = []
196
197        while offset < len(msg):
198            attr = NlAttr(msg, offset)
199            offset += attr.full_len
200            self.attrs.append(attr)
201
202    def __iter__(self):
203        yield from self.attrs
204
205    def __repr__(self):
206        msg = ''
207        for a in self.attrs:
208            if msg:
209                msg += '\n'
210            msg += repr(a)
211        return msg
212
213
214class NlMsg:
215    def __init__(self, msg, offset, attr_space=None):
216        self.hdr = msg[offset : offset + 16]
217
218        self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
219            struct.unpack("IHHII", self.hdr)
220
221        self.raw = msg[offset + 16 : offset + self.nl_len]
222
223        self.error = 0
224        self.done = 0
225
226        extack_off = None
227        if self.nl_type == Netlink.NLMSG_ERROR:
228            self.error = struct.unpack("i", self.raw[0:4])[0]
229            self.done = 1
230            extack_off = 20
231        elif self.nl_type == Netlink.NLMSG_DONE:
232            self.error = struct.unpack("i", self.raw[0:4])[0]
233            self.done = 1
234            extack_off = 4
235
236        self.extack = None
237        if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
238            self.extack = {}
239            extack_attrs = NlAttrs(self.raw[extack_off:])
240            for extack in extack_attrs:
241                if extack.type == Netlink.NLMSGERR_ATTR_MSG:
242                    self.extack['msg'] = extack.as_strz()
243                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
244                    self.extack['miss-type'] = extack.as_scalar('u32')
245                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
246                    self.extack['miss-nest'] = extack.as_scalar('u32')
247                elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
248                    self.extack['bad-attr-offs'] = extack.as_scalar('u32')
249                elif extack.type == Netlink.NLMSGERR_ATTR_POLICY:
250                    self.extack['policy'] = self._decode_policy(extack.raw)
251                else:
252                    if 'unknown' not in self.extack:
253                        self.extack['unknown'] = []
254                    self.extack['unknown'].append(extack)
255
256            if attr_space:
257                self.annotate_extack(attr_space)
258
259    def _decode_policy(self, raw):
260        policy = {}
261        for attr in NlAttrs(raw):
262            if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE:
263                type_ = attr.as_scalar('u32')
264                policy['type'] = Netlink.AttrType(type_).name
265            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S:
266                policy['min-value'] = attr.as_scalar('s64')
267            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S:
268                policy['max-value'] = attr.as_scalar('s64')
269            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U:
270                policy['min-value'] = attr.as_scalar('u64')
271            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U:
272                policy['max-value'] = attr.as_scalar('u64')
273            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH:
274                policy['min-length'] = attr.as_scalar('u32')
275            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH:
276                policy['max-length'] = attr.as_scalar('u32')
277            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK:
278                policy['bitfield32-mask'] = attr.as_scalar('u32')
279            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK:
280                policy['mask'] = attr.as_scalar('u64')
281        return policy
282
283    def annotate_extack(self, attr_space):
284        """ Make extack more human friendly with attribute information """
285
286        # We don't have the ability to parse nests yet, so only do global
287        if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
288            miss_type = self.extack['miss-type']
289            if miss_type in attr_space.attrs_by_val:
290                spec = attr_space.attrs_by_val[miss_type]
291                self.extack['miss-type'] = spec['name']
292                if 'doc' in spec:
293                    self.extack['miss-type-doc'] = spec['doc']
294
295    def cmd(self):
296        return self.nl_type
297
298    def __repr__(self):
299        msg = (f"nl_len = {self.nl_len} ({len(self.raw)}) "
300               f"nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}")
301        if self.error:
302            msg += '\n\terror: ' + str(self.error)
303        if self.extack:
304            msg += '\n\textack: ' + repr(self.extack)
305        return msg
306
307
308# pylint: disable=too-few-public-methods
309class NlMsgs:
310    def __init__(self, data):
311        self.msgs = []
312
313        offset = 0
314        while offset < len(data):
315            msg = NlMsg(data, offset)
316            offset += msg.nl_len
317            self.msgs.append(msg)
318
319    def __iter__(self):
320        yield from self.msgs
321
322
323def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
324    # we prepend length in _genl_msg_finalize()
325    if seq is None:
326        seq = random.randint(1, 1024)
327    nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
328    genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
329    return nlmsg + genlmsg
330
331
332def _genl_msg_finalize(msg):
333    return struct.pack("I", len(msg) + 4) + msg
334
335
336# pylint: disable=too-many-nested-blocks
337def _genl_load_families():
338    genl_family_name_to_id = {}
339
340    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
341        sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
342
343        msg = _genl_msg(Netlink.GENL_ID_CTRL,
344                        Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
345                        Netlink.CTRL_CMD_GETFAMILY, 1)
346        msg = _genl_msg_finalize(msg)
347
348        sock.send(msg, 0)
349
350        while True:
351            reply = sock.recv(128 * 1024)
352            nms = NlMsgs(reply)
353            for nl_msg in nms:
354                if nl_msg.error:
355                    raise YnlException(f"Netlink error: {nl_msg.error}")
356                if nl_msg.done:
357                    return genl_family_name_to_id
358
359                gm = GenlMsg(nl_msg)
360                fam = {}
361                for attr in NlAttrs(gm.raw):
362                    if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
363                        fam['id'] = attr.as_scalar('u16')
364                    elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
365                        fam['name'] = attr.as_strz()
366                    elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
367                        fam['maxattr'] = attr.as_scalar('u32')
368                    elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
369                        fam['mcast'] = {}
370                        for entry in NlAttrs(attr.raw):
371                            mcast_name = None
372                            mcast_id = None
373                            for entry_attr in NlAttrs(entry.raw):
374                                if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
375                                    mcast_name = entry_attr.as_strz()
376                                elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
377                                    mcast_id = entry_attr.as_scalar('u32')
378                            if mcast_name and mcast_id is not None:
379                                fam['mcast'][mcast_name] = mcast_id
380                if 'name' in fam and 'id' in fam:
381                    genl_family_name_to_id[fam['name']] = fam
382
383
384class GenlMsg:
385    def __init__(self, nl_msg):
386        self.nl = nl_msg
387        self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
388        self.raw = nl_msg.raw[4:]
389        self.raw_attrs = []
390
391    def cmd(self):
392        return self.genl_cmd
393
394    def __repr__(self):
395        msg = repr(self.nl)
396        msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
397        for a in self.raw_attrs:
398            msg += '\t\t' + repr(a) + '\n'
399        return msg
400
401
402class NetlinkProtocol:
403    def __init__(self, family_name, proto_num):
404        self.family_name = family_name
405        self.proto_num = proto_num
406
407    def _message(self, nl_type, nl_flags, seq=None):
408        if seq is None:
409            seq = random.randint(1, 1024)
410        nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
411        return nlmsg
412
413    def message(self, flags, command, _version, seq=None):
414        return self._message(command, flags, seq)
415
416    def _decode(self, nl_msg):
417        return nl_msg
418
419    def decode(self, ynl, nl_msg, op):
420        msg = self._decode(nl_msg)
421        if op is None:
422            op = ynl.rsp_by_value[msg.cmd()]
423        fixed_header_size = ynl.struct_size(op.fixed_header)
424        msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size)
425        return msg
426
427    def get_mcast_id(self, mcast_name, mcast_groups):
428        if mcast_name not in mcast_groups:
429            raise YnlException(f'Multicast group "{mcast_name}" not present in the spec')
430        return mcast_groups[mcast_name].value
431
432    def msghdr_size(self):
433        return 16
434
435
436class GenlProtocol(NetlinkProtocol):
437    genl_family_name_to_id = {}
438
439    def __init__(self, family_name):
440        super().__init__(family_name, Netlink.NETLINK_GENERIC)
441
442        if not GenlProtocol.genl_family_name_to_id:
443            GenlProtocol.genl_family_name_to_id = _genl_load_families()
444
445        self.genl_family = GenlProtocol.genl_family_name_to_id[family_name]
446        self.family_id = GenlProtocol.genl_family_name_to_id[family_name]['id']
447
448    def message(self, flags, command, version, seq=None):
449        nlmsg = self._message(self.family_id, flags, seq)
450        genlmsg = struct.pack("BBH", command, version, 0)
451        return nlmsg + genlmsg
452
453    def _decode(self, nl_msg):
454        return GenlMsg(nl_msg)
455
456    def get_mcast_id(self, mcast_name, mcast_groups):
457        if mcast_name not in self.genl_family['mcast']:
458            raise YnlException(f'Multicast group "{mcast_name}" not present in the family')
459        return self.genl_family['mcast'][mcast_name]
460
461    def msghdr_size(self):
462        return super().msghdr_size() + 4
463
464
465# pylint: disable=too-few-public-methods
466class SpaceAttrs:
467    SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values'])
468
469    def __init__(self, attr_space, attrs, outer = None):
470        outer_scopes = outer.scopes if outer else []
471        inner_scope = self.SpecValuesPair(attr_space, attrs)
472        self.scopes = [inner_scope] + outer_scopes
473
474    def lookup(self, name):
475        for scope in self.scopes:
476            if name in scope.spec:
477                if name in scope.values:
478                    return scope.values[name]
479                spec_name = scope.spec.yaml['name']
480                raise YnlException(
481                    f"No value for '{name}' in attribute space '{spec_name}'")
482        raise YnlException(f"Attribute '{name}' not defined in any attribute-set")
483
484
485#
486# YNL implementation details.
487#
488
489
490class YnlFamily(SpecFamily):
491    def __init__(self, def_path, schema=None, process_unknown=False,
492                 recv_size=0):
493        super().__init__(def_path, schema)
494
495        self.include_raw = False
496        self.process_unknown = process_unknown
497
498        try:
499            if self.proto == "netlink-raw":
500                self.nlproto = NetlinkProtocol(self.yaml['name'],
501                                               self.yaml['protonum'])
502            else:
503                self.nlproto = GenlProtocol(self.yaml['name'])
504        except KeyError as err:
505            raise YnlException(f"Family '{self.yaml['name']}' not supported by the kernel") from err
506
507        self._recv_dbg = False
508        # Note that netlink will use conservative (min) message size for
509        # the first dump recv() on the socket, our setting will only matter
510        # from the second recv() on.
511        self._recv_size = recv_size if recv_size else 131072
512        # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo)
513        # for a message, so smaller receive sizes will lead to truncation.
514        # Note that the min size for other families may be larger than 4k!
515        if self._recv_size < 4000:
516            raise ConfigError()
517
518        self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
519        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
520        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
521        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)
522
523        self.async_msg_ids = set()
524        self.async_msg_queue = queue.Queue()
525
526        for msg in self.msgs.values():
527            if msg.is_async:
528                self.async_msg_ids.add(msg.rsp_value)
529
530        for op_name, op in self.ops.items():
531            bound_f = functools.partial(self._op, op_name)
532            setattr(self, op.ident_name, bound_f)
533
534
535    def ntf_subscribe(self, mcast_name):
536        mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
537        self.sock.bind((0, 0))
538        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
539                             mcast_id)
540
541    def set_recv_dbg(self, enabled):
542        self._recv_dbg = enabled
543
544    def _recv_dbg_print(self, reply, nl_msgs):
545        if not self._recv_dbg:
546            return
547        print("Recv: read", len(reply), "bytes,",
548              len(nl_msgs.msgs), "messages", file=sys.stderr)
549        for nl_msg in nl_msgs:
550            print("  ", nl_msg, file=sys.stderr)
551
552    def _encode_enum(self, attr_spec, value):
553        enum = self.consts[attr_spec['enum']]
554        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
555            scalar = 0
556            if isinstance(value, str):
557                value = [value]
558            for single_value in value:
559                scalar += enum.entries[single_value].user_value(as_flags = True)
560            return scalar
561        return enum.entries[value].user_value()
562
563    def _get_scalar(self, attr_spec, value):
564        try:
565            return int(value)
566        except (ValueError, TypeError) as e:
567            if 'enum' in attr_spec:
568                return self._encode_enum(attr_spec, value)
569            if attr_spec.display_hint:
570                return self._from_string(value, attr_spec)
571            raise e
572
573    # pylint: disable=too-many-statements
574    def _add_attr(self, space, name, value, search_attrs):
575        try:
576            attr = self.attr_sets[space][name]
577        except KeyError as err:
578            raise YnlException(f"Space '{space}' has no attribute '{name}'") from err
579        nl_type = attr.value
580
581        if attr.is_multi and isinstance(value, list):
582            attr_payload = b''
583            for subvalue in value:
584                attr_payload += self._add_attr(space, name, subvalue, search_attrs)
585            return attr_payload
586
587        if attr["type"] == 'nest':
588            nl_type |= Netlink.NLA_F_NESTED
589            sub_space = attr['nested-attributes']
590            attr_payload = self._add_nest_attrs(value, sub_space, search_attrs)
591        elif attr['type'] == 'indexed-array' and attr['sub-type'] == 'nest':
592            nl_type |= Netlink.NLA_F_NESTED
593            sub_space = attr['nested-attributes']
594            attr_payload = self._encode_indexed_array(value, sub_space,
595                                                      search_attrs)
596        elif attr["type"] == 'flag':
597            if not value:
598                # If value is absent or false then skip attribute creation.
599                return b''
600            attr_payload = b''
601        elif attr["type"] == 'string':
602            attr_payload = str(value).encode('ascii') + b'\x00'
603        elif attr["type"] == 'binary':
604            if value is None:
605                attr_payload = b''
606            elif isinstance(value, bytes):
607                attr_payload = value
608            elif isinstance(value, str):
609                if attr.display_hint:
610                    attr_payload = self._from_string(value, attr)
611                else:
612                    attr_payload = bytes.fromhex(value)
613            elif isinstance(value, dict) and attr.struct_name:
614                attr_payload = self._encode_struct(attr.struct_name, value)
615            elif isinstance(value, list) and attr.sub_type in NlAttr.type_formats:
616                format_ = NlAttr.get_format(attr.sub_type)
617                attr_payload = b''.join([format_.pack(x) for x in value])
618            else:
619                raise YnlException(f'Unknown type for binary attribute, value: {value}')
620        elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar:
621            scalar = self._get_scalar(attr, value)
622            if attr.is_auto_scalar:
623                attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64')
624            else:
625                attr_type = attr["type"]
626            format_ = NlAttr.get_format(attr_type, attr.byte_order)
627            attr_payload = format_.pack(scalar)
628        elif attr['type'] in "bitfield32":
629            scalar_value = self._get_scalar(attr, value["value"])
630            scalar_selector = self._get_scalar(attr, value["selector"])
631            attr_payload = struct.pack("II", scalar_value, scalar_selector)
632        elif attr['type'] == 'sub-message':
633            msg_format, _ = self._resolve_selector(attr, search_attrs)
634            attr_payload = b''
635            if msg_format.fixed_header:
636                attr_payload += self._encode_struct(msg_format.fixed_header, value)
637            if msg_format.attr_set:
638                if msg_format.attr_set in self.attr_sets:
639                    nl_type |= Netlink.NLA_F_NESTED
640                    sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs)
641                    for subname, subvalue in value.items():
642                        attr_payload += self._add_attr(msg_format.attr_set,
643                                                       subname, subvalue, sub_attrs)
644                else:
645                    raise YnlException(f"Unknown attribute-set '{msg_format.attr_set}'")
646        else:
647            raise YnlException(f'Unknown type at {space} {name} {value} {attr["type"]}')
648
649        return self._add_attr_raw(nl_type, attr_payload)
650
651    def _add_attr_raw(self, nl_type, attr_payload):
652        pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
653        return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
654
655    def _add_nest_attrs(self, value, sub_space, search_attrs):
656        sub_attrs = SpaceAttrs(self.attr_sets[sub_space], value, search_attrs)
657        attr_payload = b''
658        for subname, subvalue in value.items():
659            attr_payload += self._add_attr(sub_space, subname, subvalue,
660                                           sub_attrs)
661        return attr_payload
662
663    def _encode_indexed_array(self, vals, sub_space, search_attrs):
664        attr_payload = b''
665        for i, val in enumerate(vals):
666            idx = i | Netlink.NLA_F_NESTED
667            val_payload = self._add_nest_attrs(val, sub_space, search_attrs)
668            attr_payload += self._add_attr_raw(idx, val_payload)
669        return attr_payload
670
671    def _get_enum_or_unknown(self, enum, raw):
672        try:
673            name = enum.entries_by_val[raw].name
674        except KeyError as error:
675            if self.process_unknown:
676                name = f"Unknown({raw})"
677            else:
678                raise error
679        return name
680
681    def _decode_enum(self, raw, attr_spec):
682        enum = self.consts[attr_spec['enum']]
683        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
684            i = 0
685            value = set()
686            while raw:
687                if raw & 1:
688                    value.add(self._get_enum_or_unknown(enum, i))
689                raw >>= 1
690                i += 1
691        else:
692            value = self._get_enum_or_unknown(enum, raw)
693        return value
694
695    def _decode_binary(self, attr, attr_spec):
696        if attr_spec.struct_name:
697            decoded = self._decode_struct(attr.raw, attr_spec.struct_name)
698        elif attr_spec.sub_type:
699            decoded = attr.as_c_array(attr_spec.sub_type)
700            if 'enum' in attr_spec:
701                decoded = [ self._decode_enum(x, attr_spec) for x in decoded ]
702            elif attr_spec.display_hint:
703                decoded = [ self._formatted_string(x, attr_spec.display_hint)
704                            for x in decoded ]
705        else:
706            decoded = attr.as_bin()
707            if attr_spec.display_hint:
708                decoded = self._formatted_string(decoded, attr_spec.display_hint)
709        return decoded
710
711    def _decode_array_attr(self, attr, attr_spec):
712        decoded = []
713        offset = 0
714        while offset < len(attr.raw):
715            item = NlAttr(attr.raw, offset)
716            offset += item.full_len
717
718            if attr_spec["sub-type"] == 'nest':
719                subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes'])
720                decoded.append({ item.type: subattrs })
721            elif attr_spec["sub-type"] == 'binary':
722                subattr = item.as_bin()
723                if attr_spec.display_hint:
724                    subattr = self._formatted_string(subattr, attr_spec.display_hint)
725                decoded.append(subattr)
726            elif attr_spec["sub-type"] in NlAttr.type_formats:
727                subattr = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order)
728                if 'enum' in attr_spec:
729                    subattr = self._decode_enum(subattr, attr_spec)
730                elif attr_spec.display_hint:
731                    subattr = self._formatted_string(subattr, attr_spec.display_hint)
732                decoded.append(subattr)
733            else:
734                raise YnlException(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}')
735        return decoded
736
737    def _decode_nest_type_value(self, attr, attr_spec):
738        decoded = {}
739        value = attr
740        for name in attr_spec['type-value']:
741            value = NlAttr(value.raw, 0)
742            decoded[name] = value.type
743        subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes'])
744        decoded.update(subattrs)
745        return decoded
746
747    def _decode_unknown(self, attr):
748        if attr.is_nest:
749            return self._decode(NlAttrs(attr.raw), None)
750        return attr.as_bin()
751
752    def _rsp_add(self, rsp, name, is_multi, decoded):
753        if is_multi is None:
754            if name in rsp and not isinstance(rsp[name], list):
755                rsp[name] = [rsp[name]]
756                is_multi = True
757            else:
758                is_multi = False
759
760        if not is_multi:
761            rsp[name] = decoded
762        elif name in rsp:
763            rsp[name].append(decoded)
764        else:
765            rsp[name] = [decoded]
766
767    def _resolve_selector(self, attr_spec, search_attrs):
768        sub_msg = attr_spec.sub_message
769        if sub_msg not in self.sub_msgs:
770            raise YnlException(f"No sub-message spec named {sub_msg} for {attr_spec.name}")
771        sub_msg_spec = self.sub_msgs[sub_msg]
772
773        selector = attr_spec.selector
774        value = search_attrs.lookup(selector)
775        if value not in sub_msg_spec.formats:
776            raise YnlException(f"No message format for '{value}' in sub-message spec '{sub_msg}'")
777
778        spec = sub_msg_spec.formats[value]
779        return spec, value
780
781    def _decode_sub_msg(self, attr, attr_spec, search_attrs):
782        msg_format, _ = self._resolve_selector(attr_spec, search_attrs)
783        decoded = {}
784        offset = 0
785        if msg_format.fixed_header:
786            decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header))
787            offset = self.struct_size(msg_format.fixed_header)
788        if msg_format.attr_set:
789            if msg_format.attr_set in self.attr_sets:
790                subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set)
791                decoded.update(subdict)
792            else:
793                raise YnlException(f"Unknown attribute-set '{msg_format.attr_set}' "
794                                   f"when decoding '{attr_spec.name}'")
795        return decoded
796
797    # pylint: disable=too-many-statements
798    def _decode(self, attrs, space, outer_attrs = None):
799        rsp = {}
800        search_attrs = {}
801        if space:
802            attr_space = self.attr_sets[space]
803            search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs)
804
805        for attr in attrs:
806            try:
807                attr_spec = attr_space.attrs_by_val[attr.type]
808            except (KeyError, UnboundLocalError) as err:
809                if not self.process_unknown:
810                    raise YnlException(f"Space '{space}' has no attribute "
811                                       f"with value '{attr.type}'") from err
812                attr_name = f"UnknownAttr({attr.type})"
813                self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr))
814                continue
815
816            try:
817                if attr_spec["type"] == 'nest':
818                    subdict = self._decode(NlAttrs(attr.raw),
819                                           attr_spec['nested-attributes'],
820                                           search_attrs)
821                    decoded = subdict
822                elif attr_spec["type"] == 'string':
823                    decoded = attr.as_strz()
824                elif attr_spec["type"] == 'binary':
825                    decoded = self._decode_binary(attr, attr_spec)
826                elif attr_spec["type"] == 'flag':
827                    decoded = True
828                elif attr_spec.is_auto_scalar:
829                    decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order)
830                    if 'enum' in attr_spec:
831                        decoded = self._decode_enum(decoded, attr_spec)
832                elif attr_spec["type"] in NlAttr.type_formats:
833                    decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
834                    if 'enum' in attr_spec:
835                        decoded = self._decode_enum(decoded, attr_spec)
836                    elif attr_spec.display_hint:
837                        decoded = self._formatted_string(decoded, attr_spec.display_hint)
838                elif attr_spec["type"] == 'indexed-array':
839                    decoded = self._decode_array_attr(attr, attr_spec)
840                elif attr_spec["type"] == 'bitfield32':
841                    value, selector = struct.unpack("II", attr.raw)
842                    if 'enum' in attr_spec:
843                        value = self._decode_enum(value, attr_spec)
844                        selector = self._decode_enum(selector, attr_spec)
845                    decoded = {"value": value, "selector": selector}
846                elif attr_spec["type"] == 'sub-message':
847                    decoded = self._decode_sub_msg(attr, attr_spec, search_attrs)
848                elif attr_spec["type"] == 'nest-type-value':
849                    decoded = self._decode_nest_type_value(attr, attr_spec)
850                else:
851                    if not self.process_unknown:
852                        raise YnlException(f'Unknown {attr_spec["type"]} '
853                                           f'with name {attr_spec["name"]}')
854                    decoded = self._decode_unknown(attr)
855
856                self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded)
857            except:
858                print(f"Error decoding '{attr_spec.name}' from '{space}'")
859                raise
860
861        return rsp
862
863    # pylint: disable=too-many-arguments, too-many-positional-arguments
864    def _decode_extack_path(self, attrs, attr_set, offset, target, search_attrs):
865        for attr in attrs:
866            try:
867                attr_spec = attr_set.attrs_by_val[attr.type]
868            except KeyError as err:
869                raise YnlException(
870                    f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") from err
871            if offset > target:
872                break
873            if offset == target:
874                return '.' + attr_spec.name
875
876            if offset + attr.full_len <= target:
877                offset += attr.full_len
878                continue
879
880            pathname = attr_spec.name
881            if attr_spec['type'] == 'nest':
882                sub_attrs = self.attr_sets[attr_spec['nested-attributes']]
883                search_attrs = SpaceAttrs(sub_attrs, search_attrs.lookup(attr_spec['name']))
884            elif attr_spec['type'] == 'sub-message':
885                msg_format, value = self._resolve_selector(attr_spec, search_attrs)
886                if msg_format is None:
887                    raise YnlException(f"Can't resolve sub-message of "
888                                       f"{attr_spec['name']} for extack")
889                sub_attrs = self.attr_sets[msg_format.attr_set]
890                pathname += f"({value})"
891            else:
892                raise YnlException(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
893            offset += 4
894            subpath = self._decode_extack_path(NlAttrs(attr.raw), sub_attrs,
895                                               offset, target, search_attrs)
896            if subpath is None:
897                return None
898            return '.' + pathname + subpath
899
900        return None
901
902    def _decode_extack(self, request, op, extack, vals):
903        if 'bad-attr-offs' not in extack:
904            return
905
906        msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op)
907        offset = self.nlproto.msghdr_size() + self.struct_size(op.fixed_header)
908        search_attrs = SpaceAttrs(op.attr_set, vals)
909        path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
910                                        extack['bad-attr-offs'], search_attrs)
911        if path:
912            del extack['bad-attr-offs']
913            extack['bad-attr'] = path
914
915    def struct_size(self, name):
916        if name:
917            members = self.consts[name].members
918            size = 0
919            for m in members:
920                if m.type in ['pad', 'binary']:
921                    if m.struct:
922                        size += self.struct_size(m.struct)
923                    else:
924                        size += m.len
925                else:
926                    format_ = NlAttr.get_format(m.type, m.byte_order)
927                    size += format_.size
928            return size
929        return 0
930
931    def _decode_struct(self, data, name):
932        members = self.consts[name].members
933        attrs = {}
934        offset = 0
935        for m in members:
936            value = None
937            if m.type == 'pad':
938                offset += m.len
939            elif m.type == 'binary':
940                if m.struct:
941                    len_ = self.struct_size(m.struct)
942                    value = self._decode_struct(data[offset : offset + len_],
943                                                m.struct)
944                    offset += len_
945                else:
946                    value = data[offset : offset + m.len]
947                    offset += m.len
948            else:
949                format_ = NlAttr.get_format(m.type, m.byte_order)
950                [ value ] = format_.unpack_from(data, offset)
951                offset += format_.size
952            if value is not None:
953                if m.enum:
954                    value = self._decode_enum(value, m)
955                elif m.display_hint:
956                    value = self._formatted_string(value, m.display_hint)
957                attrs[m.name] = value
958        return attrs
959
960    def _encode_struct(self, name, vals):
961        members = self.consts[name].members
962        attr_payload = b''
963        for m in members:
964            value = vals.pop(m.name) if m.name in vals else None
965            if m.type == 'pad':
966                attr_payload += bytearray(m.len)
967            elif m.type == 'binary':
968                if m.struct:
969                    if value is None:
970                        value = {}
971                    attr_payload += self._encode_struct(m.struct, value)
972                else:
973                    if value is None:
974                        attr_payload += bytearray(m.len)
975                    else:
976                        attr_payload += bytes.fromhex(value)
977            else:
978                if value is None:
979                    value = 0
980                format_ = NlAttr.get_format(m.type, m.byte_order)
981                attr_payload += format_.pack(value)
982        return attr_payload
983
984    def _formatted_string(self, raw, display_hint):
985        if display_hint == 'mac':
986            formatted = ':'.join(f'{b:02x}' for b in raw)
987        elif display_hint == 'hex':
988            if isinstance(raw, int):
989                formatted = hex(raw)
990            else:
991                formatted = bytes.hex(raw, ' ')
992        elif display_hint in [ 'ipv4', 'ipv6', 'ipv4-or-v6' ]:
993            formatted = format(ipaddress.ip_address(raw))
994        elif display_hint == 'uuid':
995            formatted = str(uuid.UUID(bytes=raw))
996        else:
997            formatted = raw
998        return formatted
999
1000    def _from_string(self, string, attr_spec):
1001        if attr_spec.display_hint in ['ipv4', 'ipv6', 'ipv4-or-v6']:
1002            ip = ipaddress.ip_address(string)
1003            if attr_spec['type'] == 'binary':
1004                raw = ip.packed
1005            else:
1006                raw = int(ip)
1007        elif attr_spec.display_hint == 'hex':
1008            if attr_spec['type'] == 'binary':
1009                raw = bytes.fromhex(string)
1010            else:
1011                raw = int(string, 16)
1012        elif attr_spec.display_hint == 'mac':
1013            # Parse MAC address in format "00:11:22:33:44:55" or "001122334455"
1014            if ':' in string:
1015                mac_bytes = [int(x, 16) for x in string.split(':')]
1016            else:
1017                if len(string) % 2 != 0:
1018                    raise YnlException(f"Invalid MAC address format: {string}")
1019                mac_bytes = [int(string[i:i+2], 16) for i in range(0, len(string), 2)]
1020            raw = bytes(mac_bytes)
1021        else:
1022            raise YnlException(f"Display hint '{attr_spec.display_hint}' not implemented"
1023                            f" when parsing '{attr_spec['name']}'")
1024        return raw
1025
1026    def handle_ntf(self, decoded):
1027        msg = {}
1028        if self.include_raw:
1029            msg['raw'] = decoded
1030        op = self.rsp_by_value[decoded.cmd()]
1031        attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
1032        if op.fixed_header:
1033            attrs.update(self._decode_struct(decoded.raw, op.fixed_header))
1034
1035        msg['name'] = op['name']
1036        msg['msg'] = attrs
1037        self.async_msg_queue.put(msg)
1038
1039    def check_ntf(self):
1040        while True:
1041            try:
1042                reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT)
1043            except BlockingIOError:
1044                return
1045
1046            nms = NlMsgs(reply)
1047            self._recv_dbg_print(reply, nms)
1048            for nl_msg in nms:
1049                if nl_msg.error:
1050                    print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
1051                    print(nl_msg)
1052                    continue
1053                if nl_msg.done:
1054                    print("Netlink done while checking for ntf!?")
1055                    continue
1056
1057                decoded = self.nlproto.decode(self, nl_msg, None)
1058                if decoded.cmd() not in self.async_msg_ids:
1059                    print("Unexpected msg id while checking for ntf", decoded)
1060                    continue
1061
1062                self.handle_ntf(decoded)
1063
1064    def poll_ntf(self, duration=None):
1065        start_time = time.time()
1066        selector = selectors.DefaultSelector()
1067        selector.register(self.sock, selectors.EVENT_READ)
1068
1069        while True:
1070            try:
1071                yield self.async_msg_queue.get_nowait()
1072            except queue.Empty:
1073                if duration is not None:
1074                    timeout = start_time + duration - time.time()
1075                    if timeout <= 0:
1076                        return
1077                else:
1078                    timeout = None
1079                events = selector.select(timeout)
1080                if events:
1081                    self.check_ntf()
1082
1083    def operation_do_attributes(self, name):
1084        """
1085        For a given operation name, find and return a supported
1086        set of attributes (as a dict).
1087        """
1088        op = self.find_operation(name)
1089        if not op:
1090            return None
1091
1092        return op['do']['request']['attributes'].copy()
1093
1094    def _encode_message(self, op, vals, flags, req_seq):
1095        nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
1096        for flag in flags or []:
1097            nl_flags |= flag
1098
1099        msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
1100        if op.fixed_header:
1101            msg += self._encode_struct(op.fixed_header, vals)
1102        search_attrs = SpaceAttrs(op.attr_set, vals)
1103        for name, value in vals.items():
1104            msg += self._add_attr(op.attr_set.name, name, value, search_attrs)
1105        msg = _genl_msg_finalize(msg)
1106        return msg
1107
1108    # pylint: disable=too-many-statements
1109    def _ops(self, ops):
1110        reqs_by_seq = {}
1111        req_seq = random.randint(1024, 65535)
1112        payload = b''
1113        for (method, vals, flags) in ops:
1114            op = self.ops[method]
1115            msg = self._encode_message(op, vals, flags, req_seq)
1116            reqs_by_seq[req_seq] = (op, vals, msg, flags)
1117            payload += msg
1118            req_seq += 1
1119
1120        self.sock.send(payload, 0)
1121
1122        done = False
1123        rsp = []
1124        op_rsp = []
1125        while not done:
1126            reply = self.sock.recv(self._recv_size)
1127            nms = NlMsgs(reply)
1128            self._recv_dbg_print(reply, nms)
1129            for nl_msg in nms:
1130                if nl_msg.nl_seq in reqs_by_seq:
1131                    (op, vals, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq]
1132                    if nl_msg.extack:
1133                        nl_msg.annotate_extack(op.attr_set)
1134                        self._decode_extack(req_msg, op, nl_msg.extack, vals)
1135                else:
1136                    op = None
1137                    req_flags = []
1138
1139                if nl_msg.error:
1140                    raise NlError(nl_msg)
1141                if nl_msg.done:
1142                    if nl_msg.extack:
1143                        print("Netlink warning:")
1144                        print(nl_msg)
1145
1146                    if Netlink.NLM_F_DUMP in req_flags:
1147                        rsp.append(op_rsp)
1148                    elif not op_rsp:
1149                        rsp.append(None)
1150                    elif len(op_rsp) == 1:
1151                        rsp.append(op_rsp[0])
1152                    else:
1153                        rsp.append(op_rsp)
1154                    op_rsp = []
1155
1156                    del reqs_by_seq[nl_msg.nl_seq]
1157                    done = len(reqs_by_seq) == 0
1158                    break
1159
1160                decoded = self.nlproto.decode(self, nl_msg, op)
1161
1162                # Check if this is a reply to our request
1163                if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value:
1164                    if decoded.cmd() in self.async_msg_ids:
1165                        self.handle_ntf(decoded)
1166                        continue
1167                    print('Unexpected message: ' + repr(decoded))
1168                    continue
1169
1170                rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
1171                if op.fixed_header:
1172                    rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header))
1173                op_rsp.append(rsp_msg)
1174
1175        return rsp
1176
1177    def _op(self, method, vals, flags=None, dump=False):
1178        req_flags = flags or []
1179        if dump:
1180            req_flags.append(Netlink.NLM_F_DUMP)
1181
1182        ops = [(method, vals, req_flags)]
1183        return self._ops(ops)[0]
1184
1185    def do(self, method, vals, flags=None):
1186        return self._op(method, vals, flags)
1187
1188    def dump(self, method, vals):
1189        return self._op(method, vals, dump=True)
1190
1191    def do_multi(self, ops):
1192        return self._ops(ops)
1193