xref: /linux/tools/net/ynl/pyynl/lib/ynl.py (revision 77de28cd7cf172e782319a144bf64e693794d78b)
1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2#
3# pylint: disable=missing-class-docstring, missing-function-docstring
4# pylint: disable=too-many-branches, too-many-locals, too-many-instance-attributes
5# pylint: disable=too-many-lines
6
7"""
8YAML Netlink Library
9
10An implementation of the genetlink and raw netlink protocols.
11"""
12
13from collections import namedtuple
14from enum import Enum
15import functools
16import os
17import random
18import socket
19import struct
20from struct import Struct
21import sys
22import ipaddress
23import uuid
24import queue
25import selectors
26import time
27
28from .nlspec import SpecFamily
29
30#
31# Generic Netlink code which should really be in some library, but I can't quickly find one.
32#
33
34
35class YnlException(Exception):
36    pass
37
38
39# pylint: disable=too-few-public-methods
40class Netlink:
41    # Netlink socket
42    SOL_NETLINK = 270
43
44    NETLINK_ADD_MEMBERSHIP = 1
45    NETLINK_CAP_ACK = 10
46    NETLINK_EXT_ACK = 11
47    NETLINK_GET_STRICT_CHK = 12
48
49    # Netlink message
50    NLMSG_ERROR = 2
51    NLMSG_DONE = 3
52
53    NLM_F_REQUEST = 1
54    NLM_F_ACK = 4
55    NLM_F_ROOT = 0x100
56    NLM_F_MATCH = 0x200
57
58    NLM_F_REPLACE = 0x100
59    NLM_F_EXCL = 0x200
60    NLM_F_CREATE = 0x400
61    NLM_F_APPEND = 0x800
62
63    NLM_F_CAPPED = 0x100
64    NLM_F_ACK_TLVS = 0x200
65
66    NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
67
68    NLA_F_NESTED = 0x8000
69    NLA_F_NET_BYTEORDER = 0x4000
70
71    NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
72
73    # Genetlink defines
74    NETLINK_GENERIC = 16
75
76    GENL_ID_CTRL = 0x10
77
78    # nlctrl
79    CTRL_CMD_GETFAMILY = 3
80    CTRL_CMD_GETPOLICY = 10
81
82    CTRL_ATTR_FAMILY_ID = 1
83    CTRL_ATTR_FAMILY_NAME = 2
84    CTRL_ATTR_MAXATTR = 5
85    CTRL_ATTR_MCAST_GROUPS = 7
86    CTRL_ATTR_POLICY = 8
87    CTRL_ATTR_OP_POLICY = 9
88    CTRL_ATTR_OP = 10
89
90    CTRL_ATTR_MCAST_GRP_NAME = 1
91    CTRL_ATTR_MCAST_GRP_ID = 2
92
93    CTRL_ATTR_POLICY_DO = 1
94    CTRL_ATTR_POLICY_DUMP = 2
95
96    # Extack types
97    NLMSGERR_ATTR_MSG = 1
98    NLMSGERR_ATTR_OFFS = 2
99    NLMSGERR_ATTR_COOKIE = 3
100    NLMSGERR_ATTR_POLICY = 4
101    NLMSGERR_ATTR_MISS_TYPE = 5
102    NLMSGERR_ATTR_MISS_NEST = 6
103
104    # Policy types
105    NL_POLICY_TYPE_ATTR_TYPE = 1
106    NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2
107    NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3
108    NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4
109    NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5
110    NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6
111    NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7
112    NL_POLICY_TYPE_ATTR_POLICY_IDX = 8
113    NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9
114    NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10
115    NL_POLICY_TYPE_ATTR_PAD = 11
116    NL_POLICY_TYPE_ATTR_MASK = 12
117
118    AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64',
119                                  's8', 's16', 's32', 's64',
120                                  'binary', 'string', 'nul-string',
121                                  'nested', 'nested-array',
122                                  'bitfield32', 'sint', 'uint'])
123
124class NlError(Exception):
125    def __init__(self, nl_msg):
126        self.nl_msg = nl_msg
127        self.error = -nl_msg.error
128
129    def __str__(self):
130        msg = "Netlink error: "
131
132        extack = self.nl_msg.extack.copy() if self.nl_msg.extack else {}
133        if 'msg' in extack:
134            msg += extack['msg'] + ': '
135            del extack['msg']
136        msg += os.strerror(self.error)
137        if extack:
138            msg += ' ' + str(extack)
139        return msg
140
141
142class ConfigError(Exception):
143    pass
144
145
146# pylint: disable=too-few-public-methods
147class NlPolicy:
148    """Kernel policy for one mode (do or dump) of one operation.
149
150    Returned by YnlFamily.get_policy(). Contains a dict of attributes
151    the kernel accepts, with their validation constraints.
152
153    Attributes:
154        attrs: dict mapping attribute names to policy dicts, e.g.
155        page-pool-stats-get do policy::
156
157            {
158                'info': {'type': 'nested', 'policy': {
159                    'id':      {'type': 'uint', 'min-value': 1,
160                                'max-value': 4294967295},
161                    'ifindex': {'type': 'u32', 'min-value': 1,
162                                'max-value': 2147483647},
163                }},
164            }
165
166        Each policy dict always contains 'type' (e.g. u32, string,
167        nested). Optional keys: min-value, max-value, min-length,
168        max-length, mask, policy.
169    """
170    def __init__(self, attrs):
171        self.attrs = attrs
172
173
174class NlAttr:
175    ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
176    type_formats = {
177        'u8' : ScalarFormat(Struct('B'), Struct("B"),  Struct("B")),
178        's8' : ScalarFormat(Struct('b'), Struct("b"),  Struct("b")),
179        'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
180        's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
181        'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
182        's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
183        'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
184        's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
185    }
186
187    def __init__(self, raw, offset):
188        self._len, self._type = struct.unpack("HH", raw[offset : offset + 4])
189        self.type = self._type & ~Netlink.NLA_TYPE_MASK
190        self.is_nest = self._type & Netlink.NLA_F_NESTED
191        self.payload_len = self._len
192        self.full_len = (self.payload_len + 3) & ~3
193        self.raw = raw[offset + 4 : offset + self.payload_len]
194
195    @classmethod
196    def get_format(cls, attr_type, byte_order=None):
197        format_ = cls.type_formats[attr_type]
198        if byte_order:
199            return format_.big if byte_order == "big-endian" \
200                else format_.little
201        return format_.native
202
203    def as_scalar(self, attr_type, byte_order=None):
204        format_ = self.get_format(attr_type, byte_order)
205        return format_.unpack(self.raw)[0]
206
207    def as_auto_scalar(self, attr_type, byte_order=None):
208        if len(self.raw) != 4 and len(self.raw) != 8:
209            raise YnlException(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}")
210        real_type = attr_type[0] + str(len(self.raw) * 8)
211        format_ = self.get_format(real_type, byte_order)
212        return format_.unpack(self.raw)[0]
213
214    def as_strz(self):
215        return self.raw.decode('ascii')[:-1]
216
217    def as_bin(self):
218        return self.raw
219
220    def as_c_array(self, c_type):
221        format_ = self.get_format(c_type)
222        return [ x[0] for x in format_.iter_unpack(self.raw) ]
223
224    def __repr__(self):
225        return f"[type:{self.type} len:{self._len}] {self.raw}"
226
227
228class NlAttrs:
229    def __init__(self, msg, offset=0):
230        self.attrs = []
231
232        while offset < len(msg):
233            attr = NlAttr(msg, offset)
234            offset += attr.full_len
235            self.attrs.append(attr)
236
237    def __iter__(self):
238        yield from self.attrs
239
240    def __repr__(self):
241        msg = ''
242        for a in self.attrs:
243            if msg:
244                msg += '\n'
245            msg += repr(a)
246        return msg
247
248
249class NlMsg:
250    def __init__(self, msg, offset, attr_space=None):
251        self.hdr = msg[offset : offset + 16]
252
253        self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
254            struct.unpack("IHHII", self.hdr)
255
256        self.raw = msg[offset + 16 : offset + self.nl_len]
257
258        self.error = 0
259        self.done = 0
260
261        extack_off = None
262        if self.nl_type == Netlink.NLMSG_ERROR:
263            self.error = struct.unpack("i", self.raw[0:4])[0]
264            self.done = 1
265            extack_off = 20
266        elif self.nl_type == Netlink.NLMSG_DONE:
267            self.error = struct.unpack("i", self.raw[0:4])[0]
268            self.done = 1
269            extack_off = 4
270
271        self.extack = None
272        if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
273            self.extack = {}
274            extack_attrs = NlAttrs(self.raw[extack_off:])
275            for extack in extack_attrs:
276                if extack.type == Netlink.NLMSGERR_ATTR_MSG:
277                    self.extack['msg'] = extack.as_strz()
278                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
279                    self.extack['miss-type'] = extack.as_scalar('u32')
280                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
281                    self.extack['miss-nest'] = extack.as_scalar('u32')
282                elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
283                    self.extack['bad-attr-offs'] = extack.as_scalar('u32')
284                elif extack.type == Netlink.NLMSGERR_ATTR_POLICY:
285                    self.extack['policy'] = _genl_decode_policy(extack.raw)
286                else:
287                    if 'unknown' not in self.extack:
288                        self.extack['unknown'] = []
289                    self.extack['unknown'].append(extack)
290
291            if attr_space:
292                self.annotate_extack(attr_space)
293
294    def annotate_extack(self, attr_space):
295        """ Make extack more human friendly with attribute information """
296
297        # We don't have the ability to parse nests yet, so only do global
298        if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
299            miss_type = self.extack['miss-type']
300            if miss_type in attr_space.attrs_by_val:
301                spec = attr_space.attrs_by_val[miss_type]
302                self.extack['miss-type'] = spec['name']
303                if 'doc' in spec:
304                    self.extack['miss-type-doc'] = spec['doc']
305
306    def cmd(self):
307        return self.nl_type
308
309    def __repr__(self):
310        msg = (f"nl_len = {self.nl_len} ({len(self.raw)}) "
311               f"nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}")
312        if self.error:
313            msg += '\n\terror: ' + str(self.error)
314        if self.extack:
315            msg += '\n\textack: ' + repr(self.extack)
316        return msg
317
318
319# pylint: disable=too-few-public-methods
320class NlMsgs:
321    def __init__(self, data):
322        self.msgs = []
323
324        offset = 0
325        while offset < len(data):
326            msg = NlMsg(data, offset)
327            offset += msg.nl_len
328            self.msgs.append(msg)
329
330    def __iter__(self):
331        yield from self.msgs
332
333
334def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
335    # we prepend length in _genl_msg_finalize()
336    if seq is None:
337        seq = random.randint(1, 1024)
338    nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
339    genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
340    return nlmsg + genlmsg
341
342
343def _genl_msg_finalize(msg):
344    return struct.pack("I", len(msg) + 4) + msg
345
346
347def _genl_decode_policy(raw):
348    policy = {}
349    for attr in NlAttrs(raw):
350        if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE:
351            type_ = attr.as_scalar('u32')
352            policy['type'] = Netlink.AttrType(type_).name
353        elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S:
354            policy['min-value'] = attr.as_scalar('s64')
355        elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S:
356            policy['max-value'] = attr.as_scalar('s64')
357        elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U:
358            policy['min-value'] = attr.as_scalar('u64')
359        elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U:
360            policy['max-value'] = attr.as_scalar('u64')
361        elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH:
362            policy['min-length'] = attr.as_scalar('u32')
363        elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH:
364            policy['max-length'] = attr.as_scalar('u32')
365        elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_POLICY_IDX:
366            policy['policy-idx'] = attr.as_scalar('u32')
367        elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK:
368            policy['bitfield32-mask'] = attr.as_scalar('u32')
369        elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK:
370            policy['mask'] = attr.as_scalar('u64')
371    return policy
372
373
374# pylint: disable=too-many-nested-blocks
375def _genl_load_families():
376    genl_family_name_to_id = {}
377
378    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
379        sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
380
381        msg = _genl_msg(Netlink.GENL_ID_CTRL,
382                        Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
383                        Netlink.CTRL_CMD_GETFAMILY, 1)
384        msg = _genl_msg_finalize(msg)
385
386        sock.send(msg, 0)
387
388        while True:
389            reply = sock.recv(128 * 1024)
390            nms = NlMsgs(reply)
391            for nl_msg in nms:
392                if nl_msg.error:
393                    raise YnlException(f"Netlink error: {nl_msg.error}")
394                if nl_msg.done:
395                    return genl_family_name_to_id
396
397                gm = GenlMsg(nl_msg)
398                fam = {}
399                for attr in NlAttrs(gm.raw):
400                    if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
401                        fam['id'] = attr.as_scalar('u16')
402                    elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
403                        fam['name'] = attr.as_strz()
404                    elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
405                        fam['maxattr'] = attr.as_scalar('u32')
406                    elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
407                        fam['mcast'] = {}
408                        for entry in NlAttrs(attr.raw):
409                            mcast_name = None
410                            mcast_id = None
411                            for entry_attr in NlAttrs(entry.raw):
412                                if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
413                                    mcast_name = entry_attr.as_strz()
414                                elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
415                                    mcast_id = entry_attr.as_scalar('u32')
416                            if mcast_name and mcast_id is not None:
417                                fam['mcast'][mcast_name] = mcast_id
418                if 'name' in fam and 'id' in fam:
419                    genl_family_name_to_id[fam['name']] = fam
420
421
422# pylint: disable=too-many-nested-blocks
423def _genl_policy_dump(family_id, op):
424    op_policy = {}
425    policy_table = {}
426
427    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
428        sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
429
430        msg = _genl_msg(Netlink.GENL_ID_CTRL,
431                        Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
432                        Netlink.CTRL_CMD_GETPOLICY, 1)
433        msg += struct.pack('HHHxx', 6, Netlink.CTRL_ATTR_FAMILY_ID, family_id)
434        msg += struct.pack('HHI', 8, Netlink.CTRL_ATTR_OP, op)
435        msg = _genl_msg_finalize(msg)
436
437        sock.send(msg, 0)
438
439        while True:
440            reply = sock.recv(128 * 1024)
441            nms = NlMsgs(reply)
442            for nl_msg in nms:
443                if nl_msg.error:
444                    raise YnlException(f"Netlink error: {nl_msg.error}")
445                if nl_msg.done:
446                    return op_policy, policy_table
447
448                gm = GenlMsg(nl_msg)
449                for attr in NlAttrs(gm.raw):
450                    if attr.type == Netlink.CTRL_ATTR_OP_POLICY:
451                        for op_attr in NlAttrs(attr.raw):
452                            for method_attr in NlAttrs(op_attr.raw):
453                                if method_attr.type == Netlink.CTRL_ATTR_POLICY_DO:
454                                    op_policy['do'] = method_attr.as_scalar('u32')
455                                elif method_attr.type == Netlink.CTRL_ATTR_POLICY_DUMP:
456                                    op_policy['dump'] = method_attr.as_scalar('u32')
457                    elif attr.type == Netlink.CTRL_ATTR_POLICY:
458                        for pidx_attr in NlAttrs(attr.raw):
459                            policy_idx = pidx_attr.type
460                            for aid_attr in NlAttrs(pidx_attr.raw):
461                                attr_id = aid_attr.type
462                                decoded = _genl_decode_policy(aid_attr.raw)
463                                if policy_idx not in policy_table:
464                                    policy_table[policy_idx] = {}
465                                policy_table[policy_idx][attr_id] = decoded
466
467
468class GenlMsg:
469    def __init__(self, nl_msg):
470        self.nl = nl_msg
471        self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
472        self.raw = nl_msg.raw[4:]
473        self.raw_attrs = []
474
475    def cmd(self):
476        return self.genl_cmd
477
478    def __repr__(self):
479        msg = repr(self.nl)
480        msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
481        for a in self.raw_attrs:
482            msg += '\t\t' + repr(a) + '\n'
483        return msg
484
485
486class NetlinkProtocol:
487    def __init__(self, family_name, proto_num):
488        self.family_name = family_name
489        self.proto_num = proto_num
490
491    def _message(self, nl_type, nl_flags, seq=None):
492        if seq is None:
493            seq = random.randint(1, 1024)
494        nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
495        return nlmsg
496
497    def message(self, flags, command, _version, seq=None):
498        return self._message(command, flags, seq)
499
500    def _decode(self, nl_msg):
501        return nl_msg
502
503    def decode(self, ynl, nl_msg, op):
504        msg = self._decode(nl_msg)
505        if op is None:
506            op = ynl.rsp_by_value[msg.cmd()]
507        fixed_header_size = ynl.struct_size(op.fixed_header)
508        msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size)
509        return msg
510
511    def get_mcast_id(self, mcast_name, mcast_groups):
512        if mcast_name not in mcast_groups:
513            raise YnlException(f'Multicast group "{mcast_name}" not present in the spec')
514        return mcast_groups[mcast_name].value
515
516    def msghdr_size(self):
517        return 16
518
519
520class GenlProtocol(NetlinkProtocol):
521    genl_family_name_to_id = {}
522
523    def __init__(self, family_name):
524        super().__init__(family_name, Netlink.NETLINK_GENERIC)
525
526        if not GenlProtocol.genl_family_name_to_id:
527            GenlProtocol.genl_family_name_to_id = _genl_load_families()
528
529        self.genl_family = GenlProtocol.genl_family_name_to_id[family_name]
530        self.family_id = GenlProtocol.genl_family_name_to_id[family_name]['id']
531
532    def message(self, flags, command, version, seq=None):
533        nlmsg = self._message(self.family_id, flags, seq)
534        genlmsg = struct.pack("BBH", command, version, 0)
535        return nlmsg + genlmsg
536
537    def _decode(self, nl_msg):
538        return GenlMsg(nl_msg)
539
540    def get_mcast_id(self, mcast_name, mcast_groups):
541        if mcast_name not in self.genl_family['mcast']:
542            raise YnlException(f'Multicast group "{mcast_name}" not present in the family')
543        return self.genl_family['mcast'][mcast_name]
544
545    def msghdr_size(self):
546        return super().msghdr_size() + 4
547
548
549# pylint: disable=too-few-public-methods
550class SpaceAttrs:
551    SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values'])
552
553    def __init__(self, attr_space, attrs, outer = None):
554        outer_scopes = outer.scopes if outer else []
555        inner_scope = self.SpecValuesPair(attr_space, attrs)
556        self.scopes = [inner_scope] + outer_scopes
557
558    def lookup(self, name):
559        for scope in self.scopes:
560            if name in scope.spec:
561                if name in scope.values:
562                    return scope.values[name]
563                spec_name = scope.spec.yaml['name']
564                raise YnlException(
565                    f"No value for '{name}' in attribute space '{spec_name}'")
566        raise YnlException(f"Attribute '{name}' not defined in any attribute-set")
567
568
569#
570# YNL implementation details.
571#
572
573
574class YnlFamily(SpecFamily):
575    """
576    YNL family -- a Netlink interface built from a YAML spec.
577
578    Primary use of the class is to execute Netlink commands:
579
580      ynl.<op_name>(attrs, ...)
581
582    By default this will execute the <op_name> as "do", pass dump=True
583    to perform a dump operation.
584
585    ynl.<op_name> is a shorthand / convenience wrapper for the following
586    methods which take the op_name as a string:
587
588      ynl.do(op_name, attrs, flags=None) -- execute a do operation
589      ynl.dump(op_name, attrs)           -- execute a dump operation
590      ynl.do_multi(ops)                  -- batch multiple do operations
591
592    The flags argument in ynl.do() allows passing in extra NLM_F_* flags
593    which may be necessary for old families.
594
595    Notification API:
596
597      ynl.ntf_subscribe(mcast_name)      -- join a multicast group
598      ynl.check_ntf()                    -- drain pending notifications
599      ynl.poll_ntf(duration=None)        -- yield notifications
600
601    Policy introspection allows querying validation criteria from the running
602    kernel. Allows checking whether kernel supports a given attribute or value.
603
604      ynl.get_policy(op_name, mode)      -- query kernel policy for an op
605    """
606    def __init__(self, def_path, schema=None, process_unknown=False,
607                 recv_size=0):
608        super().__init__(def_path, schema)
609
610        self.include_raw = False
611        self.process_unknown = process_unknown
612
613        try:
614            if self.proto == "netlink-raw":
615                self.nlproto = NetlinkProtocol(self.yaml['name'],
616                                               self.yaml['protonum'])
617            else:
618                self.nlproto = GenlProtocol(self.yaml['name'])
619        except KeyError as err:
620            raise YnlException(f"Family '{self.yaml['name']}' not supported by the kernel") from err
621
622        self._recv_dbg = False
623        # Note that netlink will use conservative (min) message size for
624        # the first dump recv() on the socket, our setting will only matter
625        # from the second recv() on.
626        self._recv_size = recv_size if recv_size else 131072
627        # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo)
628        # for a message, so smaller receive sizes will lead to truncation.
629        # Note that the min size for other families may be larger than 4k!
630        if self._recv_size < 4000:
631            raise ConfigError()
632
633        self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
634        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
635        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
636        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)
637
638        self.async_msg_ids = set()
639        self.async_msg_queue = queue.Queue()
640
641        for msg in self.msgs.values():
642            if msg.is_async:
643                self.async_msg_ids.add(msg.rsp_value)
644
645        for op_name, op in self.ops.items():
646            bound_f = functools.partial(self._op, op_name)
647            setattr(self, op.ident_name, bound_f)
648
649
650    def ntf_subscribe(self, mcast_name):
651        mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
652        self.sock.bind((0, 0))
653        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
654                             mcast_id)
655
656    def set_recv_dbg(self, enabled):
657        self._recv_dbg = enabled
658
659    def _recv_dbg_print(self, reply, nl_msgs):
660        if not self._recv_dbg:
661            return
662        print("Recv: read", len(reply), "bytes,",
663              len(nl_msgs.msgs), "messages", file=sys.stderr)
664        for nl_msg in nl_msgs:
665            print("  ", nl_msg, file=sys.stderr)
666
667    def _encode_enum(self, attr_spec, value):
668        enum = self.consts[attr_spec['enum']]
669        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
670            scalar = 0
671            if isinstance(value, str):
672                value = [value]
673            for single_value in value:
674                scalar += enum.entries[single_value].user_value(as_flags = True)
675            return scalar
676        return enum.entries[value].user_value()
677
678    def _get_scalar(self, attr_spec, value):
679        try:
680            return int(value)
681        except (ValueError, TypeError) as e:
682            if 'enum' in attr_spec:
683                return self._encode_enum(attr_spec, value)
684            if attr_spec.display_hint:
685                return self._from_string(value, attr_spec)
686            raise e
687
688    # pylint: disable=too-many-statements
689    def _add_attr(self, space, name, value, search_attrs):
690        try:
691            attr = self.attr_sets[space][name]
692        except KeyError as err:
693            raise YnlException(f"Space '{space}' has no attribute '{name}'") from err
694        nl_type = attr.value
695
696        if attr.is_multi and isinstance(value, list):
697            attr_payload = b''
698            for subvalue in value:
699                attr_payload += self._add_attr(space, name, subvalue, search_attrs)
700            return attr_payload
701
702        if attr["type"] == 'nest':
703            nl_type |= Netlink.NLA_F_NESTED
704            sub_space = attr['nested-attributes']
705            attr_payload = self._add_nest_attrs(value, sub_space, search_attrs)
706        elif attr['type'] == 'indexed-array' and attr['sub-type'] == 'nest':
707            nl_type |= Netlink.NLA_F_NESTED
708            sub_space = attr['nested-attributes']
709            attr_payload = self._encode_indexed_array(value, sub_space,
710                                                      search_attrs)
711        elif attr["type"] == 'flag':
712            if not value:
713                # If value is absent or false then skip attribute creation.
714                return b''
715            attr_payload = b''
716        elif attr["type"] == 'string':
717            attr_payload = str(value).encode('ascii') + b'\x00'
718        elif attr["type"] == 'binary':
719            if value is None:
720                attr_payload = b''
721            elif isinstance(value, bytes):
722                attr_payload = value
723            elif isinstance(value, str):
724                if attr.display_hint:
725                    attr_payload = self._from_string(value, attr)
726                else:
727                    attr_payload = bytes.fromhex(value)
728            elif isinstance(value, dict) and attr.struct_name:
729                attr_payload = self._encode_struct(attr.struct_name, value)
730            elif isinstance(value, list) and attr.sub_type in NlAttr.type_formats:
731                format_ = NlAttr.get_format(attr.sub_type)
732                attr_payload = b''.join([format_.pack(x) for x in value])
733            else:
734                raise YnlException(f'Unknown type for binary attribute, value: {value}')
735        elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar:
736            scalar = self._get_scalar(attr, value)
737            if attr.is_auto_scalar:
738                attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64')
739            else:
740                attr_type = attr["type"]
741            format_ = NlAttr.get_format(attr_type, attr.byte_order)
742            attr_payload = format_.pack(scalar)
743        elif attr['type'] in "bitfield32":
744            scalar_value = self._get_scalar(attr, value["value"])
745            scalar_selector = self._get_scalar(attr, value["selector"])
746            attr_payload = struct.pack("II", scalar_value, scalar_selector)
747        elif attr['type'] == 'sub-message':
748            msg_format, _ = self._resolve_selector(attr, search_attrs)
749            attr_payload = b''
750            if msg_format.fixed_header:
751                attr_payload += self._encode_struct(msg_format.fixed_header, value)
752            if msg_format.attr_set:
753                if msg_format.attr_set in self.attr_sets:
754                    nl_type |= Netlink.NLA_F_NESTED
755                    sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs)
756                    for subname, subvalue in value.items():
757                        attr_payload += self._add_attr(msg_format.attr_set,
758                                                       subname, subvalue, sub_attrs)
759                else:
760                    raise YnlException(f"Unknown attribute-set '{msg_format.attr_set}'")
761        else:
762            raise YnlException(f'Unknown type at {space} {name} {value} {attr["type"]}')
763
764        return self._add_attr_raw(nl_type, attr_payload)
765
766    def _add_attr_raw(self, nl_type, attr_payload):
767        pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
768        return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
769
770    def _add_nest_attrs(self, value, sub_space, search_attrs):
771        sub_attrs = SpaceAttrs(self.attr_sets[sub_space], value, search_attrs)
772        attr_payload = b''
773        for subname, subvalue in value.items():
774            attr_payload += self._add_attr(sub_space, subname, subvalue,
775                                           sub_attrs)
776        return attr_payload
777
778    def _encode_indexed_array(self, vals, sub_space, search_attrs):
779        attr_payload = b''
780        for i, val in enumerate(vals):
781            idx = i | Netlink.NLA_F_NESTED
782            val_payload = self._add_nest_attrs(val, sub_space, search_attrs)
783            attr_payload += self._add_attr_raw(idx, val_payload)
784        return attr_payload
785
786    def _get_enum_or_unknown(self, enum, raw):
787        try:
788            name = enum.entries_by_val[raw].name
789        except KeyError as error:
790            if self.process_unknown:
791                name = f"Unknown({raw})"
792            else:
793                raise error
794        return name
795
796    def _decode_enum(self, raw, attr_spec):
797        enum = self.consts[attr_spec['enum']]
798        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
799            i = 0
800            value = set()
801            while raw:
802                if raw & 1:
803                    value.add(self._get_enum_or_unknown(enum, i))
804                raw >>= 1
805                i += 1
806        else:
807            value = self._get_enum_or_unknown(enum, raw)
808        return value
809
810    def _decode_binary(self, attr, attr_spec):
811        if attr_spec.struct_name:
812            decoded = self._decode_struct(attr.raw, attr_spec.struct_name)
813        elif attr_spec.sub_type:
814            decoded = attr.as_c_array(attr_spec.sub_type)
815            if 'enum' in attr_spec:
816                decoded = [ self._decode_enum(x, attr_spec) for x in decoded ]
817            elif attr_spec.display_hint:
818                decoded = [ self._formatted_string(x, attr_spec.display_hint)
819                            for x in decoded ]
820        else:
821            decoded = attr.as_bin()
822            if attr_spec.display_hint:
823                decoded = self._formatted_string(decoded, attr_spec.display_hint)
824        return decoded
825
826    def _decode_array_attr(self, attr, attr_spec):
827        decoded = []
828        offset = 0
829        while offset < len(attr.raw):
830            item = NlAttr(attr.raw, offset)
831            offset += item.full_len
832
833            if attr_spec["sub-type"] == 'nest':
834                subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes'])
835                decoded.append({ item.type: subattrs })
836            elif attr_spec["sub-type"] == 'binary':
837                subattr = item.as_bin()
838                if attr_spec.display_hint:
839                    subattr = self._formatted_string(subattr, attr_spec.display_hint)
840                decoded.append(subattr)
841            elif attr_spec["sub-type"] in NlAttr.type_formats:
842                subattr = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order)
843                if 'enum' in attr_spec:
844                    subattr = self._decode_enum(subattr, attr_spec)
845                elif attr_spec.display_hint:
846                    subattr = self._formatted_string(subattr, attr_spec.display_hint)
847                decoded.append(subattr)
848            else:
849                raise YnlException(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}')
850        return decoded
851
852    def _decode_nest_type_value(self, attr, attr_spec):
853        decoded = {}
854        value = attr
855        for name in attr_spec['type-value']:
856            value = NlAttr(value.raw, 0)
857            decoded[name] = value.type
858        subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes'])
859        decoded.update(subattrs)
860        return decoded
861
862    def _decode_unknown(self, attr):
863        if attr.is_nest:
864            return self._decode(NlAttrs(attr.raw), None)
865        return attr.as_bin()
866
867    def _rsp_add(self, rsp, name, is_multi, decoded):
868        if is_multi is None:
869            if name in rsp and not isinstance(rsp[name], list):
870                rsp[name] = [rsp[name]]
871                is_multi = True
872            else:
873                is_multi = False
874
875        if not is_multi:
876            rsp[name] = decoded
877        elif name in rsp:
878            rsp[name].append(decoded)
879        else:
880            rsp[name] = [decoded]
881
882    def _resolve_selector(self, attr_spec, search_attrs):
883        sub_msg = attr_spec.sub_message
884        if sub_msg not in self.sub_msgs:
885            raise YnlException(f"No sub-message spec named {sub_msg} for {attr_spec.name}")
886        sub_msg_spec = self.sub_msgs[sub_msg]
887
888        selector = attr_spec.selector
889        value = search_attrs.lookup(selector)
890        if value not in sub_msg_spec.formats:
891            raise YnlException(f"No message format for '{value}' in sub-message spec '{sub_msg}'")
892
893        spec = sub_msg_spec.formats[value]
894        return spec, value
895
896    def _decode_sub_msg(self, attr, attr_spec, search_attrs):
897        msg_format, _ = self._resolve_selector(attr_spec, search_attrs)
898        decoded = {}
899        offset = 0
900        if msg_format.fixed_header:
901            decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header))
902            offset = self.struct_size(msg_format.fixed_header)
903        if msg_format.attr_set:
904            if msg_format.attr_set in self.attr_sets:
905                subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set)
906                decoded.update(subdict)
907            else:
908                raise YnlException(f"Unknown attribute-set '{msg_format.attr_set}' "
909                                   f"when decoding '{attr_spec.name}'")
910        return decoded
911
912    # pylint: disable=too-many-statements
913    def _decode(self, attrs, space, outer_attrs = None):
914        rsp = {}
915        search_attrs = {}
916        if space:
917            attr_space = self.attr_sets[space]
918            search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs)
919
920        for attr in attrs:
921            try:
922                attr_spec = attr_space.attrs_by_val[attr.type]
923            except (KeyError, UnboundLocalError) as err:
924                if not self.process_unknown:
925                    raise YnlException(f"Space '{space}' has no attribute "
926                                       f"with value '{attr.type}'") from err
927                attr_name = f"UnknownAttr({attr.type})"
928                self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr))
929                continue
930
931            try:
932                if attr_spec["type"] == 'pad':
933                    continue
934                elif attr_spec["type"] == 'nest':
935                    subdict = self._decode(NlAttrs(attr.raw),
936                                           attr_spec['nested-attributes'],
937                                           search_attrs)
938                    decoded = subdict
939                elif attr_spec["type"] == 'string':
940                    decoded = attr.as_strz()
941                elif attr_spec["type"] == 'binary':
942                    decoded = self._decode_binary(attr, attr_spec)
943                elif attr_spec["type"] == 'flag':
944                    decoded = True
945                elif attr_spec.is_auto_scalar:
946                    decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order)
947                    if 'enum' in attr_spec:
948                        decoded = self._decode_enum(decoded, attr_spec)
949                elif attr_spec["type"] in NlAttr.type_formats:
950                    decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
951                    if 'enum' in attr_spec:
952                        decoded = self._decode_enum(decoded, attr_spec)
953                    elif attr_spec.display_hint:
954                        decoded = self._formatted_string(decoded, attr_spec.display_hint)
955                elif attr_spec["type"] == 'indexed-array':
956                    decoded = self._decode_array_attr(attr, attr_spec)
957                elif attr_spec["type"] == 'bitfield32':
958                    value, selector = struct.unpack("II", attr.raw)
959                    if 'enum' in attr_spec:
960                        value = self._decode_enum(value, attr_spec)
961                        selector = self._decode_enum(selector, attr_spec)
962                    decoded = {"value": value, "selector": selector}
963                elif attr_spec["type"] == 'sub-message':
964                    decoded = self._decode_sub_msg(attr, attr_spec, search_attrs)
965                elif attr_spec["type"] == 'nest-type-value':
966                    decoded = self._decode_nest_type_value(attr, attr_spec)
967                else:
968                    if not self.process_unknown:
969                        raise YnlException(f'Unknown {attr_spec["type"]} '
970                                           f'with name {attr_spec["name"]}')
971                    decoded = self._decode_unknown(attr)
972
973                self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded)
974            except:
975                print(f"Error decoding '{attr_spec.name}' from '{space}'")
976                raise
977
978        return rsp
979
980    # pylint: disable=too-many-arguments, too-many-positional-arguments
981    def _decode_extack_path(self, attrs, attr_set, offset, target, search_attrs):
982        for attr in attrs:
983            try:
984                attr_spec = attr_set.attrs_by_val[attr.type]
985            except KeyError as err:
986                raise YnlException(
987                    f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") from err
988            if offset > target:
989                break
990            if offset == target:
991                return '.' + attr_spec.name
992
993            if offset + attr.full_len <= target:
994                offset += attr.full_len
995                continue
996
997            pathname = attr_spec.name
998            if attr_spec['type'] == 'nest':
999                sub_attrs = self.attr_sets[attr_spec['nested-attributes']]
1000                search_attrs = SpaceAttrs(sub_attrs, search_attrs.lookup(attr_spec['name']))
1001            elif attr_spec['type'] == 'sub-message':
1002                msg_format, value = self._resolve_selector(attr_spec, search_attrs)
1003                if msg_format is None:
1004                    raise YnlException(f"Can't resolve sub-message of "
1005                                       f"{attr_spec['name']} for extack")
1006                sub_attrs = self.attr_sets[msg_format.attr_set]
1007                pathname += f"({value})"
1008            else:
1009                raise YnlException(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
1010            offset += 4
1011            subpath = self._decode_extack_path(NlAttrs(attr.raw), sub_attrs,
1012                                               offset, target, search_attrs)
1013            if subpath is None:
1014                return None
1015            return '.' + pathname + subpath
1016
1017        return None
1018
1019    def _decode_extack(self, request, op, extack, vals):
1020        if 'bad-attr-offs' not in extack:
1021            return
1022
1023        msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op)
1024        offset = self.nlproto.msghdr_size() + self.struct_size(op.fixed_header)
1025        search_attrs = SpaceAttrs(op.attr_set, vals)
1026        path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
1027                                        extack['bad-attr-offs'], search_attrs)
1028        if path:
1029            del extack['bad-attr-offs']
1030            extack['bad-attr'] = path
1031
1032    def struct_size(self, name):
1033        if name:
1034            members = self.consts[name].members
1035            size = 0
1036            for m in members:
1037                if m.type in ['pad', 'binary']:
1038                    if m.struct:
1039                        size += self.struct_size(m.struct)
1040                    else:
1041                        size += m.len
1042                else:
1043                    format_ = NlAttr.get_format(m.type, m.byte_order)
1044                    size += format_.size
1045            return size
1046        return 0
1047
1048    def _decode_struct(self, data, name):
1049        members = self.consts[name].members
1050        attrs = {}
1051        offset = 0
1052        for m in members:
1053            value = None
1054            if m.type == 'pad':
1055                offset += m.len
1056            elif m.type == 'binary':
1057                if m.struct:
1058                    len_ = self.struct_size(m.struct)
1059                    value = self._decode_struct(data[offset : offset + len_],
1060                                                m.struct)
1061                    offset += len_
1062                else:
1063                    value = data[offset : offset + m.len]
1064                    offset += m.len
1065            else:
1066                format_ = NlAttr.get_format(m.type, m.byte_order)
1067                [ value ] = format_.unpack_from(data, offset)
1068                offset += format_.size
1069            if value is not None:
1070                if m.enum:
1071                    value = self._decode_enum(value, m)
1072                elif m.display_hint:
1073                    value = self._formatted_string(value, m.display_hint)
1074                attrs[m.name] = value
1075        return attrs
1076
1077    def _encode_struct(self, name, vals):
1078        members = self.consts[name].members
1079        attr_payload = b''
1080        for m in members:
1081            value = vals.pop(m.name) if m.name in vals else None
1082            if m.type == 'pad':
1083                attr_payload += bytearray(m.len)
1084            elif m.type == 'binary':
1085                if m.struct:
1086                    if value is None:
1087                        value = {}
1088                    attr_payload += self._encode_struct(m.struct, value)
1089                else:
1090                    if value is None:
1091                        attr_payload += bytearray(m.len)
1092                    else:
1093                        attr_payload += bytes.fromhex(value)
1094            else:
1095                if value is None:
1096                    value = 0
1097                format_ = NlAttr.get_format(m.type, m.byte_order)
1098                attr_payload += format_.pack(value)
1099        return attr_payload
1100
1101    def _formatted_string(self, raw, display_hint):
1102        if display_hint == 'mac':
1103            formatted = ':'.join(f'{b:02x}' for b in raw)
1104        elif display_hint == 'hex':
1105            if isinstance(raw, int):
1106                formatted = hex(raw)
1107            else:
1108                formatted = bytes.hex(raw, ' ')
1109        elif display_hint in [ 'ipv4', 'ipv6', 'ipv4-or-v6' ]:
1110            formatted = format(ipaddress.ip_address(raw))
1111        elif display_hint == 'uuid':
1112            formatted = str(uuid.UUID(bytes=raw))
1113        else:
1114            formatted = raw
1115        return formatted
1116
1117    def _from_string(self, string, attr_spec):
1118        if attr_spec.display_hint in ['ipv4', 'ipv6', 'ipv4-or-v6']:
1119            ip = ipaddress.ip_address(string)
1120            if attr_spec['type'] == 'binary':
1121                raw = ip.packed
1122            else:
1123                raw = int(ip)
1124        elif attr_spec.display_hint == 'hex':
1125            if attr_spec['type'] == 'binary':
1126                raw = bytes.fromhex(string)
1127            else:
1128                raw = int(string, 16)
1129        elif attr_spec.display_hint == 'mac':
1130            # Parse MAC address in format "00:11:22:33:44:55" or "001122334455"
1131            if ':' in string:
1132                mac_bytes = [int(x, 16) for x in string.split(':')]
1133            else:
1134                if len(string) % 2 != 0:
1135                    raise YnlException(f"Invalid MAC address format: {string}")
1136                mac_bytes = [int(string[i:i+2], 16) for i in range(0, len(string), 2)]
1137            raw = bytes(mac_bytes)
1138        else:
1139            raise YnlException(f"Display hint '{attr_spec.display_hint}' not implemented"
1140                            f" when parsing '{attr_spec['name']}'")
1141        return raw
1142
1143    def handle_ntf(self, decoded):
1144        msg = {}
1145        if self.include_raw:
1146            msg['raw'] = decoded
1147        op = self.rsp_by_value[decoded.cmd()]
1148        attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
1149        if op.fixed_header:
1150            attrs.update(self._decode_struct(decoded.raw, op.fixed_header))
1151
1152        msg['name'] = op['name']
1153        msg['msg'] = attrs
1154        self.async_msg_queue.put(msg)
1155
1156    def check_ntf(self):
1157        while True:
1158            try:
1159                reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT)
1160            except BlockingIOError:
1161                return
1162
1163            nms = NlMsgs(reply)
1164            self._recv_dbg_print(reply, nms)
1165            for nl_msg in nms:
1166                if nl_msg.error:
1167                    print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
1168                    print(nl_msg)
1169                    continue
1170                if nl_msg.done:
1171                    print("Netlink done while checking for ntf!?")
1172                    continue
1173
1174                decoded = self.nlproto.decode(self, nl_msg, None)
1175                if decoded.cmd() not in self.async_msg_ids:
1176                    print("Unexpected msg id while checking for ntf", decoded)
1177                    continue
1178
1179                self.handle_ntf(decoded)
1180
1181    def poll_ntf(self, duration=None):
1182        start_time = time.time()
1183        selector = selectors.DefaultSelector()
1184        selector.register(self.sock, selectors.EVENT_READ)
1185
1186        while True:
1187            try:
1188                yield self.async_msg_queue.get_nowait()
1189            except queue.Empty:
1190                if duration is not None:
1191                    timeout = start_time + duration - time.time()
1192                    if timeout <= 0:
1193                        return
1194                else:
1195                    timeout = None
1196                events = selector.select(timeout)
1197                if events:
1198                    self.check_ntf()
1199
1200    def operation_do_attributes(self, name):
1201        """
1202        For a given operation name, find and return a supported
1203        set of attributes (as a dict).
1204        """
1205        op = self.find_operation(name)
1206        if not op:
1207            return None
1208
1209        return op['do']['request']['attributes'].copy()
1210
1211    def _encode_message(self, op, vals, flags, req_seq):
1212        nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
1213        for flag in flags or []:
1214            nl_flags |= flag
1215
1216        msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
1217        if op.fixed_header:
1218            msg += self._encode_struct(op.fixed_header, vals)
1219        search_attrs = SpaceAttrs(op.attr_set, vals)
1220        for name, value in vals.items():
1221            msg += self._add_attr(op.attr_set.name, name, value, search_attrs)
1222        msg = _genl_msg_finalize(msg)
1223        return msg
1224
1225    # pylint: disable=too-many-statements
1226    def _ops(self, ops):
1227        reqs_by_seq = {}
1228        req_seq = random.randint(1024, 65535)
1229        payload = b''
1230        for (method, vals, flags) in ops:
1231            op = self.ops[method]
1232            msg = self._encode_message(op, vals, flags, req_seq)
1233            reqs_by_seq[req_seq] = (op, vals, msg, flags)
1234            payload += msg
1235            req_seq += 1
1236
1237        self.sock.send(payload, 0)
1238
1239        done = False
1240        rsp = []
1241        op_rsp = []
1242        while not done:
1243            reply = self.sock.recv(self._recv_size)
1244            nms = NlMsgs(reply)
1245            self._recv_dbg_print(reply, nms)
1246            for nl_msg in nms:
1247                if nl_msg.nl_seq in reqs_by_seq:
1248                    (op, vals, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq]
1249                    if nl_msg.extack:
1250                        nl_msg.annotate_extack(op.attr_set)
1251                        self._decode_extack(req_msg, op, nl_msg.extack, vals)
1252                else:
1253                    op = None
1254                    req_flags = []
1255
1256                if nl_msg.error:
1257                    raise NlError(nl_msg)
1258                if nl_msg.done:
1259                    if nl_msg.extack:
1260                        print("Netlink warning:")
1261                        print(nl_msg)
1262
1263                    if Netlink.NLM_F_DUMP in req_flags:
1264                        rsp.append(op_rsp)
1265                    elif not op_rsp:
1266                        rsp.append(None)
1267                    elif len(op_rsp) == 1:
1268                        rsp.append(op_rsp[0])
1269                    else:
1270                        rsp.append(op_rsp)
1271                    op_rsp = []
1272
1273                    del reqs_by_seq[nl_msg.nl_seq]
1274                    done = len(reqs_by_seq) == 0
1275                    break
1276
1277                decoded = self.nlproto.decode(self, nl_msg, op)
1278
1279                # Check if this is a reply to our request
1280                if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value:
1281                    if decoded.cmd() in self.async_msg_ids:
1282                        self.handle_ntf(decoded)
1283                        continue
1284                    print('Unexpected message: ' + repr(decoded))
1285                    continue
1286
1287                rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
1288                if op.fixed_header:
1289                    rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header))
1290                op_rsp.append(rsp_msg)
1291
1292        return rsp
1293
1294    def _op(self, method, vals, flags=None, dump=False):
1295        req_flags = flags or []
1296        if dump:
1297            req_flags.append(Netlink.NLM_F_DUMP)
1298
1299        ops = [(method, vals, req_flags)]
1300        return self._ops(ops)[0]
1301
1302    def do(self, method, vals, flags=None):
1303        return self._op(method, vals, flags)
1304
1305    def dump(self, method, vals):
1306        return self._op(method, vals, dump=True)
1307
1308    def do_multi(self, ops):
1309        return self._ops(ops)
1310
1311    def _resolve_policy(self, policy_idx, policy_table, attr_set):
1312        attrs = {}
1313        if policy_idx not in policy_table:
1314            return attrs
1315        for attr_id, decoded in policy_table[policy_idx].items():
1316            if attr_set and attr_id in attr_set.attrs_by_val:
1317                spec = attr_set.attrs_by_val[attr_id]
1318                name = spec['name']
1319            else:
1320                spec = None
1321                name = f'attr-{attr_id}'
1322            if 'policy-idx' in decoded:
1323                sub_set = None
1324                if spec and 'nested-attributes' in spec.yaml:
1325                    sub_set = self.attr_sets[spec.yaml['nested-attributes']]
1326                nested = self._resolve_policy(decoded['policy-idx'],
1327                                              policy_table, sub_set)
1328                del decoded['policy-idx']
1329                decoded['policy'] = nested
1330            attrs[name] = decoded
1331        return attrs
1332
1333    def get_policy(self, op_name, mode):
1334        """Query running kernel for the Netlink policy of an operation.
1335
1336        Allows checking whether kernel supports a given attribute or value.
1337        This method consults the running kernel, not the YAML spec.
1338
1339        Args:
1340            op_name: operation name as it appears in the YAML spec
1341            mode: 'do' or 'dump'
1342
1343        Returns:
1344            NlPolicy with an attrs dict mapping attribute names to
1345            their policy properties (type, min/max, nested, etc.),
1346            or None if the operation has no policy for the given mode.
1347            Empty policy usually implies that the operation rejects
1348            all attributes.
1349        """
1350        op = self.ops[op_name]
1351        op_policy, policy_table = _genl_policy_dump(self.nlproto.family_id,
1352                                                    op.req_value)
1353        if mode not in op_policy:
1354            return None
1355        policy_idx = op_policy[mode]
1356        attrs = self._resolve_policy(policy_idx, policy_table, op.attr_set)
1357        return NlPolicy(attrs)
1358