xref: /linux/tools/net/ynl/lib/ynl.py (revision 3d0fe49454652117522f60bfbefb978ba0e5300b)
1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2
3from collections import namedtuple
4import functools
5import os
6import random
7import socket
8import struct
9from struct import Struct
10import yaml
11import ipaddress
12import uuid
13
14from .nlspec import SpecFamily
15
16#
17# Generic Netlink code which should really be in some library, but I can't quickly find one.
18#
19
20
21class Netlink:
22    # Netlink socket
23    SOL_NETLINK = 270
24
25    NETLINK_ADD_MEMBERSHIP = 1
26    NETLINK_CAP_ACK = 10
27    NETLINK_EXT_ACK = 11
28    NETLINK_GET_STRICT_CHK = 12
29
30    # Netlink message
31    NLMSG_ERROR = 2
32    NLMSG_DONE = 3
33
34    NLM_F_REQUEST = 1
35    NLM_F_ACK = 4
36    NLM_F_ROOT = 0x100
37    NLM_F_MATCH = 0x200
38
39    NLM_F_REPLACE = 0x100
40    NLM_F_EXCL = 0x200
41    NLM_F_CREATE = 0x400
42    NLM_F_APPEND = 0x800
43
44    NLM_F_CAPPED = 0x100
45    NLM_F_ACK_TLVS = 0x200
46
47    NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
48
49    NLA_F_NESTED = 0x8000
50    NLA_F_NET_BYTEORDER = 0x4000
51
52    NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
53
54    # Genetlink defines
55    NETLINK_GENERIC = 16
56
57    GENL_ID_CTRL = 0x10
58
59    # nlctrl
60    CTRL_CMD_GETFAMILY = 3
61
62    CTRL_ATTR_FAMILY_ID = 1
63    CTRL_ATTR_FAMILY_NAME = 2
64    CTRL_ATTR_MAXATTR = 5
65    CTRL_ATTR_MCAST_GROUPS = 7
66
67    CTRL_ATTR_MCAST_GRP_NAME = 1
68    CTRL_ATTR_MCAST_GRP_ID = 2
69
70    # Extack types
71    NLMSGERR_ATTR_MSG = 1
72    NLMSGERR_ATTR_OFFS = 2
73    NLMSGERR_ATTR_COOKIE = 3
74    NLMSGERR_ATTR_POLICY = 4
75    NLMSGERR_ATTR_MISS_TYPE = 5
76    NLMSGERR_ATTR_MISS_NEST = 6
77
78
79class NlError(Exception):
80  def __init__(self, nl_msg):
81    self.nl_msg = nl_msg
82
83  def __str__(self):
84    return f"Netlink error: {os.strerror(-self.nl_msg.error)}\n{self.nl_msg}"
85
86
87class NlAttr:
88    ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
89    type_formats = {
90        'u8' : ScalarFormat(Struct('B'), Struct("B"),  Struct("B")),
91        's8' : ScalarFormat(Struct('b'), Struct("b"),  Struct("b")),
92        'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
93        's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
94        'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
95        's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
96        'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
97        's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
98    }
99
100    def __init__(self, raw, offset):
101        self._len, self._type = struct.unpack("HH", raw[offset:offset + 4])
102        self.type = self._type & ~Netlink.NLA_TYPE_MASK
103        self.is_nest = self._type & Netlink.NLA_F_NESTED
104        self.payload_len = self._len
105        self.full_len = (self.payload_len + 3) & ~3
106        self.raw = raw[offset + 4:offset + self.payload_len]
107
108    @classmethod
109    def get_format(cls, attr_type, byte_order=None):
110        format = cls.type_formats[attr_type]
111        if byte_order:
112            return format.big if byte_order == "big-endian" \
113                else format.little
114        return format.native
115
116    @classmethod
117    def formatted_string(cls, raw, display_hint):
118        if display_hint == 'mac':
119            formatted = ':'.join('%02x' % b for b in raw)
120        elif display_hint == 'hex':
121            formatted = bytes.hex(raw, ' ')
122        elif display_hint in [ 'ipv4', 'ipv6' ]:
123            formatted = format(ipaddress.ip_address(raw))
124        elif display_hint == 'uuid':
125            formatted = str(uuid.UUID(bytes=raw))
126        else:
127            formatted = raw
128        return formatted
129
130    def as_scalar(self, attr_type, byte_order=None):
131        format = self.get_format(attr_type, byte_order)
132        return format.unpack(self.raw)[0]
133
134    def as_auto_scalar(self, attr_type, byte_order=None):
135        if len(self.raw) != 4 and len(self.raw) != 8:
136            raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}")
137        real_type = attr_type[0] + str(len(self.raw) * 8)
138        format = self.get_format(real_type, byte_order)
139        return format.unpack(self.raw)[0]
140
141    def as_strz(self):
142        return self.raw.decode('ascii')[:-1]
143
144    def as_bin(self):
145        return self.raw
146
147    def as_c_array(self, type):
148        format = self.get_format(type)
149        return [ x[0] for x in format.iter_unpack(self.raw) ]
150
151    def as_struct(self, members):
152        value = dict()
153        offset = 0
154        for m in members:
155            # TODO: handle non-scalar members
156            if m.type == 'binary':
157                decoded = self.raw[offset:offset+m['len']]
158                offset += m['len']
159            elif m.type in NlAttr.type_formats:
160                format = self.get_format(m.type, m.byte_order)
161                [ decoded ] = format.unpack_from(self.raw, offset)
162                offset += format.size
163            if m.display_hint:
164                decoded = self.formatted_string(decoded, m.display_hint)
165            value[m.name] = decoded
166        return value
167
168    def __repr__(self):
169        return f"[type:{self.type} len:{self._len}] {self.raw}"
170
171
172class NlAttrs:
173    def __init__(self, msg):
174        self.attrs = []
175
176        offset = 0
177        while offset < len(msg):
178            attr = NlAttr(msg, offset)
179            offset += attr.full_len
180            self.attrs.append(attr)
181
182    def __iter__(self):
183        yield from self.attrs
184
185    def __repr__(self):
186        msg = ''
187        for a in self.attrs:
188            if msg:
189                msg += '\n'
190            msg += repr(a)
191        return msg
192
193
194class NlMsg:
195    def __init__(self, msg, offset, attr_space=None):
196        self.hdr = msg[offset:offset + 16]
197
198        self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
199            struct.unpack("IHHII", self.hdr)
200
201        self.raw = msg[offset + 16:offset + self.nl_len]
202
203        self.error = 0
204        self.done = 0
205
206        extack_off = None
207        if self.nl_type == Netlink.NLMSG_ERROR:
208            self.error = struct.unpack("i", self.raw[0:4])[0]
209            self.done = 1
210            extack_off = 20
211        elif self.nl_type == Netlink.NLMSG_DONE:
212            self.done = 1
213            extack_off = 4
214
215        self.extack = None
216        if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
217            self.extack = dict()
218            extack_attrs = NlAttrs(self.raw[extack_off:])
219            for extack in extack_attrs:
220                if extack.type == Netlink.NLMSGERR_ATTR_MSG:
221                    self.extack['msg'] = extack.as_strz()
222                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
223                    self.extack['miss-type'] = extack.as_scalar('u32')
224                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
225                    self.extack['miss-nest'] = extack.as_scalar('u32')
226                elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
227                    self.extack['bad-attr-offs'] = extack.as_scalar('u32')
228                else:
229                    if 'unknown' not in self.extack:
230                        self.extack['unknown'] = []
231                    self.extack['unknown'].append(extack)
232
233            if attr_space:
234                # We don't have the ability to parse nests yet, so only do global
235                if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
236                    miss_type = self.extack['miss-type']
237                    if miss_type in attr_space.attrs_by_val:
238                        spec = attr_space.attrs_by_val[miss_type]
239                        desc = spec['name']
240                        if 'doc' in spec:
241                            desc += f" ({spec['doc']})"
242                        self.extack['miss-type'] = desc
243
244    def cmd(self):
245        return self.nl_type
246
247    def __repr__(self):
248        msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}\n"
249        if self.error:
250            msg += '\terror: ' + str(self.error)
251        if self.extack:
252            msg += '\textack: ' + repr(self.extack)
253        return msg
254
255
256class NlMsgs:
257    def __init__(self, data, attr_space=None):
258        self.msgs = []
259
260        offset = 0
261        while offset < len(data):
262            msg = NlMsg(data, offset, attr_space=attr_space)
263            offset += msg.nl_len
264            self.msgs.append(msg)
265
266    def __iter__(self):
267        yield from self.msgs
268
269
270genl_family_name_to_id = None
271
272
273def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
274    # we prepend length in _genl_msg_finalize()
275    if seq is None:
276        seq = random.randint(1, 1024)
277    nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
278    genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
279    return nlmsg + genlmsg
280
281
282def _genl_msg_finalize(msg):
283    return struct.pack("I", len(msg) + 4) + msg
284
285
286def _genl_load_families():
287    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
288        sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
289
290        msg = _genl_msg(Netlink.GENL_ID_CTRL,
291                        Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
292                        Netlink.CTRL_CMD_GETFAMILY, 1)
293        msg = _genl_msg_finalize(msg)
294
295        sock.send(msg, 0)
296
297        global genl_family_name_to_id
298        genl_family_name_to_id = dict()
299
300        while True:
301            reply = sock.recv(128 * 1024)
302            nms = NlMsgs(reply)
303            for nl_msg in nms:
304                if nl_msg.error:
305                    print("Netlink error:", nl_msg.error)
306                    return
307                if nl_msg.done:
308                    return
309
310                gm = GenlMsg(nl_msg)
311                fam = dict()
312                for attr in NlAttrs(gm.raw):
313                    if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
314                        fam['id'] = attr.as_scalar('u16')
315                    elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
316                        fam['name'] = attr.as_strz()
317                    elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
318                        fam['maxattr'] = attr.as_scalar('u32')
319                    elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
320                        fam['mcast'] = dict()
321                        for entry in NlAttrs(attr.raw):
322                            mcast_name = None
323                            mcast_id = None
324                            for entry_attr in NlAttrs(entry.raw):
325                                if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
326                                    mcast_name = entry_attr.as_strz()
327                                elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
328                                    mcast_id = entry_attr.as_scalar('u32')
329                            if mcast_name and mcast_id is not None:
330                                fam['mcast'][mcast_name] = mcast_id
331                if 'name' in fam and 'id' in fam:
332                    genl_family_name_to_id[fam['name']] = fam
333
334
335class GenlMsg:
336    def __init__(self, nl_msg):
337        self.nl = nl_msg
338        self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
339        self.raw = nl_msg.raw[4:]
340
341    def cmd(self):
342        return self.genl_cmd
343
344    def __repr__(self):
345        msg = repr(self.nl)
346        msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
347        for a in self.raw_attrs:
348            msg += '\t\t' + repr(a) + '\n'
349        return msg
350
351
352class NetlinkProtocol:
353    def __init__(self, family_name, proto_num):
354        self.family_name = family_name
355        self.proto_num = proto_num
356
357    def _message(self, nl_type, nl_flags, seq=None):
358        if seq is None:
359            seq = random.randint(1, 1024)
360        nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
361        return nlmsg
362
363    def message(self, flags, command, version, seq=None):
364        return self._message(command, flags, seq)
365
366    def _decode(self, nl_msg):
367        return nl_msg
368
369    def decode(self, ynl, nl_msg):
370        msg = self._decode(nl_msg)
371        fixed_header_size = 0
372        if ynl:
373            op = ynl.rsp_by_value[msg.cmd()]
374            fixed_header_size = ynl._fixed_header_size(op)
375        msg.raw_attrs = NlAttrs(msg.raw[fixed_header_size:])
376        return msg
377
378    def get_mcast_id(self, mcast_name, mcast_groups):
379        if mcast_name not in mcast_groups:
380            raise Exception(f'Multicast group "{mcast_name}" not present in the spec')
381        return mcast_groups[mcast_name].value
382
383
384class GenlProtocol(NetlinkProtocol):
385    def __init__(self, family_name):
386        super().__init__(family_name, Netlink.NETLINK_GENERIC)
387
388        global genl_family_name_to_id
389        if genl_family_name_to_id is None:
390            _genl_load_families()
391
392        self.genl_family = genl_family_name_to_id[family_name]
393        self.family_id = genl_family_name_to_id[family_name]['id']
394
395    def message(self, flags, command, version, seq=None):
396        nlmsg = self._message(self.family_id, flags, seq)
397        genlmsg = struct.pack("BBH", command, version, 0)
398        return nlmsg + genlmsg
399
400    def _decode(self, nl_msg):
401        return GenlMsg(nl_msg)
402
403    def get_mcast_id(self, mcast_name, mcast_groups):
404        if mcast_name not in self.genl_family['mcast']:
405            raise Exception(f'Multicast group "{mcast_name}" not present in the family')
406        return self.genl_family['mcast'][mcast_name]
407
408
409#
410# YNL implementation details.
411#
412
413
414class YnlFamily(SpecFamily):
415    def __init__(self, def_path, schema=None, process_unknown=False):
416        super().__init__(def_path, schema)
417
418        self.include_raw = False
419        self.process_unknown = process_unknown
420
421        try:
422            if self.proto == "netlink-raw":
423                self.nlproto = NetlinkProtocol(self.yaml['name'],
424                                               self.yaml['protonum'])
425            else:
426                self.nlproto = GenlProtocol(self.yaml['name'])
427        except KeyError:
428            raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
429
430        self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
431        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
432        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
433        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)
434
435        self.async_msg_ids = set()
436        self.async_msg_queue = []
437
438        for msg in self.msgs.values():
439            if msg.is_async:
440                self.async_msg_ids.add(msg.rsp_value)
441
442        for op_name, op in self.ops.items():
443            bound_f = functools.partial(self._op, op_name)
444            setattr(self, op.ident_name, bound_f)
445
446
447    def ntf_subscribe(self, mcast_name):
448        mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
449        self.sock.bind((0, 0))
450        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
451                             mcast_id)
452
453    def _add_attr(self, space, name, value):
454        try:
455            attr = self.attr_sets[space][name]
456        except KeyError:
457            raise Exception(f"Space '{space}' has no attribute '{name}'")
458        nl_type = attr.value
459        if attr["type"] == 'nest':
460            nl_type |= Netlink.NLA_F_NESTED
461            attr_payload = b''
462            for subname, subvalue in value.items():
463                attr_payload += self._add_attr(attr['nested-attributes'], subname, subvalue)
464        elif attr["type"] == 'flag':
465            attr_payload = b''
466        elif attr["type"] == 'string':
467            attr_payload = str(value).encode('ascii') + b'\x00'
468        elif attr["type"] == 'binary':
469            if isinstance(value, bytes):
470                attr_payload = value
471            elif isinstance(value, str):
472                attr_payload = bytes.fromhex(value)
473            else:
474                raise Exception(f'Unknown type for binary attribute, value: {value}')
475        elif attr.is_auto_scalar:
476            scalar = int(value)
477            real_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64')
478            format = NlAttr.get_format(real_type, attr.byte_order)
479            attr_payload = format.pack(int(value))
480        elif attr['type'] in NlAttr.type_formats:
481            format = NlAttr.get_format(attr['type'], attr.byte_order)
482            attr_payload = format.pack(int(value))
483        elif attr['type'] in "bitfield32":
484            attr_payload = struct.pack("II", int(value["value"]), int(value["selector"]))
485        else:
486            raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
487
488        pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
489        return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
490
491    def _decode_enum(self, raw, attr_spec):
492        enum = self.consts[attr_spec['enum']]
493        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
494            i = 0
495            value = set()
496            while raw:
497                if raw & 1:
498                    value.add(enum.entries_by_val[i].name)
499                raw >>= 1
500                i += 1
501        else:
502            value = enum.entries_by_val[raw].name
503        return value
504
505    def _decode_binary(self, attr, attr_spec):
506        if attr_spec.struct_name:
507            members = self.consts[attr_spec.struct_name]
508            decoded = attr.as_struct(members)
509            for m in members:
510                if m.enum:
511                    decoded[m.name] = self._decode_enum(decoded[m.name], m)
512        elif attr_spec.sub_type:
513            decoded = attr.as_c_array(attr_spec.sub_type)
514        else:
515            decoded = attr.as_bin()
516            if attr_spec.display_hint:
517                decoded = NlAttr.formatted_string(decoded, attr_spec.display_hint)
518        return decoded
519
520    def _decode_array_nest(self, attr, attr_spec):
521        decoded = []
522        offset = 0
523        while offset < len(attr.raw):
524            item = NlAttr(attr.raw, offset)
525            offset += item.full_len
526
527            subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes'])
528            decoded.append({ item.type: subattrs })
529        return decoded
530
531    def _decode_unknown(self, attr):
532        if attr.is_nest:
533            return self._decode(NlAttrs(attr.raw), None)
534        else:
535            return attr.as_bin()
536
537    def _rsp_add(self, rsp, name, is_multi, decoded):
538        if is_multi == None:
539            if name in rsp and type(rsp[name]) is not list:
540                rsp[name] = [rsp[name]]
541                is_multi = True
542            else:
543                is_multi = False
544
545        if not is_multi:
546            rsp[name] = decoded
547        elif name in rsp:
548            rsp[name].append(decoded)
549        else:
550            rsp[name] = [decoded]
551
552    def _decode(self, attrs, space):
553        if space:
554            attr_space = self.attr_sets[space]
555        rsp = dict()
556        for attr in attrs:
557            try:
558                attr_spec = attr_space.attrs_by_val[attr.type]
559            except (KeyError, UnboundLocalError):
560                if not self.process_unknown:
561                    raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'")
562                attr_name = f"UnknownAttr({attr.type})"
563                self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr))
564                continue
565
566            if attr_spec["type"] == 'nest':
567                subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'])
568                decoded = subdict
569            elif attr_spec["type"] == 'string':
570                decoded = attr.as_strz()
571            elif attr_spec["type"] == 'binary':
572                decoded = self._decode_binary(attr, attr_spec)
573            elif attr_spec["type"] == 'flag':
574                decoded = True
575            elif attr_spec.is_auto_scalar:
576                decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order)
577            elif attr_spec["type"] in NlAttr.type_formats:
578                decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
579                if 'enum' in attr_spec:
580                    decoded = self._decode_enum(decoded, attr_spec)
581            elif attr_spec["type"] == 'array-nest':
582                decoded = self._decode_array_nest(attr, attr_spec)
583            elif attr_spec["type"] == 'bitfield32':
584                value, selector = struct.unpack("II", attr.raw)
585                if 'enum' in attr_spec:
586                    value = self._decode_enum(value, attr_spec)
587                    selector = self._decode_enum(selector, attr_spec)
588                decoded = {"value": value, "selector": selector}
589            else:
590                if not self.process_unknown:
591                    raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
592                decoded = self._decode_unknown(attr)
593
594            self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded)
595
596        return rsp
597
598    def _decode_extack_path(self, attrs, attr_set, offset, target):
599        for attr in attrs:
600            try:
601                attr_spec = attr_set.attrs_by_val[attr.type]
602            except KeyError:
603                raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'")
604            if offset > target:
605                break
606            if offset == target:
607                return '.' + attr_spec.name
608
609            if offset + attr.full_len <= target:
610                offset += attr.full_len
611                continue
612            if attr_spec['type'] != 'nest':
613                raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
614            offset += 4
615            subpath = self._decode_extack_path(NlAttrs(attr.raw),
616                                               self.attr_sets[attr_spec['nested-attributes']],
617                                               offset, target)
618            if subpath is None:
619                return None
620            return '.' + attr_spec.name + subpath
621
622        return None
623
624    def _decode_extack(self, request, op, extack):
625        if 'bad-attr-offs' not in extack:
626            return
627
628        msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set))
629        offset = 20 + self._fixed_header_size(op)
630        path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
631                                        extack['bad-attr-offs'])
632        if path:
633            del extack['bad-attr-offs']
634            extack['bad-attr'] = path
635
636    def _fixed_header_size(self, op):
637        if op.fixed_header:
638            fixed_header_members = self.consts[op.fixed_header].members
639            size = 0
640            for m in fixed_header_members:
641                format = NlAttr.get_format(m.type, m.byte_order)
642                size += format.size
643            return size
644        else:
645            return 0
646
647    def _decode_fixed_header(self, msg, name):
648        fixed_header_members = self.consts[name].members
649        fixed_header_attrs = dict()
650        offset = 0
651        for m in fixed_header_members:
652            format = NlAttr.get_format(m.type, m.byte_order)
653            [ value ] = format.unpack_from(msg.raw, offset)
654            offset += format.size
655            if m.enum:
656                value = self._decode_enum(value, m)
657            fixed_header_attrs[m.name] = value
658        return fixed_header_attrs
659
660    def handle_ntf(self, decoded):
661        msg = dict()
662        if self.include_raw:
663            msg['raw'] = decoded
664        op = self.rsp_by_value[decoded.cmd()]
665        attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
666        if op.fixed_header:
667            attrs.update(self._decode_fixed_header(decoded, op.fixed_header))
668
669        msg['name'] = op['name']
670        msg['msg'] = attrs
671        self.async_msg_queue.append(msg)
672
673    def check_ntf(self):
674        while True:
675            try:
676                reply = self.sock.recv(128 * 1024, socket.MSG_DONTWAIT)
677            except BlockingIOError:
678                return
679
680            nms = NlMsgs(reply)
681            for nl_msg in nms:
682                if nl_msg.error:
683                    print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
684                    print(nl_msg)
685                    continue
686                if nl_msg.done:
687                    print("Netlink done while checking for ntf!?")
688                    continue
689
690                decoded = self.nlproto.decode(self, nl_msg)
691                if decoded.cmd() not in self.async_msg_ids:
692                    print("Unexpected msg id done while checking for ntf", decoded)
693                    continue
694
695                self.handle_ntf(decoded)
696
697    def operation_do_attributes(self, name):
698      """
699      For a given operation name, find and return a supported
700      set of attributes (as a dict).
701      """
702      op = self.find_operation(name)
703      if not op:
704        return None
705
706      return op['do']['request']['attributes'].copy()
707
708    def _op(self, method, vals, flags, dump=False):
709        op = self.ops[method]
710
711        nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
712        for flag in flags or []:
713            nl_flags |= flag
714        if dump:
715            nl_flags |= Netlink.NLM_F_DUMP
716
717        req_seq = random.randint(1024, 65535)
718        msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
719        fixed_header_members = []
720        if op.fixed_header:
721            fixed_header_members = self.consts[op.fixed_header].members
722            for m in fixed_header_members:
723                value = vals.pop(m.name) if m.name in vals else 0
724                format = NlAttr.get_format(m.type, m.byte_order)
725                msg += format.pack(value)
726        for name, value in vals.items():
727            msg += self._add_attr(op.attr_set.name, name, value)
728        msg = _genl_msg_finalize(msg)
729
730        self.sock.send(msg, 0)
731
732        done = False
733        rsp = []
734        while not done:
735            reply = self.sock.recv(128 * 1024)
736            nms = NlMsgs(reply, attr_space=op.attr_set)
737            for nl_msg in nms:
738                if nl_msg.extack:
739                    self._decode_extack(msg, op, nl_msg.extack)
740
741                if nl_msg.error:
742                    raise NlError(nl_msg)
743                if nl_msg.done:
744                    if nl_msg.extack:
745                        print("Netlink warning:")
746                        print(nl_msg)
747                    done = True
748                    break
749
750                decoded = self.nlproto.decode(self, nl_msg)
751
752                # Check if this is a reply to our request
753                if nl_msg.nl_seq != req_seq or decoded.cmd() != op.rsp_value:
754                    if decoded.cmd() in self.async_msg_ids:
755                        self.handle_ntf(decoded)
756                        continue
757                    else:
758                        print('Unexpected message: ' + repr(decoded))
759                        continue
760
761                rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
762                if op.fixed_header:
763                    rsp_msg.update(self._decode_fixed_header(decoded, op.fixed_header))
764                rsp.append(rsp_msg)
765
766        if not rsp:
767            return None
768        if not dump and len(rsp) == 1:
769            return rsp[0]
770        return rsp
771
772    def do(self, method, vals, flags):
773        return self._op(method, vals, flags)
774
775    def dump(self, method, vals):
776        return self._op(method, vals, [], dump=True)
777