xref: /linux/tools/net/ynl/pyynl/lib/ynl.py (revision 68993ced0f618e36cf33388f1e50223e5e6e78cc)
1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2#
3# pylint: disable=missing-class-docstring, missing-function-docstring
4# pylint: disable=too-many-branches, too-many-locals, too-many-instance-attributes
5# pylint: disable=too-many-lines
6
7"""
8YAML Netlink Library
9
10An implementation of the genetlink and raw netlink protocols.
11"""
12
13from collections import namedtuple
14from enum import Enum
15import functools
16import os
17import random
18import socket
19import struct
20from struct import Struct
21import sys
22import ipaddress
23import uuid
24import queue
25import selectors
26import time
27
28from .nlspec import SpecFamily
29
30#
31# Generic Netlink code which should really be in some library, but I can't quickly find one.
32#
33
34
35class YnlException(Exception):
36    pass
37
38
39# pylint: disable=too-few-public-methods
40class Netlink:
41    # Netlink socket
42    SOL_NETLINK = 270
43
44    NETLINK_ADD_MEMBERSHIP = 1
45    NETLINK_LISTEN_ALL_NSID = 8
46    NETLINK_CAP_ACK = 10
47    NETLINK_EXT_ACK = 11
48    NETLINK_GET_STRICT_CHK = 12
49
50    # Netlink message
51    NLMSG_ERROR = 2
52    NLMSG_DONE = 3
53
54    NLM_F_REQUEST = 1
55    NLM_F_ACK = 4
56    NLM_F_ROOT = 0x100
57    NLM_F_MATCH = 0x200
58
59    NLM_F_REPLACE = 0x100
60    NLM_F_EXCL = 0x200
61    NLM_F_CREATE = 0x400
62    NLM_F_APPEND = 0x800
63
64    NLM_F_CAPPED = 0x100
65    NLM_F_ACK_TLVS = 0x200
66
67    NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
68
69    NLA_F_NESTED = 0x8000
70    NLA_F_NET_BYTEORDER = 0x4000
71
72    NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
73
74    # Genetlink defines
75    NETLINK_GENERIC = 16
76
77    GENL_ID_CTRL = 0x10
78
79    # nlctrl
80    CTRL_CMD_GETFAMILY = 3
81    CTRL_CMD_GETPOLICY = 10
82
83    CTRL_ATTR_FAMILY_ID = 1
84    CTRL_ATTR_FAMILY_NAME = 2
85    CTRL_ATTR_MAXATTR = 5
86    CTRL_ATTR_MCAST_GROUPS = 7
87    CTRL_ATTR_POLICY = 8
88    CTRL_ATTR_OP_POLICY = 9
89    CTRL_ATTR_OP = 10
90
91    CTRL_ATTR_MCAST_GRP_NAME = 1
92    CTRL_ATTR_MCAST_GRP_ID = 2
93
94    CTRL_ATTR_POLICY_DO = 1
95    CTRL_ATTR_POLICY_DUMP = 2
96
97    # Extack types
98    NLMSGERR_ATTR_MSG = 1
99    NLMSGERR_ATTR_OFFS = 2
100    NLMSGERR_ATTR_COOKIE = 3
101    NLMSGERR_ATTR_POLICY = 4
102    NLMSGERR_ATTR_MISS_TYPE = 5
103    NLMSGERR_ATTR_MISS_NEST = 6
104
105    # Policy types
106    NL_POLICY_TYPE_ATTR_TYPE = 1
107    NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2
108    NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3
109    NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4
110    NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5
111    NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6
112    NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7
113    NL_POLICY_TYPE_ATTR_POLICY_IDX = 8
114    NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9
115    NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10
116    NL_POLICY_TYPE_ATTR_PAD = 11
117    NL_POLICY_TYPE_ATTR_MASK = 12
118
119    AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64',
120                                  's8', 's16', 's32', 's64',
121                                  'binary', 'string', 'nul-string',
122                                  'nested', 'nested-array',
123                                  'bitfield32', 'sint', 'uint'])
124
125class NlError(Exception):
126    def __init__(self, nl_msg):
127        self.nl_msg = nl_msg
128        self.error = -nl_msg.error
129
130    def __str__(self):
131        msg = "Netlink error: "
132
133        extack = self.nl_msg.extack.copy() if self.nl_msg.extack else {}
134        if 'msg' in extack:
135            msg += extack['msg'] + ': '
136            del extack['msg']
137        msg += os.strerror(self.error)
138        if extack:
139            msg += ' ' + str(extack)
140        return msg
141
142
143class ConfigError(Exception):
144    pass
145
146
147class NlPolicy:
148    """Kernel policy for one mode (do or dump) of one operation.
149
150    Returned by YnlFamily.get_policy(). Attributes of the policy
151    are accessible as attributes of the object. Nested policies
152    can be accessed indexing the object like a dictionary::
153
154        pol = ynl.get_policy('page-pool-stats-get', 'do')
155        pol['info'].type            # 'nested'
156        pol['info']['id'].type      # 'uint'
157        pol['info']['id'].min_value # 1
158
159    Each policy entry always has a 'type' attribute (e.g. u32, string,
160    nested). Optional attributes depending on the 'type': min-value,
161    max-value, min-length, max-length, mask.
162
163    Policies can form infinite nesting loops. These loops are trimmed
164    when policy is converted to a dict with pol.to_dict().
165    """
166    def __init__(self, ynl, policy_idx, policy_table, attr_set, props=None):
167        self._policy_idx = policy_idx
168        self._policy_table = policy_table
169        self._ynl = ynl
170        self._props = props or {}
171        self._entries = {}
172        self._cache = {}
173        if policy_idx is not None and policy_idx in policy_table:
174            for attr_id, decoded in policy_table[policy_idx].items():
175                if attr_set and attr_id in attr_set.attrs_by_val:
176                    spec = attr_set.attrs_by_val[attr_id]
177                    name = spec['name']
178                else:
179                    spec = None
180                    name = f'attr-{attr_id}'
181                self._entries[name] = (spec, decoded)
182
183    def __getitem__(self, name):
184        """Descend into a nested policy by attribute name."""
185        if name not in self._cache:
186            spec, decoded = self._entries[name]
187            props = dict(decoded)
188            child_idx = None
189            child_set = None
190            if 'policy-idx' in props:
191                child_idx = props.pop('policy-idx')
192                if spec and 'nested-attributes' in spec.yaml:
193                    child_set = self._ynl.attr_sets[spec.yaml['nested-attributes']]
194            self._cache[name] = NlPolicy(self._ynl, child_idx,
195                                         self._policy_table,
196                                         child_set, props)
197        return self._cache[name]
198
199    def __getattr__(self, name):
200        """Access this policy entry's own properties (type, min-value, etc.).
201
202        Underscores in the name are converted to dashes, so that
203        pol.min_value looks up "min-value".
204        """
205        key = name.replace('_', '-')
206        try:
207            # Hack for level-0 which we still want to have .type but we don't
208            # want type to pointlessly show up in the dict / JSON form.
209            if not self._props and name == "type":
210                return "nested"
211            return self._props[key]
212        except KeyError:
213            raise AttributeError(name)
214
215    def get(self, name, default=None):
216        """Look up a child policy entry by attribute name, with a default."""
217        try:
218            return self[name]
219        except KeyError:
220            return default
221
222    def __contains__(self, name):
223        return name in self._entries
224
225    def __len__(self):
226        return len(self._entries)
227
228    def __iter__(self):
229        return iter(self._entries)
230
231    def keys(self):
232        """Return attribute names accepted by this policy."""
233        return self._entries.keys()
234
235    def to_dict(self, seen=None):
236        """Convert to a plain dict, suitable for JSON serialization.
237
238        Nested NlPolicy objects are expanded recursively. Cyclic
239        references are trimmed (resolved to just {"type": "nested"}).
240        """
241        if seen is None:
242            seen = set()
243        result = dict(self._props)
244        if self._policy_idx is not None:
245            if self._policy_idx not in seen:
246                seen = seen | {self._policy_idx}
247                children = {}
248                for name in self:
249                    children[name] = self[name].to_dict(seen)
250                if self._props:
251                    result['policy'] = children
252                else:
253                    result = children
254        return result
255
256    def __repr__(self):
257        return repr(self.to_dict())
258
259
260class NlAttr:
261    ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
262    type_formats = {
263        'u8' : ScalarFormat(Struct('B'), Struct("B"),  Struct("B")),
264        's8' : ScalarFormat(Struct('b'), Struct("b"),  Struct("b")),
265        'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
266        's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
267        'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
268        's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
269        'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
270        's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
271    }
272
273    def __init__(self, raw, offset):
274        self._len, self._type = struct.unpack("HH", raw[offset : offset + 4])
275        self.type = self._type & ~Netlink.NLA_TYPE_MASK
276        self.is_nest = self._type & Netlink.NLA_F_NESTED
277        self.payload_len = self._len
278        self.full_len = (self.payload_len + 3) & ~3
279        self.raw = raw[offset + 4 : offset + self.payload_len]
280
281    @classmethod
282    def get_format(cls, attr_type, byte_order=None):
283        format_ = cls.type_formats[attr_type]
284        if byte_order:
285            return format_.big if byte_order == "big-endian" \
286                else format_.little
287        return format_.native
288
289    def as_scalar(self, attr_type, byte_order=None):
290        format_ = self.get_format(attr_type, byte_order)
291        return format_.unpack(self.raw)[0]
292
293    def as_auto_scalar(self, attr_type, byte_order=None):
294        if len(self.raw) != 4 and len(self.raw) != 8:
295            raise YnlException(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}")
296        real_type = attr_type[0] + str(len(self.raw) * 8)
297        format_ = self.get_format(real_type, byte_order)
298        return format_.unpack(self.raw)[0]
299
300    def as_strz(self):
301        return self.raw.decode('ascii')[:-1]
302
303    def as_bin(self):
304        return self.raw
305
306    def as_c_array(self, c_type):
307        format_ = self.get_format(c_type)
308        return [ x[0] for x in format_.iter_unpack(self.raw) ]
309
310    def __repr__(self):
311        return f"[type:{self.type} len:{self._len}] {self.raw}"
312
313
314class NlAttrs:
315    def __init__(self, msg, offset=0):
316        self.attrs = []
317
318        while offset < len(msg):
319            attr = NlAttr(msg, offset)
320            offset += attr.full_len
321            self.attrs.append(attr)
322
323    def __iter__(self):
324        yield from self.attrs
325
326    def __repr__(self):
327        msg = ''
328        for a in self.attrs:
329            if msg:
330                msg += '\n'
331            msg += repr(a)
332        return msg
333
334
335class NlMsg:
336    def __init__(self, msg, offset, attr_space=None):
337        self.hdr = msg[offset : offset + 16]
338
339        self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
340            struct.unpack("IHHII", self.hdr)
341
342        self.raw = msg[offset + 16 : offset + self.nl_len]
343
344        self.error = 0
345        self.done = 0
346
347        extack_off = None
348        if self.nl_type == Netlink.NLMSG_ERROR:
349            self.error = struct.unpack("i", self.raw[0:4])[0]
350            self.done = 1
351            extack_off = 20
352        elif self.nl_type == Netlink.NLMSG_DONE:
353            self.error = struct.unpack("i", self.raw[0:4])[0]
354            self.done = 1
355            extack_off = 4
356
357        self.extack = None
358        if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
359            self.extack = {}
360            extack_attrs = NlAttrs(self.raw[extack_off:])
361            for extack in extack_attrs:
362                if extack.type == Netlink.NLMSGERR_ATTR_MSG:
363                    self.extack['msg'] = extack.as_strz()
364                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
365                    self.extack['miss-type'] = extack.as_scalar('u32')
366                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
367                    self.extack['miss-nest'] = extack.as_scalar('u32')
368                elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
369                    self.extack['bad-attr-offs'] = extack.as_scalar('u32')
370                elif extack.type == Netlink.NLMSGERR_ATTR_POLICY:
371                    self.extack['policy'] = _genl_decode_policy(extack.raw)
372                else:
373                    if 'unknown' not in self.extack:
374                        self.extack['unknown'] = []
375                    self.extack['unknown'].append(extack)
376
377            if attr_space:
378                self.annotate_extack(attr_space)
379
380    def annotate_extack(self, attr_space):
381        """ Make extack more human friendly with attribute information """
382
383        # We don't have the ability to parse nests yet, so only do global
384        if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
385            miss_type = self.extack['miss-type']
386            if miss_type in attr_space.attrs_by_val:
387                spec = attr_space.attrs_by_val[miss_type]
388                self.extack['miss-type'] = spec['name']
389                if 'doc' in spec:
390                    self.extack['miss-type-doc'] = spec['doc']
391
392    def cmd(self):
393        return self.nl_type
394
395    def __repr__(self):
396        msg = (f"nl_len = {self.nl_len} ({len(self.raw)}) "
397               f"nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}")
398        if self.error:
399            msg += '\n\terror: ' + str(self.error)
400        if self.extack:
401            msg += '\n\textack: ' + repr(self.extack)
402        return msg
403
404
405# pylint: disable=too-few-public-methods
406class NlMsgs:
407    def __init__(self, data):
408        self.msgs = []
409
410        offset = 0
411        while offset < len(data):
412            msg = NlMsg(data, offset)
413            offset += msg.nl_len
414            self.msgs.append(msg)
415
416    def __iter__(self):
417        yield from self.msgs
418
419
420def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
421    # we prepend length in _genl_msg_finalize()
422    if seq is None:
423        seq = random.randint(1, 1024)
424    nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
425    genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
426    return nlmsg + genlmsg
427
428
429def _genl_msg_finalize(msg):
430    return struct.pack("I", len(msg) + 4) + msg
431
432
433def _genl_decode_policy(raw):
434    policy = {}
435    for attr in NlAttrs(raw):
436        if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE:
437            type_ = attr.as_scalar('u32')
438            policy['type'] = Netlink.AttrType(type_).name
439        elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S:
440            policy['min-value'] = attr.as_scalar('s64')
441        elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S:
442            policy['max-value'] = attr.as_scalar('s64')
443        elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U:
444            policy['min-value'] = attr.as_scalar('u64')
445        elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U:
446            policy['max-value'] = attr.as_scalar('u64')
447        elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH:
448            policy['min-length'] = attr.as_scalar('u32')
449        elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH:
450            policy['max-length'] = attr.as_scalar('u32')
451        elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_POLICY_IDX:
452            policy['policy-idx'] = attr.as_scalar('u32')
453        elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK:
454            policy['bitfield32-mask'] = attr.as_scalar('u32')
455        elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK:
456            policy['mask'] = attr.as_scalar('u64')
457    return policy
458
459
460# pylint: disable=too-many-nested-blocks
461def _genl_load_families():
462    genl_family_name_to_id = {}
463
464    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
465        sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
466
467        msg = _genl_msg(Netlink.GENL_ID_CTRL,
468                        Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
469                        Netlink.CTRL_CMD_GETFAMILY, 1)
470        msg = _genl_msg_finalize(msg)
471
472        sock.send(msg, 0)
473
474        while True:
475            reply = sock.recv(128 * 1024)
476            nms = NlMsgs(reply)
477            for nl_msg in nms:
478                if nl_msg.error:
479                    raise YnlException(f"Netlink error: {nl_msg.error}")
480                if nl_msg.done:
481                    return genl_family_name_to_id
482
483                gm = GenlMsg(nl_msg)
484                fam = {}
485                for attr in NlAttrs(gm.raw):
486                    if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
487                        fam['id'] = attr.as_scalar('u16')
488                    elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
489                        fam['name'] = attr.as_strz()
490                    elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
491                        fam['maxattr'] = attr.as_scalar('u32')
492                    elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
493                        fam['mcast'] = {}
494                        for entry in NlAttrs(attr.raw):
495                            mcast_name = None
496                            mcast_id = None
497                            for entry_attr in NlAttrs(entry.raw):
498                                if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
499                                    mcast_name = entry_attr.as_strz()
500                                elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
501                                    mcast_id = entry_attr.as_scalar('u32')
502                            if mcast_name and mcast_id is not None:
503                                fam['mcast'][mcast_name] = mcast_id
504                if 'name' in fam and 'id' in fam:
505                    genl_family_name_to_id[fam['name']] = fam
506
507
508# pylint: disable=too-many-nested-blocks
509def _genl_policy_dump(family_id, op):
510    op_policy = {}
511    policy_table = {}
512
513    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
514        sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
515
516        msg = _genl_msg(Netlink.GENL_ID_CTRL,
517                        Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
518                        Netlink.CTRL_CMD_GETPOLICY, 1)
519        msg += struct.pack('HHHxx', 6, Netlink.CTRL_ATTR_FAMILY_ID, family_id)
520        msg += struct.pack('HHI', 8, Netlink.CTRL_ATTR_OP, op)
521        msg = _genl_msg_finalize(msg)
522
523        sock.send(msg, 0)
524
525        while True:
526            reply = sock.recv(128 * 1024)
527            nms = NlMsgs(reply)
528            for nl_msg in nms:
529                if nl_msg.error:
530                    raise YnlException(f"Netlink error: {nl_msg.error}")
531                if nl_msg.done:
532                    return op_policy, policy_table
533
534                gm = GenlMsg(nl_msg)
535                for attr in NlAttrs(gm.raw):
536                    if attr.type == Netlink.CTRL_ATTR_OP_POLICY:
537                        for op_attr in NlAttrs(attr.raw):
538                            for method_attr in NlAttrs(op_attr.raw):
539                                if method_attr.type == Netlink.CTRL_ATTR_POLICY_DO:
540                                    op_policy['do'] = method_attr.as_scalar('u32')
541                                elif method_attr.type == Netlink.CTRL_ATTR_POLICY_DUMP:
542                                    op_policy['dump'] = method_attr.as_scalar('u32')
543                    elif attr.type == Netlink.CTRL_ATTR_POLICY:
544                        for pidx_attr in NlAttrs(attr.raw):
545                            policy_idx = pidx_attr.type
546                            for aid_attr in NlAttrs(pidx_attr.raw):
547                                attr_id = aid_attr.type
548                                decoded = _genl_decode_policy(aid_attr.raw)
549                                if policy_idx not in policy_table:
550                                    policy_table[policy_idx] = {}
551                                policy_table[policy_idx][attr_id] = decoded
552
553
554class GenlMsg:
555    def __init__(self, nl_msg):
556        self.nl = nl_msg
557        self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
558        self.raw = nl_msg.raw[4:]
559        self.raw_attrs = []
560
561    def cmd(self):
562        return self.genl_cmd
563
564    def __repr__(self):
565        msg = repr(self.nl)
566        msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
567        for a in self.raw_attrs:
568            msg += '\t\t' + repr(a) + '\n'
569        return msg
570
571
572class NetlinkProtocol:
573    def __init__(self, family_name, proto_num):
574        self.family_name = family_name
575        self.proto_num = proto_num
576
577    def _message(self, nl_type, nl_flags, seq=None):
578        if seq is None:
579            seq = random.randint(1, 1024)
580        nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
581        return nlmsg
582
583    def message(self, flags, command, _version, seq=None):
584        return self._message(command, flags, seq)
585
586    def _decode(self, nl_msg):
587        return nl_msg
588
589    def decode(self, ynl, nl_msg, op):
590        msg = self._decode(nl_msg)
591        if op is None:
592            op = ynl.rsp_by_value[msg.cmd()]
593        fixed_header_size = ynl.struct_size(op.fixed_header)
594        msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size)
595        return msg
596
597    def get_mcast_id(self, mcast_name, mcast_groups):
598        if mcast_name not in mcast_groups:
599            raise YnlException(f'Multicast group "{mcast_name}" not present in the spec')
600        return mcast_groups[mcast_name].value
601
602    def msghdr_size(self):
603        return 16
604
605
606class GenlProtocol(NetlinkProtocol):
607    genl_family_name_to_id = {}
608
609    def __init__(self, family_name):
610        super().__init__(family_name, Netlink.NETLINK_GENERIC)
611
612        if not GenlProtocol.genl_family_name_to_id:
613            GenlProtocol.genl_family_name_to_id = _genl_load_families()
614
615        self.genl_family = GenlProtocol.genl_family_name_to_id[family_name]
616        self.family_id = GenlProtocol.genl_family_name_to_id[family_name]['id']
617
618    def message(self, flags, command, version, seq=None):
619        nlmsg = self._message(self.family_id, flags, seq)
620        genlmsg = struct.pack("BBH", command, version, 0)
621        return nlmsg + genlmsg
622
623    def _decode(self, nl_msg):
624        return GenlMsg(nl_msg)
625
626    def get_mcast_id(self, mcast_name, mcast_groups):
627        if mcast_name not in self.genl_family['mcast']:
628            raise YnlException(f'Multicast group "{mcast_name}" not present in the family')
629        return self.genl_family['mcast'][mcast_name]
630
631    def msghdr_size(self):
632        return super().msghdr_size() + 4
633
634
635# pylint: disable=too-few-public-methods
636class SpaceAttrs:
637    SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values'])
638
639    def __init__(self, attr_space, attrs, outer = None):
640        outer_scopes = outer.scopes if outer else []
641        inner_scope = self.SpecValuesPair(attr_space, attrs)
642        self.scopes = [inner_scope] + outer_scopes
643
644    def lookup(self, name):
645        for scope in self.scopes:
646            if name in scope.spec:
647                if name in scope.values:
648                    return scope.values[name]
649                spec_name = scope.spec.yaml['name']
650                raise YnlException(
651                    f"No value for '{name}' in attribute space '{spec_name}'")
652        raise YnlException(f"Attribute '{name}' not defined in any attribute-set")
653
654
655#
656# YNL implementation details.
657#
658
659
660class YnlFamily(SpecFamily):
661    """
662    YNL family -- a Netlink interface built from a YAML spec.
663
664    Primary use of the class is to execute Netlink commands:
665
666      ynl.<op_name>(attrs, ...)
667
668    By default this will execute the <op_name> as "do", pass dump=True
669    to perform a dump operation.
670
671    ynl.<op_name> is a shorthand / convenience wrapper for the following
672    methods which take the op_name as a string:
673
674      ynl.do(op_name, attrs, flags=None) -- execute a do operation
675      ynl.dump(op_name, attrs)           -- execute a dump operation
676      ynl.do_multi(ops)                  -- batch multiple do operations
677
678    The flags argument in ynl.do() allows passing in extra NLM_F_* flags
679    which may be necessary for old families.
680
681    Notification API:
682
683      ynl.ntf_subscribe(mcast_name)      -- join a multicast group
684      ynl.ntf_listen_all_nsid()          -- listen on all netns
685      ynl.check_ntf()                    -- drain pending notifications
686      ynl.poll_ntf(duration=None)        -- yield notifications
687
688    Policy introspection allows querying validation criteria from the running
689    kernel. Allows checking whether kernel supports a given attribute or value.
690
691      ynl.get_policy(op_name, mode)      -- query kernel policy for an op
692    """
693    def __init__(self, def_path, schema=None, process_unknown=False,
694                 recv_size=0):
695        super().__init__(def_path, schema)
696
697        self.include_raw = False
698        self.process_unknown = process_unknown
699
700        try:
701            if self.proto == "netlink-raw":
702                self.nlproto = NetlinkProtocol(self.yaml['name'],
703                                               self.yaml['protonum'])
704            else:
705                self.nlproto = GenlProtocol(self.yaml['name'])
706        except KeyError as err:
707            raise YnlException(f"Family '{self.yaml['name']}' not supported by the kernel") from err
708
709        self._recv_dbg = False
710        # Note that netlink will use conservative (min) message size for
711        # the first dump recv() on the socket, our setting will only matter
712        # from the second recv() on.
713        self._recv_size = recv_size if recv_size else 131072
714        # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo)
715        # for a message, so smaller receive sizes will lead to truncation.
716        # Note that the min size for other families may be larger than 4k!
717        if self._recv_size < 4000:
718            raise ConfigError()
719
720        self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
721        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
722        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
723        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)
724
725        self.async_msg_ids = set()
726        self.async_msg_queue = queue.Queue()
727
728        for msg in self.msgs.values():
729            if msg.is_async:
730                self.async_msg_ids.add(msg.rsp_value)
731
732        for op_name, op in self.ops.items():
733            bound_f = functools.partial(self._op, op_name)
734            setattr(self, op.ident_name, bound_f)
735
736    def close(self):
737        if self.sock is not None:
738            self.sock.close()
739            self.sock = None
740
741    def __enter__(self):
742        return self
743
744    def __exit__(self, exc_type, exc, tb):
745        self.close()
746
747    def ntf_subscribe(self, mcast_name):
748        mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
749        self.sock.bind((0, 0))
750        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
751                             mcast_id)
752
753    def ntf_listen_all_nsid(self):
754        """Enable NETLINK_LISTEN_ALL_NSID to receive notifications from all
755        namespaces that have an nsid mapped in the current one."""
756        self.sock.setsockopt(Netlink.SOL_NETLINK,
757                             Netlink.NETLINK_LISTEN_ALL_NSID, 1)
758
759    @staticmethod
760    def _decode_nsid(ancdata):
761        for cmsg_level, cmsg_type, cmsg_data in ancdata:
762            if (cmsg_level == Netlink.SOL_NETLINK and
763                    cmsg_type == Netlink.NETLINK_LISTEN_ALL_NSID):
764                nsid = struct.unpack('i', cmsg_data)[0]
765                if nsid >= 0:
766                    return nsid
767                return None
768        return None
769
770    def set_recv_dbg(self, enabled):
771        self._recv_dbg = enabled
772
773    def _recv_dbg_print(self, reply, nl_msgs):
774        if not self._recv_dbg:
775            return
776        print("Recv: read", len(reply), "bytes,",
777              len(nl_msgs.msgs), "messages", file=sys.stderr)
778        for nl_msg in nl_msgs:
779            print("  ", nl_msg, file=sys.stderr)
780
781    def _encode_enum(self, attr_spec, value):
782        enum = self.consts[attr_spec['enum']]
783        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
784            scalar = 0
785            if isinstance(value, str):
786                value = [value]
787            for single_value in value:
788                scalar += enum.entries[single_value].user_value(as_flags = True)
789            return scalar
790        return enum.entries[value].user_value()
791
792    def _get_scalar(self, attr_spec, value):
793        try:
794            return int(value)
795        except (ValueError, TypeError) as e:
796            if 'enum' in attr_spec:
797                return self._encode_enum(attr_spec, value)
798            if attr_spec.display_hint:
799                return self._from_string(value, attr_spec)
800            raise e
801
802    # pylint: disable=too-many-statements
803    def _add_attr(self, space, name, value, search_attrs):
804        try:
805            attr = self.attr_sets[space][name]
806        except KeyError as err:
807            raise YnlException(f"Space '{space}' has no attribute '{name}'") from err
808        nl_type = attr.value
809
810        if attr.is_multi and isinstance(value, list):
811            attr_payload = b''
812            for subvalue in value:
813                attr_payload += self._add_attr(space, name, subvalue, search_attrs)
814            return attr_payload
815
816        if attr["type"] == 'nest':
817            nl_type |= Netlink.NLA_F_NESTED
818            sub_space = attr['nested-attributes']
819            attr_payload = self._add_nest_attrs(value, sub_space, search_attrs)
820        elif attr['type'] == 'indexed-array' and attr['sub-type'] == 'nest':
821            nl_type |= Netlink.NLA_F_NESTED
822            sub_space = attr['nested-attributes']
823            attr_payload = self._encode_indexed_array(value, sub_space,
824                                                      search_attrs)
825        elif attr["type"] == 'flag':
826            if not value:
827                # If value is absent or false then skip attribute creation.
828                return b''
829            attr_payload = b''
830        elif attr["type"] == 'string':
831            attr_payload = str(value).encode('ascii') + b'\x00'
832        elif attr["type"] == 'binary':
833            if value is None:
834                attr_payload = b''
835            elif isinstance(value, bytes):
836                attr_payload = value
837            elif isinstance(value, str):
838                if attr.display_hint:
839                    attr_payload = self._from_string(value, attr)
840                else:
841                    attr_payload = bytes.fromhex(value)
842            elif isinstance(value, dict) and attr.struct_name:
843                attr_payload = self._encode_struct(attr.struct_name, value)
844            elif isinstance(value, list) and attr.sub_type in NlAttr.type_formats:
845                format_ = NlAttr.get_format(attr.sub_type)
846                attr_payload = b''.join([format_.pack(x) for x in value])
847            else:
848                raise YnlException(f'Unknown type for binary attribute, value: {value}')
849        elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar:
850            scalar = self._get_scalar(attr, value)
851            if attr.is_auto_scalar:
852                attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64')
853            else:
854                attr_type = attr["type"]
855            format_ = NlAttr.get_format(attr_type, attr.byte_order)
856            attr_payload = format_.pack(scalar)
857        elif attr['type'] in "bitfield32":
858            scalar_value = self._get_scalar(attr, value["value"])
859            scalar_selector = self._get_scalar(attr, value["selector"])
860            attr_payload = struct.pack("II", scalar_value, scalar_selector)
861        elif attr['type'] == 'sub-message':
862            msg_format, _ = self._resolve_selector(attr, search_attrs)
863            attr_payload = b''
864            if msg_format.fixed_header:
865                attr_payload += self._encode_struct(msg_format.fixed_header, value)
866            if msg_format.attr_set:
867                if msg_format.attr_set in self.attr_sets:
868                    nl_type |= Netlink.NLA_F_NESTED
869                    sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs)
870                    for subname, subvalue in value.items():
871                        attr_payload += self._add_attr(msg_format.attr_set,
872                                                       subname, subvalue, sub_attrs)
873                else:
874                    raise YnlException(f"Unknown attribute-set '{msg_format.attr_set}'")
875        else:
876            raise YnlException(f'Unknown type at {space} {name} {value} {attr["type"]}')
877
878        return self._add_attr_raw(nl_type, attr_payload)
879
880    def _add_attr_raw(self, nl_type, attr_payload):
881        pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
882        return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
883
884    def _add_nest_attrs(self, value, sub_space, search_attrs):
885        sub_attrs = SpaceAttrs(self.attr_sets[sub_space], value, search_attrs)
886        attr_payload = b''
887        for subname, subvalue in value.items():
888            attr_payload += self._add_attr(sub_space, subname, subvalue,
889                                           sub_attrs)
890        return attr_payload
891
892    def _encode_indexed_array(self, vals, sub_space, search_attrs):
893        attr_payload = b''
894        for i, val in enumerate(vals):
895            idx = i | Netlink.NLA_F_NESTED
896            val_payload = self._add_nest_attrs(val, sub_space, search_attrs)
897            attr_payload += self._add_attr_raw(idx, val_payload)
898        return attr_payload
899
900    def _get_enum_or_unknown(self, enum, raw):
901        try:
902            name = enum.entries_by_val[raw].name
903        except KeyError as error:
904            if self.process_unknown:
905                name = f"Unknown({raw})"
906            else:
907                raise error
908        return name
909
910    def _decode_enum(self, raw, attr_spec):
911        enum = self.consts[attr_spec['enum']]
912        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
913            i = 0
914            value = set()
915            while raw:
916                if raw & 1:
917                    value.add(self._get_enum_or_unknown(enum, i))
918                raw >>= 1
919                i += 1
920        else:
921            value = self._get_enum_or_unknown(enum, raw)
922        return value
923
924    def _decode_binary(self, attr, attr_spec):
925        if attr_spec.struct_name:
926            decoded = self._decode_struct(attr.raw, attr_spec.struct_name)
927        elif attr_spec.sub_type:
928            decoded = attr.as_c_array(attr_spec.sub_type)
929            if 'enum' in attr_spec:
930                decoded = [ self._decode_enum(x, attr_spec) for x in decoded ]
931            elif attr_spec.display_hint:
932                decoded = [ self._formatted_string(x, attr_spec.display_hint)
933                            for x in decoded ]
934        else:
935            decoded = attr.as_bin()
936            if attr_spec.display_hint:
937                decoded = self._formatted_string(decoded, attr_spec.display_hint)
938        return decoded
939
940    def _decode_array_attr(self, attr, attr_spec):
941        decoded = []
942        offset = 0
943        while offset < len(attr.raw):
944            item = NlAttr(attr.raw, offset)
945            offset += item.full_len
946
947            if attr_spec["sub-type"] == 'nest':
948                subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes'])
949                decoded.append({ item.type: subattrs })
950            elif attr_spec["sub-type"] == 'binary':
951                subattr = item.as_bin()
952                if attr_spec.display_hint:
953                    subattr = self._formatted_string(subattr, attr_spec.display_hint)
954                decoded.append(subattr)
955            elif attr_spec["sub-type"] in NlAttr.type_formats:
956                subattr = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order)
957                if 'enum' in attr_spec:
958                    subattr = self._decode_enum(subattr, attr_spec)
959                elif attr_spec.display_hint:
960                    subattr = self._formatted_string(subattr, attr_spec.display_hint)
961                decoded.append(subattr)
962            else:
963                raise YnlException(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}')
964        return decoded
965
966    def _decode_nest_type_value(self, attr, attr_spec):
967        decoded = {}
968        value = attr
969        for name in attr_spec['type-value']:
970            value = NlAttr(value.raw, 0)
971            decoded[name] = value.type
972        subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes'])
973        decoded.update(subattrs)
974        return decoded
975
976    def _decode_unknown(self, attr):
977        if attr.is_nest:
978            return self._decode(NlAttrs(attr.raw), None)
979        return attr.as_bin()
980
981    def _rsp_add(self, rsp, name, is_multi, decoded):
982        if is_multi is None:
983            if name in rsp and not isinstance(rsp[name], list):
984                rsp[name] = [rsp[name]]
985                is_multi = True
986            else:
987                is_multi = False
988
989        if not is_multi:
990            rsp[name] = decoded
991        elif name in rsp:
992            rsp[name].append(decoded)
993        else:
994            rsp[name] = [decoded]
995
996    def _resolve_selector(self, attr_spec, search_attrs):
997        sub_msg = attr_spec.sub_message
998        if sub_msg not in self.sub_msgs:
999            raise YnlException(f"No sub-message spec named {sub_msg} for {attr_spec.name}")
1000        sub_msg_spec = self.sub_msgs[sub_msg]
1001
1002        selector = attr_spec.selector
1003        value = search_attrs.lookup(selector)
1004        if value not in sub_msg_spec.formats:
1005            raise YnlException(f"No message format for '{value}' in sub-message spec '{sub_msg}'")
1006
1007        spec = sub_msg_spec.formats[value]
1008        return spec, value
1009
1010    def _decode_sub_msg(self, attr, attr_spec, search_attrs):
1011        msg_format, _ = self._resolve_selector(attr_spec, search_attrs)
1012        decoded = {}
1013        offset = 0
1014        if msg_format.fixed_header:
1015            decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header))
1016            offset = self.struct_size(msg_format.fixed_header)
1017        if msg_format.attr_set:
1018            if msg_format.attr_set in self.attr_sets:
1019                subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set)
1020                decoded.update(subdict)
1021            else:
1022                raise YnlException(f"Unknown attribute-set '{msg_format.attr_set}' "
1023                                   f"when decoding '{attr_spec.name}'")
1024        return decoded
1025
1026    # pylint: disable=too-many-statements
1027    def _decode(self, attrs, space, outer_attrs = None):
1028        rsp = {}
1029        search_attrs = {}
1030        if space:
1031            attr_space = self.attr_sets[space]
1032            search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs)
1033
1034        for attr in attrs:
1035            try:
1036                attr_spec = attr_space.attrs_by_val[attr.type]
1037            except (KeyError, UnboundLocalError) as err:
1038                if not self.process_unknown:
1039                    raise YnlException(f"Space '{space}' has no attribute "
1040                                       f"with value '{attr.type}'") from err
1041                attr_name = f"UnknownAttr({attr.type})"
1042                self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr))
1043                continue
1044
1045            try:
1046                if attr_spec["type"] == 'pad':
1047                    continue
1048                elif attr_spec["type"] == 'nest':
1049                    subdict = self._decode(NlAttrs(attr.raw),
1050                                           attr_spec['nested-attributes'],
1051                                           search_attrs)
1052                    decoded = subdict
1053                elif attr_spec["type"] == 'string':
1054                    decoded = attr.as_strz()
1055                elif attr_spec["type"] == 'binary':
1056                    decoded = self._decode_binary(attr, attr_spec)
1057                elif attr_spec["type"] == 'flag':
1058                    decoded = True
1059                elif attr_spec.is_auto_scalar:
1060                    decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order)
1061                    if 'enum' in attr_spec:
1062                        decoded = self._decode_enum(decoded, attr_spec)
1063                elif attr_spec["type"] in NlAttr.type_formats:
1064                    decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
1065                    if 'enum' in attr_spec:
1066                        decoded = self._decode_enum(decoded, attr_spec)
1067                    elif attr_spec.display_hint:
1068                        decoded = self._formatted_string(decoded, attr_spec.display_hint)
1069                elif attr_spec["type"] == 'indexed-array':
1070                    decoded = self._decode_array_attr(attr, attr_spec)
1071                elif attr_spec["type"] == 'bitfield32':
1072                    value, selector = struct.unpack("II", attr.raw)
1073                    if 'enum' in attr_spec:
1074                        value = self._decode_enum(value, attr_spec)
1075                        selector = self._decode_enum(selector, attr_spec)
1076                    decoded = {"value": value, "selector": selector}
1077                elif attr_spec["type"] == 'sub-message':
1078                    decoded = self._decode_sub_msg(attr, attr_spec, search_attrs)
1079                elif attr_spec["type"] == 'nest-type-value':
1080                    decoded = self._decode_nest_type_value(attr, attr_spec)
1081                else:
1082                    if not self.process_unknown:
1083                        raise YnlException(f'Unknown {attr_spec["type"]} '
1084                                           f'with name {attr_spec["name"]}')
1085                    decoded = self._decode_unknown(attr)
1086
1087                self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded)
1088            except:
1089                print(f"Error decoding '{attr_spec.name}' from '{space}'")
1090                raise
1091
1092        return rsp
1093
1094    # pylint: disable=too-many-arguments, too-many-positional-arguments
1095    def _decode_extack_path(self, attrs, attr_set, offset, target, search_attrs):
1096        for attr in attrs:
1097            try:
1098                attr_spec = attr_set.attrs_by_val[attr.type]
1099            except KeyError as err:
1100                raise YnlException(
1101                    f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") from err
1102            if offset > target:
1103                break
1104            if offset == target:
1105                return '.' + attr_spec.name
1106
1107            if offset + attr.full_len <= target:
1108                offset += attr.full_len
1109                continue
1110
1111            pathname = attr_spec.name
1112            if attr_spec['type'] == 'nest':
1113                sub_attrs = self.attr_sets[attr_spec['nested-attributes']]
1114                search_attrs = SpaceAttrs(sub_attrs, search_attrs.lookup(attr_spec['name']))
1115            elif attr_spec['type'] == 'sub-message':
1116                msg_format, value = self._resolve_selector(attr_spec, search_attrs)
1117                if msg_format is None:
1118                    raise YnlException(f"Can't resolve sub-message of "
1119                                       f"{attr_spec['name']} for extack")
1120                sub_attrs = self.attr_sets[msg_format.attr_set]
1121                pathname += f"({value})"
1122            else:
1123                raise YnlException(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
1124            offset += 4
1125            subpath = self._decode_extack_path(NlAttrs(attr.raw), sub_attrs,
1126                                               offset, target, search_attrs)
1127            if subpath is None:
1128                return None
1129            return '.' + pathname + subpath
1130
1131        return None
1132
1133    def _decode_extack(self, request, op, extack, vals):
1134        if 'bad-attr-offs' not in extack:
1135            return
1136
1137        msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op)
1138        offset = self.nlproto.msghdr_size() + self.struct_size(op.fixed_header)
1139        search_attrs = SpaceAttrs(op.attr_set, vals)
1140        path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
1141                                        extack['bad-attr-offs'], search_attrs)
1142        if path:
1143            del extack['bad-attr-offs']
1144            extack['bad-attr'] = path
1145
1146    def struct_size(self, name):
1147        if name:
1148            members = self.consts[name].members
1149            size = 0
1150            for m in members:
1151                if m.type in ['pad', 'binary']:
1152                    if m.struct:
1153                        size += self.struct_size(m.struct)
1154                    else:
1155                        size += m.len
1156                else:
1157                    format_ = NlAttr.get_format(m.type, m.byte_order)
1158                    size += format_.size
1159            return size
1160        return 0
1161
1162    def _decode_struct(self, data, name):
1163        members = self.consts[name].members
1164        attrs = {}
1165        offset = 0
1166        for m in members:
1167            value = None
1168            if m.type == 'pad':
1169                offset += m.len
1170            elif m.type == 'binary':
1171                if m.struct:
1172                    len_ = self.struct_size(m.struct)
1173                    value = self._decode_struct(data[offset : offset + len_],
1174                                                m.struct)
1175                    offset += len_
1176                else:
1177                    value = data[offset : offset + m.len]
1178                    offset += m.len
1179            else:
1180                format_ = NlAttr.get_format(m.type, m.byte_order)
1181                [ value ] = format_.unpack_from(data, offset)
1182                offset += format_.size
1183            if value is not None:
1184                if m.enum:
1185                    value = self._decode_enum(value, m)
1186                elif m.display_hint:
1187                    value = self._formatted_string(value, m.display_hint)
1188                attrs[m.name] = value
1189        return attrs
1190
1191    def _encode_struct(self, name, vals):
1192        members = self.consts[name].members
1193        attr_payload = b''
1194        for m in members:
1195            value = vals.pop(m.name) if m.name in vals else None
1196            if m.type == 'pad':
1197                attr_payload += bytearray(m.len)
1198            elif m.type == 'binary':
1199                if m.struct:
1200                    if value is None:
1201                        value = {}
1202                    attr_payload += self._encode_struct(m.struct, value)
1203                else:
1204                    if value is None:
1205                        attr_payload += bytearray(m.len)
1206                    else:
1207                        attr_payload += bytes.fromhex(value)
1208            else:
1209                if value is None:
1210                    value = 0
1211                format_ = NlAttr.get_format(m.type, m.byte_order)
1212                attr_payload += format_.pack(value)
1213        return attr_payload
1214
1215    def _formatted_string(self, raw, display_hint):
1216        if display_hint == 'mac':
1217            formatted = ':'.join(f'{b:02x}' for b in raw)
1218        elif display_hint == 'hex':
1219            if isinstance(raw, int):
1220                formatted = hex(raw)
1221            else:
1222                formatted = bytes.hex(raw, ' ')
1223        elif display_hint in [ 'ipv4', 'ipv6', 'ipv4-or-v6' ]:
1224            formatted = format(ipaddress.ip_address(raw))
1225        elif display_hint == 'uuid':
1226            formatted = str(uuid.UUID(bytes=raw))
1227        else:
1228            formatted = raw
1229        return formatted
1230
1231    def _from_string(self, string, attr_spec):
1232        if attr_spec.display_hint in ['ipv4', 'ipv6', 'ipv4-or-v6']:
1233            ip = ipaddress.ip_address(string)
1234            if attr_spec['type'] == 'binary':
1235                raw = ip.packed
1236            else:
1237                raw = int(ip)
1238        elif attr_spec.display_hint == 'hex':
1239            if attr_spec['type'] == 'binary':
1240                raw = bytes.fromhex(string)
1241            else:
1242                raw = int(string, 16)
1243        elif attr_spec.display_hint == 'mac':
1244            # Parse MAC address in format "00:11:22:33:44:55" or "001122334455"
1245            if ':' in string:
1246                mac_bytes = [int(x, 16) for x in string.split(':')]
1247            else:
1248                if len(string) % 2 != 0:
1249                    raise YnlException(f"Invalid MAC address format: {string}")
1250                mac_bytes = [int(string[i:i+2], 16) for i in range(0, len(string), 2)]
1251            raw = bytes(mac_bytes)
1252        else:
1253            raise YnlException(f"Display hint '{attr_spec.display_hint}' not implemented"
1254                            f" when parsing '{attr_spec['name']}'")
1255        return raw
1256
1257    def handle_ntf(self, decoded, nsid=None):
1258        msg = {}
1259        if self.include_raw:
1260            msg['raw'] = decoded
1261        op = self.rsp_by_value[decoded.cmd()]
1262        attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
1263        if op.fixed_header:
1264            attrs.update(self._decode_struct(decoded.raw, op.fixed_header))
1265
1266        msg['name'] = op['name']
1267        msg['msg'] = attrs
1268        if nsid is not None:
1269            msg['nsid'] = nsid
1270        self.async_msg_queue.put(msg)
1271
1272    def _recvmsg(self, flags=0):
1273        reply, ancdata, _, _ = self.sock.recvmsg(self._recv_size, 4096, flags)
1274        return reply, ancdata
1275
1276    def check_ntf(self):
1277        while True:
1278            try:
1279                reply, ancdata = self._recvmsg(socket.MSG_DONTWAIT)
1280            except BlockingIOError:
1281                return
1282
1283            nsid = self._decode_nsid(ancdata)
1284            nms = NlMsgs(reply)
1285            self._recv_dbg_print(reply, nms)
1286            for nl_msg in nms:
1287                if nl_msg.error:
1288                    print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
1289                    print(nl_msg)
1290                    continue
1291                if nl_msg.done:
1292                    print("Netlink done while checking for ntf!?")
1293                    continue
1294
1295                decoded = self.nlproto.decode(self, nl_msg, None)
1296                if decoded.cmd() not in self.async_msg_ids:
1297                    print("Unexpected msg id while checking for ntf", decoded)
1298                    continue
1299
1300                self.handle_ntf(decoded, nsid)
1301
1302    def poll_ntf(self, duration=None):
1303        start_time = time.time()
1304        selector = selectors.DefaultSelector()
1305        selector.register(self.sock, selectors.EVENT_READ)
1306
1307        while True:
1308            try:
1309                yield self.async_msg_queue.get_nowait()
1310            except queue.Empty:
1311                if duration is not None:
1312                    timeout = start_time + duration - time.time()
1313                    if timeout <= 0:
1314                        return
1315                else:
1316                    timeout = None
1317                events = selector.select(timeout)
1318                if events:
1319                    self.check_ntf()
1320
1321    def operation_do_attributes(self, name):
1322        """
1323        For a given operation name, find and return a supported
1324        set of attributes (as a dict).
1325        """
1326        op = self.find_operation(name)
1327        if not op:
1328            return None
1329
1330        return op['do']['request']['attributes'].copy()
1331
1332    def _encode_message(self, op, vals, flags, req_seq):
1333        nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
1334        for flag in flags or []:
1335            nl_flags |= flag
1336
1337        msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
1338        if op.fixed_header:
1339            msg += self._encode_struct(op.fixed_header, vals)
1340        search_attrs = SpaceAttrs(op.attr_set, vals)
1341        for name, value in vals.items():
1342            msg += self._add_attr(op.attr_set.name, name, value, search_attrs)
1343        msg = _genl_msg_finalize(msg)
1344        return msg
1345
1346    # pylint: disable=too-many-statements
1347    def _ops(self, ops):
1348        reqs_by_seq = {}
1349        req_seq = random.randint(1024, 65535)
1350        payload = b''
1351        for (method, vals, flags) in ops:
1352            op = self.ops[method]
1353            msg = self._encode_message(op, vals, flags, req_seq)
1354            reqs_by_seq[req_seq] = (op, vals, msg, flags)
1355            payload += msg
1356            req_seq += 1
1357
1358        self.sock.send(payload, 0)
1359
1360        done = False
1361        rsp = []
1362        op_rsp = []
1363        while not done:
1364            reply, ancdata = self._recvmsg()
1365            nsid = self._decode_nsid(ancdata)
1366            nms = NlMsgs(reply)
1367            self._recv_dbg_print(reply, nms)
1368            for nl_msg in nms:
1369                if nl_msg.nl_seq in reqs_by_seq:
1370                    (op, vals, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq]
1371                    if nl_msg.extack:
1372                        nl_msg.annotate_extack(op.attr_set)
1373                        self._decode_extack(req_msg, op, nl_msg.extack, vals)
1374                else:
1375                    op = None
1376                    req_flags = []
1377
1378                if nl_msg.error:
1379                    raise NlError(nl_msg)
1380                if nl_msg.done:
1381                    if nl_msg.extack:
1382                        print("Netlink warning:")
1383                        print(nl_msg)
1384
1385                    if Netlink.NLM_F_DUMP in req_flags:
1386                        rsp.append(op_rsp)
1387                    elif not op_rsp:
1388                        rsp.append(None)
1389                    elif len(op_rsp) == 1:
1390                        rsp.append(op_rsp[0])
1391                    else:
1392                        rsp.append(op_rsp)
1393                    op_rsp = []
1394
1395                    del reqs_by_seq[nl_msg.nl_seq]
1396                    done = len(reqs_by_seq) == 0
1397                    break
1398
1399                decoded = self.nlproto.decode(self, nl_msg, op)
1400
1401                # Check if this is a reply to our request
1402                if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value:
1403                    if decoded.cmd() in self.async_msg_ids:
1404                        self.handle_ntf(decoded, nsid)
1405                        continue
1406                    print('Unexpected message: ' + repr(decoded))
1407                    continue
1408
1409                rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
1410                if op.fixed_header:
1411                    rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header))
1412                op_rsp.append(rsp_msg)
1413
1414        return rsp
1415
1416    def _op(self, method, vals, flags=None, dump=False):
1417        req_flags = flags or []
1418        if dump:
1419            req_flags.append(Netlink.NLM_F_DUMP)
1420
1421        ops = [(method, vals, req_flags)]
1422        return self._ops(ops)[0]
1423
1424    def do(self, method, vals, flags=None):
1425        return self._op(method, vals, flags)
1426
1427    def dump(self, method, vals):
1428        return self._op(method, vals, dump=True)
1429
1430    def do_multi(self, ops):
1431        return self._ops(ops)
1432
1433    def get_policy(self, op_name, mode):
1434        """Query running kernel for the Netlink policy of an operation.
1435
1436        Allows checking whether kernel supports a given attribute or value.
1437        This method consults the running kernel, not the YAML spec.
1438
1439        Args:
1440            op_name: operation name as it appears in the YAML spec
1441            mode: 'do' or 'dump'
1442
1443        Returns:
1444            NlPolicy acting as a read-only dict mapping attribute names
1445            to their policy properties (type, min/max, nested, etc.),
1446            or None if the operation has no policy for the given mode.
1447            Empty policy usually implies that the operation rejects
1448            all attributes.
1449        """
1450        op = self.ops[op_name]
1451        op_policy, policy_table = _genl_policy_dump(self.nlproto.family_id,
1452                                                    op.req_value)
1453        if mode not in op_policy:
1454            return None
1455        policy_idx = op_policy[mode]
1456        return NlPolicy(self, policy_idx, policy_table, op.attr_set)
1457