xref: /linux/tools/net/ynl/pyynl/lib/ynl.py (revision 4ce06406958b67fdddcc2e6948237dd6ff6ba112)
1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2#
3# pylint: disable=missing-class-docstring, missing-function-docstring
4# pylint: disable=too-many-branches, too-many-locals, too-many-instance-attributes
5# pylint: disable=too-many-lines
6
7"""
8YAML Netlink Library
9
10An implementation of the genetlink and raw netlink protocols.
11"""
12
13from collections import namedtuple
14from enum import Enum
15import functools
16import os
17import random
18import socket
19import struct
20from struct import Struct
21import sys
22import ipaddress
23import uuid
24import queue
25import selectors
26import time
27
28from .nlspec import SpecFamily
29
30#
31# Generic Netlink code which should really be in some library, but I can't quickly find one.
32#
33
34
35class YnlException(Exception):
36    pass
37
38
39# pylint: disable=too-few-public-methods
40class Netlink:
41    # Netlink socket
42    SOL_NETLINK = 270
43
44    NETLINK_ADD_MEMBERSHIP = 1
45    NETLINK_CAP_ACK = 10
46    NETLINK_EXT_ACK = 11
47    NETLINK_GET_STRICT_CHK = 12
48
49    # Netlink message
50    NLMSG_ERROR = 2
51    NLMSG_DONE = 3
52
53    NLM_F_REQUEST = 1
54    NLM_F_ACK = 4
55    NLM_F_ROOT = 0x100
56    NLM_F_MATCH = 0x200
57
58    NLM_F_REPLACE = 0x100
59    NLM_F_EXCL = 0x200
60    NLM_F_CREATE = 0x400
61    NLM_F_APPEND = 0x800
62
63    NLM_F_CAPPED = 0x100
64    NLM_F_ACK_TLVS = 0x200
65
66    NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
67
68    NLA_F_NESTED = 0x8000
69    NLA_F_NET_BYTEORDER = 0x4000
70
71    NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
72
73    # Genetlink defines
74    NETLINK_GENERIC = 16
75
76    GENL_ID_CTRL = 0x10
77
78    # nlctrl
79    CTRL_CMD_GETFAMILY = 3
80    CTRL_CMD_GETPOLICY = 10
81
82    CTRL_ATTR_FAMILY_ID = 1
83    CTRL_ATTR_FAMILY_NAME = 2
84    CTRL_ATTR_MAXATTR = 5
85    CTRL_ATTR_MCAST_GROUPS = 7
86    CTRL_ATTR_POLICY = 8
87    CTRL_ATTR_OP_POLICY = 9
88    CTRL_ATTR_OP = 10
89
90    CTRL_ATTR_MCAST_GRP_NAME = 1
91    CTRL_ATTR_MCAST_GRP_ID = 2
92
93    CTRL_ATTR_POLICY_DO = 1
94    CTRL_ATTR_POLICY_DUMP = 2
95
96    # Extack types
97    NLMSGERR_ATTR_MSG = 1
98    NLMSGERR_ATTR_OFFS = 2
99    NLMSGERR_ATTR_COOKIE = 3
100    NLMSGERR_ATTR_POLICY = 4
101    NLMSGERR_ATTR_MISS_TYPE = 5
102    NLMSGERR_ATTR_MISS_NEST = 6
103
104    # Policy types
105    NL_POLICY_TYPE_ATTR_TYPE = 1
106    NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2
107    NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3
108    NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4
109    NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5
110    NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6
111    NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7
112    NL_POLICY_TYPE_ATTR_POLICY_IDX = 8
113    NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9
114    NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10
115    NL_POLICY_TYPE_ATTR_PAD = 11
116    NL_POLICY_TYPE_ATTR_MASK = 12
117
118    AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64',
119                                  's8', 's16', 's32', 's64',
120                                  'binary', 'string', 'nul-string',
121                                  'nested', 'nested-array',
122                                  'bitfield32', 'sint', 'uint'])
123
124class NlError(Exception):
125    def __init__(self, nl_msg):
126        self.nl_msg = nl_msg
127        self.error = -nl_msg.error
128
129    def __str__(self):
130        msg = "Netlink error: "
131
132        extack = self.nl_msg.extack.copy() if self.nl_msg.extack else {}
133        if 'msg' in extack:
134            msg += extack['msg'] + ': '
135            del extack['msg']
136        msg += os.strerror(self.error)
137        if extack:
138            msg += ' ' + str(extack)
139        return msg
140
141
142class ConfigError(Exception):
143    pass
144
145
146class NlPolicy:
147    """Kernel policy for one mode (do or dump) of one operation.
148
149    Returned by YnlFamily.get_policy(). Attributes of the policy
150    are accessible as attributes of the object. Nested policies
151    can be accessed indexing the object like a dictionary::
152
153        pol = ynl.get_policy('page-pool-stats-get', 'do')
154        pol['info'].type            # 'nested'
155        pol['info']['id'].type      # 'uint'
156        pol['info']['id'].min_value # 1
157
158    Each policy entry always has a 'type' attribute (e.g. u32, string,
159    nested). Optional attributes depending on the 'type': min-value,
160    max-value, min-length, max-length, mask.
161
162    Policies can form infinite nesting loops. These loops are trimmed
163    when policy is converted to a dict with pol.to_dict().
164    """
165    def __init__(self, ynl, policy_idx, policy_table, attr_set, props=None):
166        self._policy_idx = policy_idx
167        self._policy_table = policy_table
168        self._ynl = ynl
169        self._props = props or {}
170        self._entries = {}
171        self._cache = {}
172        if policy_idx is not None and policy_idx in policy_table:
173            for attr_id, decoded in policy_table[policy_idx].items():
174                if attr_set and attr_id in attr_set.attrs_by_val:
175                    spec = attr_set.attrs_by_val[attr_id]
176                    name = spec['name']
177                else:
178                    spec = None
179                    name = f'attr-{attr_id}'
180                self._entries[name] = (spec, decoded)
181
182    def __getitem__(self, name):
183        """Descend into a nested policy by attribute name."""
184        if name not in self._cache:
185            spec, decoded = self._entries[name]
186            props = dict(decoded)
187            child_idx = None
188            child_set = None
189            if 'policy-idx' in props:
190                child_idx = props.pop('policy-idx')
191                if spec and 'nested-attributes' in spec.yaml:
192                    child_set = self._ynl.attr_sets[spec.yaml['nested-attributes']]
193            self._cache[name] = NlPolicy(self._ynl, child_idx,
194                                         self._policy_table,
195                                         child_set, props)
196        return self._cache[name]
197
198    def __getattr__(self, name):
199        """Access this policy entry's own properties (type, min-value, etc.).
200
201        Underscores in the name are converted to dashes, so that
202        pol.min_value looks up "min-value".
203        """
204        key = name.replace('_', '-')
205        try:
206            # Hack for level-0 which we still want to have .type but we don't
207            # want type to pointlessly show up in the dict / JSON form.
208            if not self._props and name == "type":
209                return "nested"
210            return self._props[key]
211        except KeyError:
212            raise AttributeError(name)
213
214    def get(self, name, default=None):
215        """Look up a child policy entry by attribute name, with a default."""
216        try:
217            return self[name]
218        except KeyError:
219            return default
220
221    def __contains__(self, name):
222        return name in self._entries
223
224    def __len__(self):
225        return len(self._entries)
226
227    def __iter__(self):
228        return iter(self._entries)
229
230    def keys(self):
231        """Return attribute names accepted by this policy."""
232        return self._entries.keys()
233
234    def to_dict(self, seen=None):
235        """Convert to a plain dict, suitable for JSON serialization.
236
237        Nested NlPolicy objects are expanded recursively. Cyclic
238        references are trimmed (resolved to just {"type": "nested"}).
239        """
240        if seen is None:
241            seen = set()
242        result = dict(self._props)
243        if self._policy_idx is not None:
244            if self._policy_idx not in seen:
245                seen = seen | {self._policy_idx}
246                children = {}
247                for name in self:
248                    children[name] = self[name].to_dict(seen)
249                if self._props:
250                    result['policy'] = children
251                else:
252                    result = children
253        return result
254
255    def __repr__(self):
256        return repr(self.to_dict())
257
258
259class NlAttr:
260    ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
261    type_formats = {
262        'u8' : ScalarFormat(Struct('B'), Struct("B"),  Struct("B")),
263        's8' : ScalarFormat(Struct('b'), Struct("b"),  Struct("b")),
264        'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
265        's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
266        'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
267        's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
268        'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
269        's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
270    }
271
272    def __init__(self, raw, offset):
273        self._len, self._type = struct.unpack("HH", raw[offset : offset + 4])
274        self.type = self._type & ~Netlink.NLA_TYPE_MASK
275        self.is_nest = self._type & Netlink.NLA_F_NESTED
276        self.payload_len = self._len
277        self.full_len = (self.payload_len + 3) & ~3
278        self.raw = raw[offset + 4 : offset + self.payload_len]
279
280    @classmethod
281    def get_format(cls, attr_type, byte_order=None):
282        format_ = cls.type_formats[attr_type]
283        if byte_order:
284            return format_.big if byte_order == "big-endian" \
285                else format_.little
286        return format_.native
287
288    def as_scalar(self, attr_type, byte_order=None):
289        format_ = self.get_format(attr_type, byte_order)
290        return format_.unpack(self.raw)[0]
291
292    def as_auto_scalar(self, attr_type, byte_order=None):
293        if len(self.raw) != 4 and len(self.raw) != 8:
294            raise YnlException(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}")
295        real_type = attr_type[0] + str(len(self.raw) * 8)
296        format_ = self.get_format(real_type, byte_order)
297        return format_.unpack(self.raw)[0]
298
299    def as_strz(self):
300        return self.raw.decode('ascii')[:-1]
301
302    def as_bin(self):
303        return self.raw
304
305    def as_c_array(self, c_type):
306        format_ = self.get_format(c_type)
307        return [ x[0] for x in format_.iter_unpack(self.raw) ]
308
309    def __repr__(self):
310        return f"[type:{self.type} len:{self._len}] {self.raw}"
311
312
313class NlAttrs:
314    def __init__(self, msg, offset=0):
315        self.attrs = []
316
317        while offset < len(msg):
318            attr = NlAttr(msg, offset)
319            offset += attr.full_len
320            self.attrs.append(attr)
321
322    def __iter__(self):
323        yield from self.attrs
324
325    def __repr__(self):
326        msg = ''
327        for a in self.attrs:
328            if msg:
329                msg += '\n'
330            msg += repr(a)
331        return msg
332
333
334class NlMsg:
335    def __init__(self, msg, offset, attr_space=None):
336        self.hdr = msg[offset : offset + 16]
337
338        self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
339            struct.unpack("IHHII", self.hdr)
340
341        self.raw = msg[offset + 16 : offset + self.nl_len]
342
343        self.error = 0
344        self.done = 0
345
346        extack_off = None
347        if self.nl_type == Netlink.NLMSG_ERROR:
348            self.error = struct.unpack("i", self.raw[0:4])[0]
349            self.done = 1
350            extack_off = 20
351        elif self.nl_type == Netlink.NLMSG_DONE:
352            self.error = struct.unpack("i", self.raw[0:4])[0]
353            self.done = 1
354            extack_off = 4
355
356        self.extack = None
357        if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
358            self.extack = {}
359            extack_attrs = NlAttrs(self.raw[extack_off:])
360            for extack in extack_attrs:
361                if extack.type == Netlink.NLMSGERR_ATTR_MSG:
362                    self.extack['msg'] = extack.as_strz()
363                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
364                    self.extack['miss-type'] = extack.as_scalar('u32')
365                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
366                    self.extack['miss-nest'] = extack.as_scalar('u32')
367                elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
368                    self.extack['bad-attr-offs'] = extack.as_scalar('u32')
369                elif extack.type == Netlink.NLMSGERR_ATTR_POLICY:
370                    self.extack['policy'] = _genl_decode_policy(extack.raw)
371                else:
372                    if 'unknown' not in self.extack:
373                        self.extack['unknown'] = []
374                    self.extack['unknown'].append(extack)
375
376            if attr_space:
377                self.annotate_extack(attr_space)
378
379    def annotate_extack(self, attr_space):
380        """ Make extack more human friendly with attribute information """
381
382        # We don't have the ability to parse nests yet, so only do global
383        if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
384            miss_type = self.extack['miss-type']
385            if miss_type in attr_space.attrs_by_val:
386                spec = attr_space.attrs_by_val[miss_type]
387                self.extack['miss-type'] = spec['name']
388                if 'doc' in spec:
389                    self.extack['miss-type-doc'] = spec['doc']
390
391    def cmd(self):
392        return self.nl_type
393
394    def __repr__(self):
395        msg = (f"nl_len = {self.nl_len} ({len(self.raw)}) "
396               f"nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}")
397        if self.error:
398            msg += '\n\terror: ' + str(self.error)
399        if self.extack:
400            msg += '\n\textack: ' + repr(self.extack)
401        return msg
402
403
404# pylint: disable=too-few-public-methods
405class NlMsgs:
406    def __init__(self, data):
407        self.msgs = []
408
409        offset = 0
410        while offset < len(data):
411            msg = NlMsg(data, offset)
412            offset += msg.nl_len
413            self.msgs.append(msg)
414
415    def __iter__(self):
416        yield from self.msgs
417
418
419def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
420    # we prepend length in _genl_msg_finalize()
421    if seq is None:
422        seq = random.randint(1, 1024)
423    nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
424    genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
425    return nlmsg + genlmsg
426
427
428def _genl_msg_finalize(msg):
429    return struct.pack("I", len(msg) + 4) + msg
430
431
432def _genl_decode_policy(raw):
433    policy = {}
434    for attr in NlAttrs(raw):
435        if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE:
436            type_ = attr.as_scalar('u32')
437            policy['type'] = Netlink.AttrType(type_).name
438        elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S:
439            policy['min-value'] = attr.as_scalar('s64')
440        elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S:
441            policy['max-value'] = attr.as_scalar('s64')
442        elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U:
443            policy['min-value'] = attr.as_scalar('u64')
444        elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U:
445            policy['max-value'] = attr.as_scalar('u64')
446        elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH:
447            policy['min-length'] = attr.as_scalar('u32')
448        elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH:
449            policy['max-length'] = attr.as_scalar('u32')
450        elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_POLICY_IDX:
451            policy['policy-idx'] = attr.as_scalar('u32')
452        elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK:
453            policy['bitfield32-mask'] = attr.as_scalar('u32')
454        elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK:
455            policy['mask'] = attr.as_scalar('u64')
456    return policy
457
458
459# pylint: disable=too-many-nested-blocks
460def _genl_load_families():
461    genl_family_name_to_id = {}
462
463    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
464        sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
465
466        msg = _genl_msg(Netlink.GENL_ID_CTRL,
467                        Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
468                        Netlink.CTRL_CMD_GETFAMILY, 1)
469        msg = _genl_msg_finalize(msg)
470
471        sock.send(msg, 0)
472
473        while True:
474            reply = sock.recv(128 * 1024)
475            nms = NlMsgs(reply)
476            for nl_msg in nms:
477                if nl_msg.error:
478                    raise YnlException(f"Netlink error: {nl_msg.error}")
479                if nl_msg.done:
480                    return genl_family_name_to_id
481
482                gm = GenlMsg(nl_msg)
483                fam = {}
484                for attr in NlAttrs(gm.raw):
485                    if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
486                        fam['id'] = attr.as_scalar('u16')
487                    elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
488                        fam['name'] = attr.as_strz()
489                    elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
490                        fam['maxattr'] = attr.as_scalar('u32')
491                    elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
492                        fam['mcast'] = {}
493                        for entry in NlAttrs(attr.raw):
494                            mcast_name = None
495                            mcast_id = None
496                            for entry_attr in NlAttrs(entry.raw):
497                                if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
498                                    mcast_name = entry_attr.as_strz()
499                                elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
500                                    mcast_id = entry_attr.as_scalar('u32')
501                            if mcast_name and mcast_id is not None:
502                                fam['mcast'][mcast_name] = mcast_id
503                if 'name' in fam and 'id' in fam:
504                    genl_family_name_to_id[fam['name']] = fam
505
506
507# pylint: disable=too-many-nested-blocks
508def _genl_policy_dump(family_id, op):
509    op_policy = {}
510    policy_table = {}
511
512    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
513        sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
514
515        msg = _genl_msg(Netlink.GENL_ID_CTRL,
516                        Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
517                        Netlink.CTRL_CMD_GETPOLICY, 1)
518        msg += struct.pack('HHHxx', 6, Netlink.CTRL_ATTR_FAMILY_ID, family_id)
519        msg += struct.pack('HHI', 8, Netlink.CTRL_ATTR_OP, op)
520        msg = _genl_msg_finalize(msg)
521
522        sock.send(msg, 0)
523
524        while True:
525            reply = sock.recv(128 * 1024)
526            nms = NlMsgs(reply)
527            for nl_msg in nms:
528                if nl_msg.error:
529                    raise YnlException(f"Netlink error: {nl_msg.error}")
530                if nl_msg.done:
531                    return op_policy, policy_table
532
533                gm = GenlMsg(nl_msg)
534                for attr in NlAttrs(gm.raw):
535                    if attr.type == Netlink.CTRL_ATTR_OP_POLICY:
536                        for op_attr in NlAttrs(attr.raw):
537                            for method_attr in NlAttrs(op_attr.raw):
538                                if method_attr.type == Netlink.CTRL_ATTR_POLICY_DO:
539                                    op_policy['do'] = method_attr.as_scalar('u32')
540                                elif method_attr.type == Netlink.CTRL_ATTR_POLICY_DUMP:
541                                    op_policy['dump'] = method_attr.as_scalar('u32')
542                    elif attr.type == Netlink.CTRL_ATTR_POLICY:
543                        for pidx_attr in NlAttrs(attr.raw):
544                            policy_idx = pidx_attr.type
545                            for aid_attr in NlAttrs(pidx_attr.raw):
546                                attr_id = aid_attr.type
547                                decoded = _genl_decode_policy(aid_attr.raw)
548                                if policy_idx not in policy_table:
549                                    policy_table[policy_idx] = {}
550                                policy_table[policy_idx][attr_id] = decoded
551
552
553class GenlMsg:
554    def __init__(self, nl_msg):
555        self.nl = nl_msg
556        self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
557        self.raw = nl_msg.raw[4:]
558        self.raw_attrs = []
559
560    def cmd(self):
561        return self.genl_cmd
562
563    def __repr__(self):
564        msg = repr(self.nl)
565        msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
566        for a in self.raw_attrs:
567            msg += '\t\t' + repr(a) + '\n'
568        return msg
569
570
571class NetlinkProtocol:
572    def __init__(self, family_name, proto_num):
573        self.family_name = family_name
574        self.proto_num = proto_num
575
576    def _message(self, nl_type, nl_flags, seq=None):
577        if seq is None:
578            seq = random.randint(1, 1024)
579        nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
580        return nlmsg
581
582    def message(self, flags, command, _version, seq=None):
583        return self._message(command, flags, seq)
584
585    def _decode(self, nl_msg):
586        return nl_msg
587
588    def decode(self, ynl, nl_msg, op):
589        msg = self._decode(nl_msg)
590        if op is None:
591            op = ynl.rsp_by_value[msg.cmd()]
592        fixed_header_size = ynl.struct_size(op.fixed_header)
593        msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size)
594        return msg
595
596    def get_mcast_id(self, mcast_name, mcast_groups):
597        if mcast_name not in mcast_groups:
598            raise YnlException(f'Multicast group "{mcast_name}" not present in the spec')
599        return mcast_groups[mcast_name].value
600
601    def msghdr_size(self):
602        return 16
603
604
605class GenlProtocol(NetlinkProtocol):
606    genl_family_name_to_id = {}
607
608    def __init__(self, family_name):
609        super().__init__(family_name, Netlink.NETLINK_GENERIC)
610
611        if not GenlProtocol.genl_family_name_to_id:
612            GenlProtocol.genl_family_name_to_id = _genl_load_families()
613
614        self.genl_family = GenlProtocol.genl_family_name_to_id[family_name]
615        self.family_id = GenlProtocol.genl_family_name_to_id[family_name]['id']
616
617    def message(self, flags, command, version, seq=None):
618        nlmsg = self._message(self.family_id, flags, seq)
619        genlmsg = struct.pack("BBH", command, version, 0)
620        return nlmsg + genlmsg
621
622    def _decode(self, nl_msg):
623        return GenlMsg(nl_msg)
624
625    def get_mcast_id(self, mcast_name, mcast_groups):
626        if mcast_name not in self.genl_family['mcast']:
627            raise YnlException(f'Multicast group "{mcast_name}" not present in the family')
628        return self.genl_family['mcast'][mcast_name]
629
630    def msghdr_size(self):
631        return super().msghdr_size() + 4
632
633
634# pylint: disable=too-few-public-methods
635class SpaceAttrs:
636    SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values'])
637
638    def __init__(self, attr_space, attrs, outer = None):
639        outer_scopes = outer.scopes if outer else []
640        inner_scope = self.SpecValuesPair(attr_space, attrs)
641        self.scopes = [inner_scope] + outer_scopes
642
643    def lookup(self, name):
644        for scope in self.scopes:
645            if name in scope.spec:
646                if name in scope.values:
647                    return scope.values[name]
648                spec_name = scope.spec.yaml['name']
649                raise YnlException(
650                    f"No value for '{name}' in attribute space '{spec_name}'")
651        raise YnlException(f"Attribute '{name}' not defined in any attribute-set")
652
653
654#
655# YNL implementation details.
656#
657
658
659class YnlFamily(SpecFamily):
660    """
661    YNL family -- a Netlink interface built from a YAML spec.
662
663    Primary use of the class is to execute Netlink commands:
664
665      ynl.<op_name>(attrs, ...)
666
667    By default this will execute the <op_name> as "do", pass dump=True
668    to perform a dump operation.
669
670    ynl.<op_name> is a shorthand / convenience wrapper for the following
671    methods which take the op_name as a string:
672
673      ynl.do(op_name, attrs, flags=None) -- execute a do operation
674      ynl.dump(op_name, attrs)           -- execute a dump operation
675      ynl.do_multi(ops)                  -- batch multiple do operations
676
677    The flags argument in ynl.do() allows passing in extra NLM_F_* flags
678    which may be necessary for old families.
679
680    Notification API:
681
682      ynl.ntf_subscribe(mcast_name)      -- join a multicast group
683      ynl.check_ntf()                    -- drain pending notifications
684      ynl.poll_ntf(duration=None)        -- yield notifications
685
686    Policy introspection allows querying validation criteria from the running
687    kernel. Allows checking whether kernel supports a given attribute or value.
688
689      ynl.get_policy(op_name, mode)      -- query kernel policy for an op
690    """
691    def __init__(self, def_path, schema=None, process_unknown=False,
692                 recv_size=0):
693        super().__init__(def_path, schema)
694
695        self.include_raw = False
696        self.process_unknown = process_unknown
697
698        try:
699            if self.proto == "netlink-raw":
700                self.nlproto = NetlinkProtocol(self.yaml['name'],
701                                               self.yaml['protonum'])
702            else:
703                self.nlproto = GenlProtocol(self.yaml['name'])
704        except KeyError as err:
705            raise YnlException(f"Family '{self.yaml['name']}' not supported by the kernel") from err
706
707        self._recv_dbg = False
708        # Note that netlink will use conservative (min) message size for
709        # the first dump recv() on the socket, our setting will only matter
710        # from the second recv() on.
711        self._recv_size = recv_size if recv_size else 131072
712        # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo)
713        # for a message, so smaller receive sizes will lead to truncation.
714        # Note that the min size for other families may be larger than 4k!
715        if self._recv_size < 4000:
716            raise ConfigError()
717
718        self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
719        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
720        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
721        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)
722
723        self.async_msg_ids = set()
724        self.async_msg_queue = queue.Queue()
725
726        for msg in self.msgs.values():
727            if msg.is_async:
728                self.async_msg_ids.add(msg.rsp_value)
729
730        for op_name, op in self.ops.items():
731            bound_f = functools.partial(self._op, op_name)
732            setattr(self, op.ident_name, bound_f)
733
734
735    def ntf_subscribe(self, mcast_name):
736        mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
737        self.sock.bind((0, 0))
738        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
739                             mcast_id)
740
741    def set_recv_dbg(self, enabled):
742        self._recv_dbg = enabled
743
744    def _recv_dbg_print(self, reply, nl_msgs):
745        if not self._recv_dbg:
746            return
747        print("Recv: read", len(reply), "bytes,",
748              len(nl_msgs.msgs), "messages", file=sys.stderr)
749        for nl_msg in nl_msgs:
750            print("  ", nl_msg, file=sys.stderr)
751
752    def _encode_enum(self, attr_spec, value):
753        enum = self.consts[attr_spec['enum']]
754        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
755            scalar = 0
756            if isinstance(value, str):
757                value = [value]
758            for single_value in value:
759                scalar += enum.entries[single_value].user_value(as_flags = True)
760            return scalar
761        return enum.entries[value].user_value()
762
763    def _get_scalar(self, attr_spec, value):
764        try:
765            return int(value)
766        except (ValueError, TypeError) as e:
767            if 'enum' in attr_spec:
768                return self._encode_enum(attr_spec, value)
769            if attr_spec.display_hint:
770                return self._from_string(value, attr_spec)
771            raise e
772
773    # pylint: disable=too-many-statements
774    def _add_attr(self, space, name, value, search_attrs):
775        try:
776            attr = self.attr_sets[space][name]
777        except KeyError as err:
778            raise YnlException(f"Space '{space}' has no attribute '{name}'") from err
779        nl_type = attr.value
780
781        if attr.is_multi and isinstance(value, list):
782            attr_payload = b''
783            for subvalue in value:
784                attr_payload += self._add_attr(space, name, subvalue, search_attrs)
785            return attr_payload
786
787        if attr["type"] == 'nest':
788            nl_type |= Netlink.NLA_F_NESTED
789            sub_space = attr['nested-attributes']
790            attr_payload = self._add_nest_attrs(value, sub_space, search_attrs)
791        elif attr['type'] == 'indexed-array' and attr['sub-type'] == 'nest':
792            nl_type |= Netlink.NLA_F_NESTED
793            sub_space = attr['nested-attributes']
794            attr_payload = self._encode_indexed_array(value, sub_space,
795                                                      search_attrs)
796        elif attr["type"] == 'flag':
797            if not value:
798                # If value is absent or false then skip attribute creation.
799                return b''
800            attr_payload = b''
801        elif attr["type"] == 'string':
802            attr_payload = str(value).encode('ascii') + b'\x00'
803        elif attr["type"] == 'binary':
804            if value is None:
805                attr_payload = b''
806            elif isinstance(value, bytes):
807                attr_payload = value
808            elif isinstance(value, str):
809                if attr.display_hint:
810                    attr_payload = self._from_string(value, attr)
811                else:
812                    attr_payload = bytes.fromhex(value)
813            elif isinstance(value, dict) and attr.struct_name:
814                attr_payload = self._encode_struct(attr.struct_name, value)
815            elif isinstance(value, list) and attr.sub_type in NlAttr.type_formats:
816                format_ = NlAttr.get_format(attr.sub_type)
817                attr_payload = b''.join([format_.pack(x) for x in value])
818            else:
819                raise YnlException(f'Unknown type for binary attribute, value: {value}')
820        elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar:
821            scalar = self._get_scalar(attr, value)
822            if attr.is_auto_scalar:
823                attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64')
824            else:
825                attr_type = attr["type"]
826            format_ = NlAttr.get_format(attr_type, attr.byte_order)
827            attr_payload = format_.pack(scalar)
828        elif attr['type'] in "bitfield32":
829            scalar_value = self._get_scalar(attr, value["value"])
830            scalar_selector = self._get_scalar(attr, value["selector"])
831            attr_payload = struct.pack("II", scalar_value, scalar_selector)
832        elif attr['type'] == 'sub-message':
833            msg_format, _ = self._resolve_selector(attr, search_attrs)
834            attr_payload = b''
835            if msg_format.fixed_header:
836                attr_payload += self._encode_struct(msg_format.fixed_header, value)
837            if msg_format.attr_set:
838                if msg_format.attr_set in self.attr_sets:
839                    nl_type |= Netlink.NLA_F_NESTED
840                    sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs)
841                    for subname, subvalue in value.items():
842                        attr_payload += self._add_attr(msg_format.attr_set,
843                                                       subname, subvalue, sub_attrs)
844                else:
845                    raise YnlException(f"Unknown attribute-set '{msg_format.attr_set}'")
846        else:
847            raise YnlException(f'Unknown type at {space} {name} {value} {attr["type"]}')
848
849        return self._add_attr_raw(nl_type, attr_payload)
850
851    def _add_attr_raw(self, nl_type, attr_payload):
852        pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
853        return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
854
855    def _add_nest_attrs(self, value, sub_space, search_attrs):
856        sub_attrs = SpaceAttrs(self.attr_sets[sub_space], value, search_attrs)
857        attr_payload = b''
858        for subname, subvalue in value.items():
859            attr_payload += self._add_attr(sub_space, subname, subvalue,
860                                           sub_attrs)
861        return attr_payload
862
863    def _encode_indexed_array(self, vals, sub_space, search_attrs):
864        attr_payload = b''
865        for i, val in enumerate(vals):
866            idx = i | Netlink.NLA_F_NESTED
867            val_payload = self._add_nest_attrs(val, sub_space, search_attrs)
868            attr_payload += self._add_attr_raw(idx, val_payload)
869        return attr_payload
870
871    def _get_enum_or_unknown(self, enum, raw):
872        try:
873            name = enum.entries_by_val[raw].name
874        except KeyError as error:
875            if self.process_unknown:
876                name = f"Unknown({raw})"
877            else:
878                raise error
879        return name
880
881    def _decode_enum(self, raw, attr_spec):
882        enum = self.consts[attr_spec['enum']]
883        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
884            i = 0
885            value = set()
886            while raw:
887                if raw & 1:
888                    value.add(self._get_enum_or_unknown(enum, i))
889                raw >>= 1
890                i += 1
891        else:
892            value = self._get_enum_or_unknown(enum, raw)
893        return value
894
895    def _decode_binary(self, attr, attr_spec):
896        if attr_spec.struct_name:
897            decoded = self._decode_struct(attr.raw, attr_spec.struct_name)
898        elif attr_spec.sub_type:
899            decoded = attr.as_c_array(attr_spec.sub_type)
900            if 'enum' in attr_spec:
901                decoded = [ self._decode_enum(x, attr_spec) for x in decoded ]
902            elif attr_spec.display_hint:
903                decoded = [ self._formatted_string(x, attr_spec.display_hint)
904                            for x in decoded ]
905        else:
906            decoded = attr.as_bin()
907            if attr_spec.display_hint:
908                decoded = self._formatted_string(decoded, attr_spec.display_hint)
909        return decoded
910
911    def _decode_array_attr(self, attr, attr_spec):
912        decoded = []
913        offset = 0
914        while offset < len(attr.raw):
915            item = NlAttr(attr.raw, offset)
916            offset += item.full_len
917
918            if attr_spec["sub-type"] == 'nest':
919                subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes'])
920                decoded.append({ item.type: subattrs })
921            elif attr_spec["sub-type"] == 'binary':
922                subattr = item.as_bin()
923                if attr_spec.display_hint:
924                    subattr = self._formatted_string(subattr, attr_spec.display_hint)
925                decoded.append(subattr)
926            elif attr_spec["sub-type"] in NlAttr.type_formats:
927                subattr = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order)
928                if 'enum' in attr_spec:
929                    subattr = self._decode_enum(subattr, attr_spec)
930                elif attr_spec.display_hint:
931                    subattr = self._formatted_string(subattr, attr_spec.display_hint)
932                decoded.append(subattr)
933            else:
934                raise YnlException(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}')
935        return decoded
936
937    def _decode_nest_type_value(self, attr, attr_spec):
938        decoded = {}
939        value = attr
940        for name in attr_spec['type-value']:
941            value = NlAttr(value.raw, 0)
942            decoded[name] = value.type
943        subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes'])
944        decoded.update(subattrs)
945        return decoded
946
947    def _decode_unknown(self, attr):
948        if attr.is_nest:
949            return self._decode(NlAttrs(attr.raw), None)
950        return attr.as_bin()
951
952    def _rsp_add(self, rsp, name, is_multi, decoded):
953        if is_multi is None:
954            if name in rsp and not isinstance(rsp[name], list):
955                rsp[name] = [rsp[name]]
956                is_multi = True
957            else:
958                is_multi = False
959
960        if not is_multi:
961            rsp[name] = decoded
962        elif name in rsp:
963            rsp[name].append(decoded)
964        else:
965            rsp[name] = [decoded]
966
967    def _resolve_selector(self, attr_spec, search_attrs):
968        sub_msg = attr_spec.sub_message
969        if sub_msg not in self.sub_msgs:
970            raise YnlException(f"No sub-message spec named {sub_msg} for {attr_spec.name}")
971        sub_msg_spec = self.sub_msgs[sub_msg]
972
973        selector = attr_spec.selector
974        value = search_attrs.lookup(selector)
975        if value not in sub_msg_spec.formats:
976            raise YnlException(f"No message format for '{value}' in sub-message spec '{sub_msg}'")
977
978        spec = sub_msg_spec.formats[value]
979        return spec, value
980
981    def _decode_sub_msg(self, attr, attr_spec, search_attrs):
982        msg_format, _ = self._resolve_selector(attr_spec, search_attrs)
983        decoded = {}
984        offset = 0
985        if msg_format.fixed_header:
986            decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header))
987            offset = self.struct_size(msg_format.fixed_header)
988        if msg_format.attr_set:
989            if msg_format.attr_set in self.attr_sets:
990                subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set)
991                decoded.update(subdict)
992            else:
993                raise YnlException(f"Unknown attribute-set '{msg_format.attr_set}' "
994                                   f"when decoding '{attr_spec.name}'")
995        return decoded
996
997    # pylint: disable=too-many-statements
998    def _decode(self, attrs, space, outer_attrs = None):
999        rsp = {}
1000        search_attrs = {}
1001        if space:
1002            attr_space = self.attr_sets[space]
1003            search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs)
1004
1005        for attr in attrs:
1006            try:
1007                attr_spec = attr_space.attrs_by_val[attr.type]
1008            except (KeyError, UnboundLocalError) as err:
1009                if not self.process_unknown:
1010                    raise YnlException(f"Space '{space}' has no attribute "
1011                                       f"with value '{attr.type}'") from err
1012                attr_name = f"UnknownAttr({attr.type})"
1013                self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr))
1014                continue
1015
1016            try:
1017                if attr_spec["type"] == 'pad':
1018                    continue
1019                elif attr_spec["type"] == 'nest':
1020                    subdict = self._decode(NlAttrs(attr.raw),
1021                                           attr_spec['nested-attributes'],
1022                                           search_attrs)
1023                    decoded = subdict
1024                elif attr_spec["type"] == 'string':
1025                    decoded = attr.as_strz()
1026                elif attr_spec["type"] == 'binary':
1027                    decoded = self._decode_binary(attr, attr_spec)
1028                elif attr_spec["type"] == 'flag':
1029                    decoded = True
1030                elif attr_spec.is_auto_scalar:
1031                    decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order)
1032                    if 'enum' in attr_spec:
1033                        decoded = self._decode_enum(decoded, attr_spec)
1034                elif attr_spec["type"] in NlAttr.type_formats:
1035                    decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
1036                    if 'enum' in attr_spec:
1037                        decoded = self._decode_enum(decoded, attr_spec)
1038                    elif attr_spec.display_hint:
1039                        decoded = self._formatted_string(decoded, attr_spec.display_hint)
1040                elif attr_spec["type"] == 'indexed-array':
1041                    decoded = self._decode_array_attr(attr, attr_spec)
1042                elif attr_spec["type"] == 'bitfield32':
1043                    value, selector = struct.unpack("II", attr.raw)
1044                    if 'enum' in attr_spec:
1045                        value = self._decode_enum(value, attr_spec)
1046                        selector = self._decode_enum(selector, attr_spec)
1047                    decoded = {"value": value, "selector": selector}
1048                elif attr_spec["type"] == 'sub-message':
1049                    decoded = self._decode_sub_msg(attr, attr_spec, search_attrs)
1050                elif attr_spec["type"] == 'nest-type-value':
1051                    decoded = self._decode_nest_type_value(attr, attr_spec)
1052                else:
1053                    if not self.process_unknown:
1054                        raise YnlException(f'Unknown {attr_spec["type"]} '
1055                                           f'with name {attr_spec["name"]}')
1056                    decoded = self._decode_unknown(attr)
1057
1058                self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded)
1059            except:
1060                print(f"Error decoding '{attr_spec.name}' from '{space}'")
1061                raise
1062
1063        return rsp
1064
1065    # pylint: disable=too-many-arguments, too-many-positional-arguments
1066    def _decode_extack_path(self, attrs, attr_set, offset, target, search_attrs):
1067        for attr in attrs:
1068            try:
1069                attr_spec = attr_set.attrs_by_val[attr.type]
1070            except KeyError as err:
1071                raise YnlException(
1072                    f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") from err
1073            if offset > target:
1074                break
1075            if offset == target:
1076                return '.' + attr_spec.name
1077
1078            if offset + attr.full_len <= target:
1079                offset += attr.full_len
1080                continue
1081
1082            pathname = attr_spec.name
1083            if attr_spec['type'] == 'nest':
1084                sub_attrs = self.attr_sets[attr_spec['nested-attributes']]
1085                search_attrs = SpaceAttrs(sub_attrs, search_attrs.lookup(attr_spec['name']))
1086            elif attr_spec['type'] == 'sub-message':
1087                msg_format, value = self._resolve_selector(attr_spec, search_attrs)
1088                if msg_format is None:
1089                    raise YnlException(f"Can't resolve sub-message of "
1090                                       f"{attr_spec['name']} for extack")
1091                sub_attrs = self.attr_sets[msg_format.attr_set]
1092                pathname += f"({value})"
1093            else:
1094                raise YnlException(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
1095            offset += 4
1096            subpath = self._decode_extack_path(NlAttrs(attr.raw), sub_attrs,
1097                                               offset, target, search_attrs)
1098            if subpath is None:
1099                return None
1100            return '.' + pathname + subpath
1101
1102        return None
1103
1104    def _decode_extack(self, request, op, extack, vals):
1105        if 'bad-attr-offs' not in extack:
1106            return
1107
1108        msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op)
1109        offset = self.nlproto.msghdr_size() + self.struct_size(op.fixed_header)
1110        search_attrs = SpaceAttrs(op.attr_set, vals)
1111        path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
1112                                        extack['bad-attr-offs'], search_attrs)
1113        if path:
1114            del extack['bad-attr-offs']
1115            extack['bad-attr'] = path
1116
1117    def struct_size(self, name):
1118        if name:
1119            members = self.consts[name].members
1120            size = 0
1121            for m in members:
1122                if m.type in ['pad', 'binary']:
1123                    if m.struct:
1124                        size += self.struct_size(m.struct)
1125                    else:
1126                        size += m.len
1127                else:
1128                    format_ = NlAttr.get_format(m.type, m.byte_order)
1129                    size += format_.size
1130            return size
1131        return 0
1132
1133    def _decode_struct(self, data, name):
1134        members = self.consts[name].members
1135        attrs = {}
1136        offset = 0
1137        for m in members:
1138            value = None
1139            if m.type == 'pad':
1140                offset += m.len
1141            elif m.type == 'binary':
1142                if m.struct:
1143                    len_ = self.struct_size(m.struct)
1144                    value = self._decode_struct(data[offset : offset + len_],
1145                                                m.struct)
1146                    offset += len_
1147                else:
1148                    value = data[offset : offset + m.len]
1149                    offset += m.len
1150            else:
1151                format_ = NlAttr.get_format(m.type, m.byte_order)
1152                [ value ] = format_.unpack_from(data, offset)
1153                offset += format_.size
1154            if value is not None:
1155                if m.enum:
1156                    value = self._decode_enum(value, m)
1157                elif m.display_hint:
1158                    value = self._formatted_string(value, m.display_hint)
1159                attrs[m.name] = value
1160        return attrs
1161
1162    def _encode_struct(self, name, vals):
1163        members = self.consts[name].members
1164        attr_payload = b''
1165        for m in members:
1166            value = vals.pop(m.name) if m.name in vals else None
1167            if m.type == 'pad':
1168                attr_payload += bytearray(m.len)
1169            elif m.type == 'binary':
1170                if m.struct:
1171                    if value is None:
1172                        value = {}
1173                    attr_payload += self._encode_struct(m.struct, value)
1174                else:
1175                    if value is None:
1176                        attr_payload += bytearray(m.len)
1177                    else:
1178                        attr_payload += bytes.fromhex(value)
1179            else:
1180                if value is None:
1181                    value = 0
1182                format_ = NlAttr.get_format(m.type, m.byte_order)
1183                attr_payload += format_.pack(value)
1184        return attr_payload
1185
1186    def _formatted_string(self, raw, display_hint):
1187        if display_hint == 'mac':
1188            formatted = ':'.join(f'{b:02x}' for b in raw)
1189        elif display_hint == 'hex':
1190            if isinstance(raw, int):
1191                formatted = hex(raw)
1192            else:
1193                formatted = bytes.hex(raw, ' ')
1194        elif display_hint in [ 'ipv4', 'ipv6', 'ipv4-or-v6' ]:
1195            formatted = format(ipaddress.ip_address(raw))
1196        elif display_hint == 'uuid':
1197            formatted = str(uuid.UUID(bytes=raw))
1198        else:
1199            formatted = raw
1200        return formatted
1201
1202    def _from_string(self, string, attr_spec):
1203        if attr_spec.display_hint in ['ipv4', 'ipv6', 'ipv4-or-v6']:
1204            ip = ipaddress.ip_address(string)
1205            if attr_spec['type'] == 'binary':
1206                raw = ip.packed
1207            else:
1208                raw = int(ip)
1209        elif attr_spec.display_hint == 'hex':
1210            if attr_spec['type'] == 'binary':
1211                raw = bytes.fromhex(string)
1212            else:
1213                raw = int(string, 16)
1214        elif attr_spec.display_hint == 'mac':
1215            # Parse MAC address in format "00:11:22:33:44:55" or "001122334455"
1216            if ':' in string:
1217                mac_bytes = [int(x, 16) for x in string.split(':')]
1218            else:
1219                if len(string) % 2 != 0:
1220                    raise YnlException(f"Invalid MAC address format: {string}")
1221                mac_bytes = [int(string[i:i+2], 16) for i in range(0, len(string), 2)]
1222            raw = bytes(mac_bytes)
1223        else:
1224            raise YnlException(f"Display hint '{attr_spec.display_hint}' not implemented"
1225                            f" when parsing '{attr_spec['name']}'")
1226        return raw
1227
1228    def handle_ntf(self, decoded):
1229        msg = {}
1230        if self.include_raw:
1231            msg['raw'] = decoded
1232        op = self.rsp_by_value[decoded.cmd()]
1233        attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
1234        if op.fixed_header:
1235            attrs.update(self._decode_struct(decoded.raw, op.fixed_header))
1236
1237        msg['name'] = op['name']
1238        msg['msg'] = attrs
1239        self.async_msg_queue.put(msg)
1240
1241    def check_ntf(self):
1242        while True:
1243            try:
1244                reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT)
1245            except BlockingIOError:
1246                return
1247
1248            nms = NlMsgs(reply)
1249            self._recv_dbg_print(reply, nms)
1250            for nl_msg in nms:
1251                if nl_msg.error:
1252                    print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
1253                    print(nl_msg)
1254                    continue
1255                if nl_msg.done:
1256                    print("Netlink done while checking for ntf!?")
1257                    continue
1258
1259                decoded = self.nlproto.decode(self, nl_msg, None)
1260                if decoded.cmd() not in self.async_msg_ids:
1261                    print("Unexpected msg id while checking for ntf", decoded)
1262                    continue
1263
1264                self.handle_ntf(decoded)
1265
1266    def poll_ntf(self, duration=None):
1267        start_time = time.time()
1268        selector = selectors.DefaultSelector()
1269        selector.register(self.sock, selectors.EVENT_READ)
1270
1271        while True:
1272            try:
1273                yield self.async_msg_queue.get_nowait()
1274            except queue.Empty:
1275                if duration is not None:
1276                    timeout = start_time + duration - time.time()
1277                    if timeout <= 0:
1278                        return
1279                else:
1280                    timeout = None
1281                events = selector.select(timeout)
1282                if events:
1283                    self.check_ntf()
1284
1285    def operation_do_attributes(self, name):
1286        """
1287        For a given operation name, find and return a supported
1288        set of attributes (as a dict).
1289        """
1290        op = self.find_operation(name)
1291        if not op:
1292            return None
1293
1294        return op['do']['request']['attributes'].copy()
1295
1296    def _encode_message(self, op, vals, flags, req_seq):
1297        nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
1298        for flag in flags or []:
1299            nl_flags |= flag
1300
1301        msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
1302        if op.fixed_header:
1303            msg += self._encode_struct(op.fixed_header, vals)
1304        search_attrs = SpaceAttrs(op.attr_set, vals)
1305        for name, value in vals.items():
1306            msg += self._add_attr(op.attr_set.name, name, value, search_attrs)
1307        msg = _genl_msg_finalize(msg)
1308        return msg
1309
1310    # pylint: disable=too-many-statements
1311    def _ops(self, ops):
1312        reqs_by_seq = {}
1313        req_seq = random.randint(1024, 65535)
1314        payload = b''
1315        for (method, vals, flags) in ops:
1316            op = self.ops[method]
1317            msg = self._encode_message(op, vals, flags, req_seq)
1318            reqs_by_seq[req_seq] = (op, vals, msg, flags)
1319            payload += msg
1320            req_seq += 1
1321
1322        self.sock.send(payload, 0)
1323
1324        done = False
1325        rsp = []
1326        op_rsp = []
1327        while not done:
1328            reply = self.sock.recv(self._recv_size)
1329            nms = NlMsgs(reply)
1330            self._recv_dbg_print(reply, nms)
1331            for nl_msg in nms:
1332                if nl_msg.nl_seq in reqs_by_seq:
1333                    (op, vals, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq]
1334                    if nl_msg.extack:
1335                        nl_msg.annotate_extack(op.attr_set)
1336                        self._decode_extack(req_msg, op, nl_msg.extack, vals)
1337                else:
1338                    op = None
1339                    req_flags = []
1340
1341                if nl_msg.error:
1342                    raise NlError(nl_msg)
1343                if nl_msg.done:
1344                    if nl_msg.extack:
1345                        print("Netlink warning:")
1346                        print(nl_msg)
1347
1348                    if Netlink.NLM_F_DUMP in req_flags:
1349                        rsp.append(op_rsp)
1350                    elif not op_rsp:
1351                        rsp.append(None)
1352                    elif len(op_rsp) == 1:
1353                        rsp.append(op_rsp[0])
1354                    else:
1355                        rsp.append(op_rsp)
1356                    op_rsp = []
1357
1358                    del reqs_by_seq[nl_msg.nl_seq]
1359                    done = len(reqs_by_seq) == 0
1360                    break
1361
1362                decoded = self.nlproto.decode(self, nl_msg, op)
1363
1364                # Check if this is a reply to our request
1365                if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value:
1366                    if decoded.cmd() in self.async_msg_ids:
1367                        self.handle_ntf(decoded)
1368                        continue
1369                    print('Unexpected message: ' + repr(decoded))
1370                    continue
1371
1372                rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
1373                if op.fixed_header:
1374                    rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header))
1375                op_rsp.append(rsp_msg)
1376
1377        return rsp
1378
1379    def _op(self, method, vals, flags=None, dump=False):
1380        req_flags = flags or []
1381        if dump:
1382            req_flags.append(Netlink.NLM_F_DUMP)
1383
1384        ops = [(method, vals, req_flags)]
1385        return self._ops(ops)[0]
1386
1387    def do(self, method, vals, flags=None):
1388        return self._op(method, vals, flags)
1389
1390    def dump(self, method, vals):
1391        return self._op(method, vals, dump=True)
1392
1393    def do_multi(self, ops):
1394        return self._ops(ops)
1395
1396    def get_policy(self, op_name, mode):
1397        """Query running kernel for the Netlink policy of an operation.
1398
1399        Allows checking whether kernel supports a given attribute or value.
1400        This method consults the running kernel, not the YAML spec.
1401
1402        Args:
1403            op_name: operation name as it appears in the YAML spec
1404            mode: 'do' or 'dump'
1405
1406        Returns:
1407            NlPolicy acting as a read-only dict mapping attribute names
1408            to their policy properties (type, min/max, nested, etc.),
1409            or None if the operation has no policy for the given mode.
1410            Empty policy usually implies that the operation rejects
1411            all attributes.
1412        """
1413        op = self.ops[op_name]
1414        op_policy, policy_table = _genl_policy_dump(self.nlproto.family_id,
1415                                                    op.req_value)
1416        if mode not in op_policy:
1417            return None
1418        policy_idx = op_policy[mode]
1419        return NlPolicy(self, policy_idx, policy_table, op.attr_set)
1420