xref: /linux/tools/net/ynl/lib/ynl.py (revision 6ca80638b90cec66547011ee1ef79e534589989a)
1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2
3from collections import namedtuple
4import functools
5import os
6import random
7import socket
8import struct
9from struct import Struct
10import yaml
11import ipaddress
12import uuid
13
14from .nlspec import SpecFamily
15
16#
17# Generic Netlink code which should really be in some library, but I can't quickly find one.
18#
19
20
21class Netlink:
22    # Netlink socket
23    SOL_NETLINK = 270
24
25    NETLINK_ADD_MEMBERSHIP = 1
26    NETLINK_CAP_ACK = 10
27    NETLINK_EXT_ACK = 11
28    NETLINK_GET_STRICT_CHK = 12
29
30    # Netlink message
31    NLMSG_ERROR = 2
32    NLMSG_DONE = 3
33
34    NLM_F_REQUEST = 1
35    NLM_F_ACK = 4
36    NLM_F_ROOT = 0x100
37    NLM_F_MATCH = 0x200
38
39    NLM_F_REPLACE = 0x100
40    NLM_F_EXCL = 0x200
41    NLM_F_CREATE = 0x400
42    NLM_F_APPEND = 0x800
43
44    NLM_F_CAPPED = 0x100
45    NLM_F_ACK_TLVS = 0x200
46
47    NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
48
49    NLA_F_NESTED = 0x8000
50    NLA_F_NET_BYTEORDER = 0x4000
51
52    NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
53
54    # Genetlink defines
55    NETLINK_GENERIC = 16
56
57    GENL_ID_CTRL = 0x10
58
59    # nlctrl
60    CTRL_CMD_GETFAMILY = 3
61
62    CTRL_ATTR_FAMILY_ID = 1
63    CTRL_ATTR_FAMILY_NAME = 2
64    CTRL_ATTR_MAXATTR = 5
65    CTRL_ATTR_MCAST_GROUPS = 7
66
67    CTRL_ATTR_MCAST_GRP_NAME = 1
68    CTRL_ATTR_MCAST_GRP_ID = 2
69
70    # Extack types
71    NLMSGERR_ATTR_MSG = 1
72    NLMSGERR_ATTR_OFFS = 2
73    NLMSGERR_ATTR_COOKIE = 3
74    NLMSGERR_ATTR_POLICY = 4
75    NLMSGERR_ATTR_MISS_TYPE = 5
76    NLMSGERR_ATTR_MISS_NEST = 6
77
78
79class NlError(Exception):
80  def __init__(self, nl_msg):
81    self.nl_msg = nl_msg
82
83  def __str__(self):
84    return f"Netlink error: {os.strerror(-self.nl_msg.error)}\n{self.nl_msg}"
85
86
87class NlAttr:
88    ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
89    type_formats = {
90        'u8' : ScalarFormat(Struct('B'), Struct("B"),  Struct("B")),
91        's8' : ScalarFormat(Struct('b'), Struct("b"),  Struct("b")),
92        'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
93        's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
94        'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
95        's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
96        'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
97        's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
98    }
99
100    def __init__(self, raw, offset):
101        self._len, self._type = struct.unpack("HH", raw[offset:offset + 4])
102        self.type = self._type & ~Netlink.NLA_TYPE_MASK
103        self.payload_len = self._len
104        self.full_len = (self.payload_len + 3) & ~3
105        self.raw = raw[offset + 4:offset + self.payload_len]
106
107    @classmethod
108    def get_format(cls, attr_type, byte_order=None):
109        format = cls.type_formats[attr_type]
110        if byte_order:
111            return format.big if byte_order == "big-endian" \
112                else format.little
113        return format.native
114
115    @classmethod
116    def formatted_string(cls, raw, display_hint):
117        if display_hint == 'mac':
118            formatted = ':'.join('%02x' % b for b in raw)
119        elif display_hint == 'hex':
120            formatted = bytes.hex(raw, ' ')
121        elif display_hint in [ 'ipv4', 'ipv6' ]:
122            formatted = format(ipaddress.ip_address(raw))
123        elif display_hint == 'uuid':
124            formatted = str(uuid.UUID(bytes=raw))
125        else:
126            formatted = raw
127        return formatted
128
129    def as_scalar(self, attr_type, byte_order=None):
130        format = self.get_format(attr_type, byte_order)
131        return format.unpack(self.raw)[0]
132
133    def as_auto_scalar(self, attr_type, byte_order=None):
134        if len(self.raw) != 4 and len(self.raw) != 8:
135            raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}")
136        real_type = attr_type[0] + str(len(self.raw) * 8)
137        format = self.get_format(real_type, byte_order)
138        return format.unpack(self.raw)[0]
139
140    def as_strz(self):
141        return self.raw.decode('ascii')[:-1]
142
143    def as_bin(self):
144        return self.raw
145
146    def as_c_array(self, type):
147        format = self.get_format(type)
148        return [ x[0] for x in format.iter_unpack(self.raw) ]
149
150    def as_struct(self, members):
151        value = dict()
152        offset = 0
153        for m in members:
154            # TODO: handle non-scalar members
155            if m.type == 'binary':
156                decoded = self.raw[offset:offset+m['len']]
157                offset += m['len']
158            elif m.type in NlAttr.type_formats:
159                format = self.get_format(m.type, m.byte_order)
160                [ decoded ] = format.unpack_from(self.raw, offset)
161                offset += format.size
162            if m.display_hint:
163                decoded = self.formatted_string(decoded, m.display_hint)
164            value[m.name] = decoded
165        return value
166
167    def __repr__(self):
168        return f"[type:{self.type} len:{self._len}] {self.raw}"
169
170
171class NlAttrs:
172    def __init__(self, msg):
173        self.attrs = []
174
175        offset = 0
176        while offset < len(msg):
177            attr = NlAttr(msg, offset)
178            offset += attr.full_len
179            self.attrs.append(attr)
180
181    def __iter__(self):
182        yield from self.attrs
183
184    def __repr__(self):
185        msg = ''
186        for a in self.attrs:
187            if msg:
188                msg += '\n'
189            msg += repr(a)
190        return msg
191
192
193class NlMsg:
194    def __init__(self, msg, offset, attr_space=None):
195        self.hdr = msg[offset:offset + 16]
196
197        self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
198            struct.unpack("IHHII", self.hdr)
199
200        self.raw = msg[offset + 16:offset + self.nl_len]
201
202        self.error = 0
203        self.done = 0
204
205        extack_off = None
206        if self.nl_type == Netlink.NLMSG_ERROR:
207            self.error = struct.unpack("i", self.raw[0:4])[0]
208            self.done = 1
209            extack_off = 20
210        elif self.nl_type == Netlink.NLMSG_DONE:
211            self.done = 1
212            extack_off = 4
213
214        self.extack = None
215        if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
216            self.extack = dict()
217            extack_attrs = NlAttrs(self.raw[extack_off:])
218            for extack in extack_attrs:
219                if extack.type == Netlink.NLMSGERR_ATTR_MSG:
220                    self.extack['msg'] = extack.as_strz()
221                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
222                    self.extack['miss-type'] = extack.as_scalar('u32')
223                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
224                    self.extack['miss-nest'] = extack.as_scalar('u32')
225                elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
226                    self.extack['bad-attr-offs'] = extack.as_scalar('u32')
227                else:
228                    if 'unknown' not in self.extack:
229                        self.extack['unknown'] = []
230                    self.extack['unknown'].append(extack)
231
232            if attr_space:
233                # We don't have the ability to parse nests yet, so only do global
234                if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
235                    miss_type = self.extack['miss-type']
236                    if miss_type in attr_space.attrs_by_val:
237                        spec = attr_space.attrs_by_val[miss_type]
238                        desc = spec['name']
239                        if 'doc' in spec:
240                            desc += f" ({spec['doc']})"
241                        self.extack['miss-type'] = desc
242
243    def cmd(self):
244        return self.nl_type
245
246    def __repr__(self):
247        msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}\n"
248        if self.error:
249            msg += '\terror: ' + str(self.error)
250        if self.extack:
251            msg += '\textack: ' + repr(self.extack)
252        return msg
253
254
255class NlMsgs:
256    def __init__(self, data, attr_space=None):
257        self.msgs = []
258
259        offset = 0
260        while offset < len(data):
261            msg = NlMsg(data, offset, attr_space=attr_space)
262            offset += msg.nl_len
263            self.msgs.append(msg)
264
265    def __iter__(self):
266        yield from self.msgs
267
268
269genl_family_name_to_id = None
270
271
272def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
273    # we prepend length in _genl_msg_finalize()
274    if seq is None:
275        seq = random.randint(1, 1024)
276    nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
277    genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
278    return nlmsg + genlmsg
279
280
281def _genl_msg_finalize(msg):
282    return struct.pack("I", len(msg) + 4) + msg
283
284
285def _genl_load_families():
286    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
287        sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
288
289        msg = _genl_msg(Netlink.GENL_ID_CTRL,
290                        Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
291                        Netlink.CTRL_CMD_GETFAMILY, 1)
292        msg = _genl_msg_finalize(msg)
293
294        sock.send(msg, 0)
295
296        global genl_family_name_to_id
297        genl_family_name_to_id = dict()
298
299        while True:
300            reply = sock.recv(128 * 1024)
301            nms = NlMsgs(reply)
302            for nl_msg in nms:
303                if nl_msg.error:
304                    print("Netlink error:", nl_msg.error)
305                    return
306                if nl_msg.done:
307                    return
308
309                gm = GenlMsg(nl_msg)
310                fam = dict()
311                for attr in NlAttrs(gm.raw):
312                    if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
313                        fam['id'] = attr.as_scalar('u16')
314                    elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
315                        fam['name'] = attr.as_strz()
316                    elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
317                        fam['maxattr'] = attr.as_scalar('u32')
318                    elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
319                        fam['mcast'] = dict()
320                        for entry in NlAttrs(attr.raw):
321                            mcast_name = None
322                            mcast_id = None
323                            for entry_attr in NlAttrs(entry.raw):
324                                if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
325                                    mcast_name = entry_attr.as_strz()
326                                elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
327                                    mcast_id = entry_attr.as_scalar('u32')
328                            if mcast_name and mcast_id is not None:
329                                fam['mcast'][mcast_name] = mcast_id
330                if 'name' in fam and 'id' in fam:
331                    genl_family_name_to_id[fam['name']] = fam
332
333
334class GenlMsg:
335    def __init__(self, nl_msg):
336        self.nl = nl_msg
337        self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
338        self.raw = nl_msg.raw[4:]
339
340    def cmd(self):
341        return self.genl_cmd
342
343    def __repr__(self):
344        msg = repr(self.nl)
345        msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
346        for a in self.raw_attrs:
347            msg += '\t\t' + repr(a) + '\n'
348        return msg
349
350
351class NetlinkProtocol:
352    def __init__(self, family_name, proto_num):
353        self.family_name = family_name
354        self.proto_num = proto_num
355
356    def _message(self, nl_type, nl_flags, seq=None):
357        if seq is None:
358            seq = random.randint(1, 1024)
359        nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
360        return nlmsg
361
362    def message(self, flags, command, version, seq=None):
363        return self._message(command, flags, seq)
364
365    def _decode(self, nl_msg):
366        return nl_msg
367
368    def decode(self, ynl, nl_msg):
369        msg = self._decode(nl_msg)
370        fixed_header_size = 0
371        if ynl:
372            op = ynl.rsp_by_value[msg.cmd()]
373            fixed_header_size = ynl._fixed_header_size(op)
374        msg.raw_attrs = NlAttrs(msg.raw[fixed_header_size:])
375        return msg
376
377    def get_mcast_id(self, mcast_name, mcast_groups):
378        if mcast_name not in mcast_groups:
379            raise Exception(f'Multicast group "{mcast_name}" not present in the spec')
380        return mcast_groups[mcast_name].value
381
382
383class GenlProtocol(NetlinkProtocol):
384    def __init__(self, family_name):
385        super().__init__(family_name, Netlink.NETLINK_GENERIC)
386
387        global genl_family_name_to_id
388        if genl_family_name_to_id is None:
389            _genl_load_families()
390
391        self.genl_family = genl_family_name_to_id[family_name]
392        self.family_id = genl_family_name_to_id[family_name]['id']
393
394    def message(self, flags, command, version, seq=None):
395        nlmsg = self._message(self.family_id, flags, seq)
396        genlmsg = struct.pack("BBH", command, version, 0)
397        return nlmsg + genlmsg
398
399    def _decode(self, nl_msg):
400        return GenlMsg(nl_msg)
401
402    def get_mcast_id(self, mcast_name, mcast_groups):
403        if mcast_name not in self.genl_family['mcast']:
404            raise Exception(f'Multicast group "{mcast_name}" not present in the family')
405        return self.genl_family['mcast'][mcast_name]
406
407
408#
409# YNL implementation details.
410#
411
412
413class YnlFamily(SpecFamily):
414    def __init__(self, def_path, schema=None):
415        super().__init__(def_path, schema)
416
417        self.include_raw = False
418
419        try:
420            if self.proto == "netlink-raw":
421                self.nlproto = NetlinkProtocol(self.yaml['name'],
422                                               self.yaml['protonum'])
423            else:
424                self.nlproto = GenlProtocol(self.yaml['name'])
425        except KeyError:
426            raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
427
428        self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
429        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
430        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
431        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)
432
433        self.async_msg_ids = set()
434        self.async_msg_queue = []
435
436        for msg in self.msgs.values():
437            if msg.is_async:
438                self.async_msg_ids.add(msg.rsp_value)
439
440        for op_name, op in self.ops.items():
441            bound_f = functools.partial(self._op, op_name)
442            setattr(self, op.ident_name, bound_f)
443
444
445    def ntf_subscribe(self, mcast_name):
446        mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
447        self.sock.bind((0, 0))
448        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
449                             mcast_id)
450
451    def _add_attr(self, space, name, value):
452        try:
453            attr = self.attr_sets[space][name]
454        except KeyError:
455            raise Exception(f"Space '{space}' has no attribute '{name}'")
456        nl_type = attr.value
457        if attr["type"] == 'nest':
458            nl_type |= Netlink.NLA_F_NESTED
459            attr_payload = b''
460            for subname, subvalue in value.items():
461                attr_payload += self._add_attr(attr['nested-attributes'], subname, subvalue)
462        elif attr["type"] == 'flag':
463            attr_payload = b''
464        elif attr["type"] == 'string':
465            attr_payload = str(value).encode('ascii') + b'\x00'
466        elif attr["type"] == 'binary':
467            if isinstance(value, bytes):
468                attr_payload = value
469            elif isinstance(value, str):
470                attr_payload = bytes.fromhex(value)
471            else:
472                raise Exception(f'Unknown type for binary attribute, value: {value}')
473        elif attr.is_auto_scalar:
474            scalar = int(value)
475            real_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64')
476            format = NlAttr.get_format(real_type, attr.byte_order)
477            attr_payload = format.pack(int(value))
478        elif attr['type'] in NlAttr.type_formats:
479            format = NlAttr.get_format(attr['type'], attr.byte_order)
480            attr_payload = format.pack(int(value))
481        elif attr['type'] in "bitfield32":
482            attr_payload = struct.pack("II", int(value["value"]), int(value["selector"]))
483        else:
484            raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
485
486        pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
487        return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
488
489    def _decode_enum(self, raw, attr_spec):
490        enum = self.consts[attr_spec['enum']]
491        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
492            i = 0
493            value = set()
494            while raw:
495                if raw & 1:
496                    value.add(enum.entries_by_val[i].name)
497                raw >>= 1
498                i += 1
499        else:
500            value = enum.entries_by_val[raw].name
501        return value
502
503    def _decode_binary(self, attr, attr_spec):
504        if attr_spec.struct_name:
505            members = self.consts[attr_spec.struct_name]
506            decoded = attr.as_struct(members)
507            for m in members:
508                if m.enum:
509                    decoded[m.name] = self._decode_enum(decoded[m.name], m)
510        elif attr_spec.sub_type:
511            decoded = attr.as_c_array(attr_spec.sub_type)
512        else:
513            decoded = attr.as_bin()
514            if attr_spec.display_hint:
515                decoded = NlAttr.formatted_string(decoded, attr_spec.display_hint)
516        return decoded
517
518    def _decode_array_nest(self, attr, attr_spec):
519        decoded = []
520        offset = 0
521        while offset < len(attr.raw):
522            item = NlAttr(attr.raw, offset)
523            offset += item.full_len
524
525            subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes'])
526            decoded.append({ item.type: subattrs })
527        return decoded
528
529    def _decode(self, attrs, space):
530        attr_space = self.attr_sets[space]
531        rsp = dict()
532        for attr in attrs:
533            try:
534                attr_spec = attr_space.attrs_by_val[attr.type]
535            except KeyError:
536                raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'")
537            if attr_spec["type"] == 'nest':
538                subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'])
539                decoded = subdict
540            elif attr_spec["type"] == 'string':
541                decoded = attr.as_strz()
542            elif attr_spec["type"] == 'binary':
543                decoded = self._decode_binary(attr, attr_spec)
544            elif attr_spec["type"] == 'flag':
545                decoded = True
546            elif attr_spec.is_auto_scalar:
547                decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order)
548            elif attr_spec["type"] in NlAttr.type_formats:
549                decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
550                if 'enum' in attr_spec:
551                    decoded = self._decode_enum(decoded, attr_spec)
552            elif attr_spec["type"] == 'array-nest':
553                decoded = self._decode_array_nest(attr, attr_spec)
554            elif attr_spec["type"] == 'bitfield32':
555                value, selector = struct.unpack("II", attr.raw)
556                if 'enum' in attr_spec:
557                    value = self._decode_enum(value, attr_spec)
558                    selector = self._decode_enum(selector, attr_spec)
559                decoded = {"value": value, "selector": selector}
560            else:
561                raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
562
563            if not attr_spec.is_multi:
564                rsp[attr_spec['name']] = decoded
565            elif attr_spec.name in rsp:
566                rsp[attr_spec.name].append(decoded)
567            else:
568                rsp[attr_spec.name] = [decoded]
569
570        return rsp
571
572    def _decode_extack_path(self, attrs, attr_set, offset, target):
573        for attr in attrs:
574            try:
575                attr_spec = attr_set.attrs_by_val[attr.type]
576            except KeyError:
577                raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'")
578            if offset > target:
579                break
580            if offset == target:
581                return '.' + attr_spec.name
582
583            if offset + attr.full_len <= target:
584                offset += attr.full_len
585                continue
586            if attr_spec['type'] != 'nest':
587                raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
588            offset += 4
589            subpath = self._decode_extack_path(NlAttrs(attr.raw),
590                                               self.attr_sets[attr_spec['nested-attributes']],
591                                               offset, target)
592            if subpath is None:
593                return None
594            return '.' + attr_spec.name + subpath
595
596        return None
597
598    def _decode_extack(self, request, op, extack):
599        if 'bad-attr-offs' not in extack:
600            return
601
602        msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set))
603        offset = 20 + self._fixed_header_size(op)
604        path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
605                                        extack['bad-attr-offs'])
606        if path:
607            del extack['bad-attr-offs']
608            extack['bad-attr'] = path
609
610    def _fixed_header_size(self, op):
611        if op.fixed_header:
612            fixed_header_members = self.consts[op.fixed_header].members
613            size = 0
614            for m in fixed_header_members:
615                format = NlAttr.get_format(m.type, m.byte_order)
616                size += format.size
617            return size
618        else:
619            return 0
620
621    def _decode_fixed_header(self, msg, name):
622        fixed_header_members = self.consts[name].members
623        fixed_header_attrs = dict()
624        offset = 0
625        for m in fixed_header_members:
626            format = NlAttr.get_format(m.type, m.byte_order)
627            [ value ] = format.unpack_from(msg.raw, offset)
628            offset += format.size
629            if m.enum:
630                value = self._decode_enum(value, m)
631            fixed_header_attrs[m.name] = value
632        return fixed_header_attrs
633
634    def handle_ntf(self, decoded):
635        msg = dict()
636        if self.include_raw:
637            msg['raw'] = decoded
638        op = self.rsp_by_value[decoded.cmd()]
639        attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
640        if op.fixed_header:
641            attrs.update(self._decode_fixed_header(decoded, op.fixed_header))
642
643        msg['name'] = op['name']
644        msg['msg'] = attrs
645        self.async_msg_queue.append(msg)
646
647    def check_ntf(self):
648        while True:
649            try:
650                reply = self.sock.recv(128 * 1024, socket.MSG_DONTWAIT)
651            except BlockingIOError:
652                return
653
654            nms = NlMsgs(reply)
655            for nl_msg in nms:
656                if nl_msg.error:
657                    print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
658                    print(nl_msg)
659                    continue
660                if nl_msg.done:
661                    print("Netlink done while checking for ntf!?")
662                    continue
663
664                decoded = self.nlproto.decode(self, nl_msg)
665                if decoded.cmd() not in self.async_msg_ids:
666                    print("Unexpected msg id done while checking for ntf", decoded)
667                    continue
668
669                self.handle_ntf(decoded)
670
671    def operation_do_attributes(self, name):
672      """
673      For a given operation name, find and return a supported
674      set of attributes (as a dict).
675      """
676      op = self.find_operation(name)
677      if not op:
678        return None
679
680      return op['do']['request']['attributes'].copy()
681
682    def _op(self, method, vals, flags, dump=False):
683        op = self.ops[method]
684
685        nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
686        for flag in flags or []:
687            nl_flags |= flag
688        if dump:
689            nl_flags |= Netlink.NLM_F_DUMP
690
691        req_seq = random.randint(1024, 65535)
692        msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
693        fixed_header_members = []
694        if op.fixed_header:
695            fixed_header_members = self.consts[op.fixed_header].members
696            for m in fixed_header_members:
697                value = vals.pop(m.name) if m.name in vals else 0
698                format = NlAttr.get_format(m.type, m.byte_order)
699                msg += format.pack(value)
700        for name, value in vals.items():
701            msg += self._add_attr(op.attr_set.name, name, value)
702        msg = _genl_msg_finalize(msg)
703
704        self.sock.send(msg, 0)
705
706        done = False
707        rsp = []
708        while not done:
709            reply = self.sock.recv(128 * 1024)
710            nms = NlMsgs(reply, attr_space=op.attr_set)
711            for nl_msg in nms:
712                if nl_msg.extack:
713                    self._decode_extack(msg, op, nl_msg.extack)
714
715                if nl_msg.error:
716                    raise NlError(nl_msg)
717                if nl_msg.done:
718                    if nl_msg.extack:
719                        print("Netlink warning:")
720                        print(nl_msg)
721                    done = True
722                    break
723
724                decoded = self.nlproto.decode(self, nl_msg)
725
726                # Check if this is a reply to our request
727                if nl_msg.nl_seq != req_seq or decoded.cmd() != op.rsp_value:
728                    if decoded.cmd() in self.async_msg_ids:
729                        self.handle_ntf(decoded)
730                        continue
731                    else:
732                        print('Unexpected message: ' + repr(decoded))
733                        continue
734
735                rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
736                if op.fixed_header:
737                    rsp_msg.update(self._decode_fixed_header(decoded, op.fixed_header))
738                rsp.append(rsp_msg)
739
740        if not rsp:
741            return None
742        if not dump and len(rsp) == 1:
743            return rsp[0]
744        return rsp
745
746    def do(self, method, vals, flags):
747        return self._op(method, vals, flags)
748
749    def dump(self, method, vals):
750        return self._op(method, vals, [], dump=True)
751