xref: /linux/tools/net/ynl/lib/ynl.py (revision c4101e55974cc7d835fbd2d8e01553a3f61e9e75)
1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2
3from collections import namedtuple
4import functools
5import os
6import random
7import socket
8import struct
9from struct import Struct
10import yaml
11import ipaddress
12import uuid
13
14from .nlspec import SpecFamily
15
16#
17# Generic Netlink code which should really be in some library, but I can't quickly find one.
18#
19
20
21class Netlink:
22    # Netlink socket
23    SOL_NETLINK = 270
24
25    NETLINK_ADD_MEMBERSHIP = 1
26    NETLINK_CAP_ACK = 10
27    NETLINK_EXT_ACK = 11
28    NETLINK_GET_STRICT_CHK = 12
29
30    # Netlink message
31    NLMSG_ERROR = 2
32    NLMSG_DONE = 3
33
34    NLM_F_REQUEST = 1
35    NLM_F_ACK = 4
36    NLM_F_ROOT = 0x100
37    NLM_F_MATCH = 0x200
38
39    NLM_F_REPLACE = 0x100
40    NLM_F_EXCL = 0x200
41    NLM_F_CREATE = 0x400
42    NLM_F_APPEND = 0x800
43
44    NLM_F_CAPPED = 0x100
45    NLM_F_ACK_TLVS = 0x200
46
47    NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
48
49    NLA_F_NESTED = 0x8000
50    NLA_F_NET_BYTEORDER = 0x4000
51
52    NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
53
54    # Genetlink defines
55    NETLINK_GENERIC = 16
56
57    GENL_ID_CTRL = 0x10
58
59    # nlctrl
60    CTRL_CMD_GETFAMILY = 3
61
62    CTRL_ATTR_FAMILY_ID = 1
63    CTRL_ATTR_FAMILY_NAME = 2
64    CTRL_ATTR_MAXATTR = 5
65    CTRL_ATTR_MCAST_GROUPS = 7
66
67    CTRL_ATTR_MCAST_GRP_NAME = 1
68    CTRL_ATTR_MCAST_GRP_ID = 2
69
70    # Extack types
71    NLMSGERR_ATTR_MSG = 1
72    NLMSGERR_ATTR_OFFS = 2
73    NLMSGERR_ATTR_COOKIE = 3
74    NLMSGERR_ATTR_POLICY = 4
75    NLMSGERR_ATTR_MISS_TYPE = 5
76    NLMSGERR_ATTR_MISS_NEST = 6
77
78
79class NlError(Exception):
80  def __init__(self, nl_msg):
81    self.nl_msg = nl_msg
82
83  def __str__(self):
84    return f"Netlink error: {os.strerror(-self.nl_msg.error)}\n{self.nl_msg}"
85
86
87class NlAttr:
88    ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
89    type_formats = {
90        'u8' : ScalarFormat(Struct('B'), Struct("B"),  Struct("B")),
91        's8' : ScalarFormat(Struct('b'), Struct("b"),  Struct("b")),
92        'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
93        's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
94        'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
95        's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
96        'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
97        's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
98    }
99
100    def __init__(self, raw, offset):
101        self._len, self._type = struct.unpack("HH", raw[offset : offset + 4])
102        self.type = self._type & ~Netlink.NLA_TYPE_MASK
103        self.is_nest = self._type & Netlink.NLA_F_NESTED
104        self.payload_len = self._len
105        self.full_len = (self.payload_len + 3) & ~3
106        self.raw = raw[offset + 4 : offset + self.payload_len]
107
108    @classmethod
109    def get_format(cls, attr_type, byte_order=None):
110        format = cls.type_formats[attr_type]
111        if byte_order:
112            return format.big if byte_order == "big-endian" \
113                else format.little
114        return format.native
115
116    @classmethod
117    def formatted_string(cls, raw, display_hint):
118        if display_hint == 'mac':
119            formatted = ':'.join('%02x' % b for b in raw)
120        elif display_hint == 'hex':
121            formatted = bytes.hex(raw, ' ')
122        elif display_hint in [ 'ipv4', 'ipv6' ]:
123            formatted = format(ipaddress.ip_address(raw))
124        elif display_hint == 'uuid':
125            formatted = str(uuid.UUID(bytes=raw))
126        else:
127            formatted = raw
128        return formatted
129
130    def as_scalar(self, attr_type, byte_order=None):
131        format = self.get_format(attr_type, byte_order)
132        return format.unpack(self.raw)[0]
133
134    def as_auto_scalar(self, attr_type, byte_order=None):
135        if len(self.raw) != 4 and len(self.raw) != 8:
136            raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}")
137        real_type = attr_type[0] + str(len(self.raw) * 8)
138        format = self.get_format(real_type, byte_order)
139        return format.unpack(self.raw)[0]
140
141    def as_strz(self):
142        return self.raw.decode('ascii')[:-1]
143
144    def as_bin(self):
145        return self.raw
146
147    def as_c_array(self, type):
148        format = self.get_format(type)
149        return [ x[0] for x in format.iter_unpack(self.raw) ]
150
151    def as_struct(self, members):
152        value = dict()
153        offset = 0
154        for m in members:
155            # TODO: handle non-scalar members
156            if m.type == 'binary':
157                decoded = self.raw[offset : offset + m['len']]
158                offset += m['len']
159            elif m.type in NlAttr.type_formats:
160                format = self.get_format(m.type, m.byte_order)
161                [ decoded ] = format.unpack_from(self.raw, offset)
162                offset += format.size
163            if m.display_hint:
164                decoded = self.formatted_string(decoded, m.display_hint)
165            value[m.name] = decoded
166        return value
167
168    def __repr__(self):
169        return f"[type:{self.type} len:{self._len}] {self.raw}"
170
171
172class NlAttrs:
173    def __init__(self, msg, offset=0):
174        self.attrs = []
175
176        while offset < len(msg):
177            attr = NlAttr(msg, offset)
178            offset += attr.full_len
179            self.attrs.append(attr)
180
181    def __iter__(self):
182        yield from self.attrs
183
184    def __repr__(self):
185        msg = ''
186        for a in self.attrs:
187            if msg:
188                msg += '\n'
189            msg += repr(a)
190        return msg
191
192
193class NlMsg:
194    def __init__(self, msg, offset, attr_space=None):
195        self.hdr = msg[offset : offset + 16]
196
197        self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
198            struct.unpack("IHHII", self.hdr)
199
200        self.raw = msg[offset + 16 : offset + self.nl_len]
201
202        self.error = 0
203        self.done = 0
204
205        extack_off = None
206        if self.nl_type == Netlink.NLMSG_ERROR:
207            self.error = struct.unpack("i", self.raw[0:4])[0]
208            self.done = 1
209            extack_off = 20
210        elif self.nl_type == Netlink.NLMSG_DONE:
211            self.done = 1
212            extack_off = 4
213
214        self.extack = None
215        if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
216            self.extack = dict()
217            extack_attrs = NlAttrs(self.raw[extack_off:])
218            for extack in extack_attrs:
219                if extack.type == Netlink.NLMSGERR_ATTR_MSG:
220                    self.extack['msg'] = extack.as_strz()
221                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
222                    self.extack['miss-type'] = extack.as_scalar('u32')
223                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
224                    self.extack['miss-nest'] = extack.as_scalar('u32')
225                elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
226                    self.extack['bad-attr-offs'] = extack.as_scalar('u32')
227                else:
228                    if 'unknown' not in self.extack:
229                        self.extack['unknown'] = []
230                    self.extack['unknown'].append(extack)
231
232            if attr_space:
233                # We don't have the ability to parse nests yet, so only do global
234                if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
235                    miss_type = self.extack['miss-type']
236                    if miss_type in attr_space.attrs_by_val:
237                        spec = attr_space.attrs_by_val[miss_type]
238                        desc = spec['name']
239                        if 'doc' in spec:
240                            desc += f" ({spec['doc']})"
241                        self.extack['miss-type'] = desc
242
243    def cmd(self):
244        return self.nl_type
245
246    def __repr__(self):
247        msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}\n"
248        if self.error:
249            msg += '\terror: ' + str(self.error)
250        if self.extack:
251            msg += '\textack: ' + repr(self.extack)
252        return msg
253
254
255class NlMsgs:
256    def __init__(self, data, attr_space=None):
257        self.msgs = []
258
259        offset = 0
260        while offset < len(data):
261            msg = NlMsg(data, offset, attr_space=attr_space)
262            offset += msg.nl_len
263            self.msgs.append(msg)
264
265    def __iter__(self):
266        yield from self.msgs
267
268
269genl_family_name_to_id = None
270
271
272def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
273    # we prepend length in _genl_msg_finalize()
274    if seq is None:
275        seq = random.randint(1, 1024)
276    nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
277    genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
278    return nlmsg + genlmsg
279
280
281def _genl_msg_finalize(msg):
282    return struct.pack("I", len(msg) + 4) + msg
283
284
285def _genl_load_families():
286    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
287        sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
288
289        msg = _genl_msg(Netlink.GENL_ID_CTRL,
290                        Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
291                        Netlink.CTRL_CMD_GETFAMILY, 1)
292        msg = _genl_msg_finalize(msg)
293
294        sock.send(msg, 0)
295
296        global genl_family_name_to_id
297        genl_family_name_to_id = dict()
298
299        while True:
300            reply = sock.recv(128 * 1024)
301            nms = NlMsgs(reply)
302            for nl_msg in nms:
303                if nl_msg.error:
304                    print("Netlink error:", nl_msg.error)
305                    return
306                if nl_msg.done:
307                    return
308
309                gm = GenlMsg(nl_msg)
310                fam = dict()
311                for attr in NlAttrs(gm.raw):
312                    if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
313                        fam['id'] = attr.as_scalar('u16')
314                    elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
315                        fam['name'] = attr.as_strz()
316                    elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
317                        fam['maxattr'] = attr.as_scalar('u32')
318                    elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
319                        fam['mcast'] = dict()
320                        for entry in NlAttrs(attr.raw):
321                            mcast_name = None
322                            mcast_id = None
323                            for entry_attr in NlAttrs(entry.raw):
324                                if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
325                                    mcast_name = entry_attr.as_strz()
326                                elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
327                                    mcast_id = entry_attr.as_scalar('u32')
328                            if mcast_name and mcast_id is not None:
329                                fam['mcast'][mcast_name] = mcast_id
330                if 'name' in fam and 'id' in fam:
331                    genl_family_name_to_id[fam['name']] = fam
332
333
334class GenlMsg:
335    def __init__(self, nl_msg):
336        self.nl = nl_msg
337        self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
338        self.raw = nl_msg.raw[4:]
339
340    def cmd(self):
341        return self.genl_cmd
342
343    def __repr__(self):
344        msg = repr(self.nl)
345        msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
346        for a in self.raw_attrs:
347            msg += '\t\t' + repr(a) + '\n'
348        return msg
349
350
351class NetlinkProtocol:
352    def __init__(self, family_name, proto_num):
353        self.family_name = family_name
354        self.proto_num = proto_num
355
356    def _message(self, nl_type, nl_flags, seq=None):
357        if seq is None:
358            seq = random.randint(1, 1024)
359        nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
360        return nlmsg
361
362    def message(self, flags, command, version, seq=None):
363        return self._message(command, flags, seq)
364
365    def _decode(self, nl_msg):
366        return nl_msg
367
368    def decode(self, ynl, nl_msg):
369        msg = self._decode(nl_msg)
370        fixed_header_size = 0
371        if ynl:
372            op = ynl.rsp_by_value[msg.cmd()]
373            fixed_header_size = ynl._fixed_header_size(op.fixed_header)
374        msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size)
375        return msg
376
377    def get_mcast_id(self, mcast_name, mcast_groups):
378        if mcast_name not in mcast_groups:
379            raise Exception(f'Multicast group "{mcast_name}" not present in the spec')
380        return mcast_groups[mcast_name].value
381
382
383class GenlProtocol(NetlinkProtocol):
384    def __init__(self, family_name):
385        super().__init__(family_name, Netlink.NETLINK_GENERIC)
386
387        global genl_family_name_to_id
388        if genl_family_name_to_id is None:
389            _genl_load_families()
390
391        self.genl_family = genl_family_name_to_id[family_name]
392        self.family_id = genl_family_name_to_id[family_name]['id']
393
394    def message(self, flags, command, version, seq=None):
395        nlmsg = self._message(self.family_id, flags, seq)
396        genlmsg = struct.pack("BBH", command, version, 0)
397        return nlmsg + genlmsg
398
399    def _decode(self, nl_msg):
400        return GenlMsg(nl_msg)
401
402    def get_mcast_id(self, mcast_name, mcast_groups):
403        if mcast_name not in self.genl_family['mcast']:
404            raise Exception(f'Multicast group "{mcast_name}" not present in the family')
405        return self.genl_family['mcast'][mcast_name]
406
407
408#
409# YNL implementation details.
410#
411
412
413class YnlFamily(SpecFamily):
414    def __init__(self, def_path, schema=None, process_unknown=False):
415        super().__init__(def_path, schema)
416
417        self.include_raw = False
418        self.process_unknown = process_unknown
419
420        try:
421            if self.proto == "netlink-raw":
422                self.nlproto = NetlinkProtocol(self.yaml['name'],
423                                               self.yaml['protonum'])
424            else:
425                self.nlproto = GenlProtocol(self.yaml['name'])
426        except KeyError:
427            raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
428
429        self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
430        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
431        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
432        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)
433
434        self.async_msg_ids = set()
435        self.async_msg_queue = []
436
437        for msg in self.msgs.values():
438            if msg.is_async:
439                self.async_msg_ids.add(msg.rsp_value)
440
441        for op_name, op in self.ops.items():
442            bound_f = functools.partial(self._op, op_name)
443            setattr(self, op.ident_name, bound_f)
444
445
446    def ntf_subscribe(self, mcast_name):
447        mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
448        self.sock.bind((0, 0))
449        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
450                             mcast_id)
451
452    def _add_attr(self, space, name, value):
453        try:
454            attr = self.attr_sets[space][name]
455        except KeyError:
456            raise Exception(f"Space '{space}' has no attribute '{name}'")
457        nl_type = attr.value
458        if attr["type"] == 'nest':
459            nl_type |= Netlink.NLA_F_NESTED
460            attr_payload = b''
461            for subname, subvalue in value.items():
462                attr_payload += self._add_attr(attr['nested-attributes'], subname, subvalue)
463        elif attr["type"] == 'flag':
464            attr_payload = b''
465        elif attr["type"] == 'string':
466            attr_payload = str(value).encode('ascii') + b'\x00'
467        elif attr["type"] == 'binary':
468            if isinstance(value, bytes):
469                attr_payload = value
470            elif isinstance(value, str):
471                attr_payload = bytes.fromhex(value)
472            else:
473                raise Exception(f'Unknown type for binary attribute, value: {value}')
474        elif attr.is_auto_scalar:
475            scalar = int(value)
476            real_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64')
477            format = NlAttr.get_format(real_type, attr.byte_order)
478            attr_payload = format.pack(int(value))
479        elif attr['type'] in NlAttr.type_formats:
480            format = NlAttr.get_format(attr['type'], attr.byte_order)
481            attr_payload = format.pack(int(value))
482        elif attr['type'] in "bitfield32":
483            attr_payload = struct.pack("II", int(value["value"]), int(value["selector"]))
484        else:
485            raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
486
487        pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
488        return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
489
490    def _decode_enum(self, raw, attr_spec):
491        enum = self.consts[attr_spec['enum']]
492        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
493            i = 0
494            value = set()
495            while raw:
496                if raw & 1:
497                    value.add(enum.entries_by_val[i].name)
498                raw >>= 1
499                i += 1
500        else:
501            value = enum.entries_by_val[raw].name
502        return value
503
504    def _decode_binary(self, attr, attr_spec):
505        if attr_spec.struct_name:
506            members = self.consts[attr_spec.struct_name]
507            decoded = attr.as_struct(members)
508            for m in members:
509                if m.enum:
510                    decoded[m.name] = self._decode_enum(decoded[m.name], m)
511        elif attr_spec.sub_type:
512            decoded = attr.as_c_array(attr_spec.sub_type)
513        else:
514            decoded = attr.as_bin()
515            if attr_spec.display_hint:
516                decoded = NlAttr.formatted_string(decoded, attr_spec.display_hint)
517        return decoded
518
519    def _decode_array_nest(self, attr, attr_spec):
520        decoded = []
521        offset = 0
522        while offset < len(attr.raw):
523            item = NlAttr(attr.raw, offset)
524            offset += item.full_len
525
526            subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes'])
527            decoded.append({ item.type: subattrs })
528        return decoded
529
530    def _decode_unknown(self, attr):
531        if attr.is_nest:
532            return self._decode(NlAttrs(attr.raw), None)
533        else:
534            return attr.as_bin()
535
536    def _rsp_add(self, rsp, name, is_multi, decoded):
537        if is_multi == None:
538            if name in rsp and type(rsp[name]) is not list:
539                rsp[name] = [rsp[name]]
540                is_multi = True
541            else:
542                is_multi = False
543
544        if not is_multi:
545            rsp[name] = decoded
546        elif name in rsp:
547            rsp[name].append(decoded)
548        else:
549            rsp[name] = [decoded]
550
551    def _resolve_selector(self, attr_spec, vals):
552        sub_msg = attr_spec.sub_message
553        if sub_msg not in self.sub_msgs:
554            raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}")
555        sub_msg_spec = self.sub_msgs[sub_msg]
556
557        selector = attr_spec.selector
558        if selector not in vals:
559            raise Exception(f"There is no value for {selector} to resolve '{attr_spec.name}'")
560        value = vals[selector]
561        if value not in sub_msg_spec.formats:
562            raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'")
563
564        spec = sub_msg_spec.formats[value]
565        return spec
566
567    def _decode_sub_msg(self, attr, attr_spec, rsp):
568        msg_format = self._resolve_selector(attr_spec, rsp)
569        decoded = {}
570        offset = 0
571        if msg_format.fixed_header:
572            decoded.update(self._decode_fixed_header(attr, msg_format.fixed_header));
573            offset = self._fixed_header_size(msg_format.fixed_header)
574        if msg_format.attr_set:
575            if msg_format.attr_set in self.attr_sets:
576                subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set)
577                decoded.update(subdict)
578            else:
579                raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'")
580        return decoded
581
582    def _decode(self, attrs, space):
583        if space:
584            attr_space = self.attr_sets[space]
585        rsp = dict()
586        for attr in attrs:
587            try:
588                attr_spec = attr_space.attrs_by_val[attr.type]
589            except (KeyError, UnboundLocalError):
590                if not self.process_unknown:
591                    raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'")
592                attr_name = f"UnknownAttr({attr.type})"
593                self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr))
594                continue
595
596            if attr_spec["type"] == 'nest':
597                subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'])
598                decoded = subdict
599            elif attr_spec["type"] == 'string':
600                decoded = attr.as_strz()
601            elif attr_spec["type"] == 'binary':
602                decoded = self._decode_binary(attr, attr_spec)
603            elif attr_spec["type"] == 'flag':
604                decoded = True
605            elif attr_spec.is_auto_scalar:
606                decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order)
607            elif attr_spec["type"] in NlAttr.type_formats:
608                decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
609                if 'enum' in attr_spec:
610                    decoded = self._decode_enum(decoded, attr_spec)
611            elif attr_spec["type"] == 'array-nest':
612                decoded = self._decode_array_nest(attr, attr_spec)
613            elif attr_spec["type"] == 'bitfield32':
614                value, selector = struct.unpack("II", attr.raw)
615                if 'enum' in attr_spec:
616                    value = self._decode_enum(value, attr_spec)
617                    selector = self._decode_enum(selector, attr_spec)
618                decoded = {"value": value, "selector": selector}
619            elif attr_spec["type"] == 'sub-message':
620                decoded = self._decode_sub_msg(attr, attr_spec, rsp)
621            else:
622                if not self.process_unknown:
623                    raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
624                decoded = self._decode_unknown(attr)
625
626            self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded)
627
628        return rsp
629
630    def _decode_extack_path(self, attrs, attr_set, offset, target):
631        for attr in attrs:
632            try:
633                attr_spec = attr_set.attrs_by_val[attr.type]
634            except KeyError:
635                raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'")
636            if offset > target:
637                break
638            if offset == target:
639                return '.' + attr_spec.name
640
641            if offset + attr.full_len <= target:
642                offset += attr.full_len
643                continue
644            if attr_spec['type'] != 'nest':
645                raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
646            offset += 4
647            subpath = self._decode_extack_path(NlAttrs(attr.raw),
648                                               self.attr_sets[attr_spec['nested-attributes']],
649                                               offset, target)
650            if subpath is None:
651                return None
652            return '.' + attr_spec.name + subpath
653
654        return None
655
656    def _decode_extack(self, request, op, extack):
657        if 'bad-attr-offs' not in extack:
658            return
659
660        msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set))
661        offset = 20 + self._fixed_header_size(op.fixed_header)
662        path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
663                                        extack['bad-attr-offs'])
664        if path:
665            del extack['bad-attr-offs']
666            extack['bad-attr'] = path
667
668    def _fixed_header_size(self, name):
669        if name:
670            fixed_header_members = self.consts[name].members
671            size = 0
672            for m in fixed_header_members:
673                if m.type in ['pad', 'binary']:
674                    size += m.len
675                else:
676                    format = NlAttr.get_format(m.type, m.byte_order)
677                    size += format.size
678            return size
679        else:
680            return 0
681
682    def _decode_fixed_header(self, msg, name):
683        fixed_header_members = self.consts[name].members
684        fixed_header_attrs = dict()
685        offset = 0
686        for m in fixed_header_members:
687            value = None
688            if m.type == 'pad':
689                offset += m.len
690            elif m.type == 'binary':
691                value = msg.raw[offset : offset + m.len]
692                offset += m.len
693            else:
694                format = NlAttr.get_format(m.type, m.byte_order)
695                [ value ] = format.unpack_from(msg.raw, offset)
696                offset += format.size
697            if value is not None:
698                if m.enum:
699                    value = self._decode_enum(value, m)
700                fixed_header_attrs[m.name] = value
701        return fixed_header_attrs
702
703    def handle_ntf(self, decoded):
704        msg = dict()
705        if self.include_raw:
706            msg['raw'] = decoded
707        op = self.rsp_by_value[decoded.cmd()]
708        attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
709        if op.fixed_header:
710            attrs.update(self._decode_fixed_header(decoded, op.fixed_header))
711
712        msg['name'] = op['name']
713        msg['msg'] = attrs
714        self.async_msg_queue.append(msg)
715
716    def check_ntf(self):
717        while True:
718            try:
719                reply = self.sock.recv(128 * 1024, socket.MSG_DONTWAIT)
720            except BlockingIOError:
721                return
722
723            nms = NlMsgs(reply)
724            for nl_msg in nms:
725                if nl_msg.error:
726                    print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
727                    print(nl_msg)
728                    continue
729                if nl_msg.done:
730                    print("Netlink done while checking for ntf!?")
731                    continue
732
733                decoded = self.nlproto.decode(self, nl_msg)
734                if decoded.cmd() not in self.async_msg_ids:
735                    print("Unexpected msg id done while checking for ntf", decoded)
736                    continue
737
738                self.handle_ntf(decoded)
739
740    def operation_do_attributes(self, name):
741      """
742      For a given operation name, find and return a supported
743      set of attributes (as a dict).
744      """
745      op = self.find_operation(name)
746      if not op:
747        return None
748
749      return op['do']['request']['attributes'].copy()
750
751    def _op(self, method, vals, flags=None, dump=False):
752        op = self.ops[method]
753
754        nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
755        for flag in flags or []:
756            nl_flags |= flag
757        if dump:
758            nl_flags |= Netlink.NLM_F_DUMP
759
760        req_seq = random.randint(1024, 65535)
761        msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
762        fixed_header_members = []
763        if op.fixed_header:
764            fixed_header_members = self.consts[op.fixed_header].members
765            for m in fixed_header_members:
766                value = vals.pop(m.name) if m.name in vals else 0
767                if m.type == 'pad':
768                    msg += bytearray(m.len)
769                elif m.type == 'binary':
770                    msg += bytes.fromhex(value)
771                else:
772                    format = NlAttr.get_format(m.type, m.byte_order)
773                    msg += format.pack(value)
774        for name, value in vals.items():
775            msg += self._add_attr(op.attr_set.name, name, value)
776        msg = _genl_msg_finalize(msg)
777
778        self.sock.send(msg, 0)
779
780        done = False
781        rsp = []
782        while not done:
783            reply = self.sock.recv(128 * 1024)
784            nms = NlMsgs(reply, attr_space=op.attr_set)
785            for nl_msg in nms:
786                if nl_msg.extack:
787                    self._decode_extack(msg, op, nl_msg.extack)
788
789                if nl_msg.error:
790                    raise NlError(nl_msg)
791                if nl_msg.done:
792                    if nl_msg.extack:
793                        print("Netlink warning:")
794                        print(nl_msg)
795                    done = True
796                    break
797
798                decoded = self.nlproto.decode(self, nl_msg)
799
800                # Check if this is a reply to our request
801                if nl_msg.nl_seq != req_seq or decoded.cmd() != op.rsp_value:
802                    if decoded.cmd() in self.async_msg_ids:
803                        self.handle_ntf(decoded)
804                        continue
805                    else:
806                        print('Unexpected message: ' + repr(decoded))
807                        continue
808
809                rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
810                if op.fixed_header:
811                    rsp_msg.update(self._decode_fixed_header(decoded, op.fixed_header))
812                rsp.append(rsp_msg)
813
814        if not rsp:
815            return None
816        if not dump and len(rsp) == 1:
817            return rsp[0]
818        return rsp
819
820    def do(self, method, vals, flags=None):
821        return self._op(method, vals, flags)
822
823    def dump(self, method, vals):
824        return self._op(method, vals, [], dump=True)
825