xref: /linux/tools/net/ynl/lib/ynl.py (revision db7193a5c9db4babbfc99798f2e636b12419c12b)
1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2
3from collections import namedtuple
4import functools
5import os
6import random
7import socket
8import struct
9from struct import Struct
10import yaml
11import ipaddress
12import uuid
13
14from .nlspec import SpecFamily
15
16#
17# Generic Netlink code which should really be in some library, but I can't quickly find one.
18#
19
20
21class Netlink:
22    # Netlink socket
23    SOL_NETLINK = 270
24
25    NETLINK_ADD_MEMBERSHIP = 1
26    NETLINK_CAP_ACK = 10
27    NETLINK_EXT_ACK = 11
28
29    # Netlink message
30    NLMSG_ERROR = 2
31    NLMSG_DONE = 3
32
33    NLM_F_REQUEST = 1
34    NLM_F_ACK = 4
35    NLM_F_ROOT = 0x100
36    NLM_F_MATCH = 0x200
37    NLM_F_APPEND = 0x800
38
39    NLM_F_CAPPED = 0x100
40    NLM_F_ACK_TLVS = 0x200
41
42    NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
43
44    NLA_F_NESTED = 0x8000
45    NLA_F_NET_BYTEORDER = 0x4000
46
47    NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
48
49    # Genetlink defines
50    NETLINK_GENERIC = 16
51
52    GENL_ID_CTRL = 0x10
53
54    # nlctrl
55    CTRL_CMD_GETFAMILY = 3
56
57    CTRL_ATTR_FAMILY_ID = 1
58    CTRL_ATTR_FAMILY_NAME = 2
59    CTRL_ATTR_MAXATTR = 5
60    CTRL_ATTR_MCAST_GROUPS = 7
61
62    CTRL_ATTR_MCAST_GRP_NAME = 1
63    CTRL_ATTR_MCAST_GRP_ID = 2
64
65    # Extack types
66    NLMSGERR_ATTR_MSG = 1
67    NLMSGERR_ATTR_OFFS = 2
68    NLMSGERR_ATTR_COOKIE = 3
69    NLMSGERR_ATTR_POLICY = 4
70    NLMSGERR_ATTR_MISS_TYPE = 5
71    NLMSGERR_ATTR_MISS_NEST = 6
72
73
74class NlError(Exception):
75  def __init__(self, nl_msg):
76    self.nl_msg = nl_msg
77
78  def __str__(self):
79    return f"Netlink error: {os.strerror(-self.nl_msg.error)}\n{self.nl_msg}"
80
81
82class NlAttr:
83    ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
84    type_formats = {
85        'u8' : ScalarFormat(Struct('B'), Struct("B"),  Struct("B")),
86        's8' : ScalarFormat(Struct('b'), Struct("b"),  Struct("b")),
87        'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
88        's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
89        'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
90        's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
91        'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
92        's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
93    }
94
95    def __init__(self, raw, offset):
96        self._len, self._type = struct.unpack("HH", raw[offset:offset + 4])
97        self.type = self._type & ~Netlink.NLA_TYPE_MASK
98        self.payload_len = self._len
99        self.full_len = (self.payload_len + 3) & ~3
100        self.raw = raw[offset + 4:offset + self.payload_len]
101
102    @classmethod
103    def get_format(cls, attr_type, byte_order=None):
104        format = cls.type_formats[attr_type]
105        if byte_order:
106            return format.big if byte_order == "big-endian" \
107                else format.little
108        return format.native
109
110    @classmethod
111    def formatted_string(cls, raw, display_hint):
112        if display_hint == 'mac':
113            formatted = ':'.join('%02x' % b for b in raw)
114        elif display_hint == 'hex':
115            formatted = bytes.hex(raw, ' ')
116        elif display_hint in [ 'ipv4', 'ipv6' ]:
117            formatted = format(ipaddress.ip_address(raw))
118        elif display_hint == 'uuid':
119            formatted = str(uuid.UUID(bytes=raw))
120        else:
121            formatted = raw
122        return formatted
123
124    def as_scalar(self, attr_type, byte_order=None):
125        format = self.get_format(attr_type, byte_order)
126        return format.unpack(self.raw)[0]
127
128    def as_strz(self):
129        return self.raw.decode('ascii')[:-1]
130
131    def as_bin(self):
132        return self.raw
133
134    def as_c_array(self, type):
135        format = self.get_format(type)
136        return [ x[0] for x in format.iter_unpack(self.raw) ]
137
138    def as_struct(self, members):
139        value = dict()
140        offset = 0
141        for m in members:
142            # TODO: handle non-scalar members
143            if m.type == 'binary':
144                decoded = self.raw[offset:offset+m['len']]
145                offset += m['len']
146            elif m.type in NlAttr.type_formats:
147                format = self.get_format(m.type, m.byte_order)
148                [ decoded ] = format.unpack_from(self.raw, offset)
149                offset += format.size
150            if m.display_hint:
151                decoded = self.formatted_string(decoded, m.display_hint)
152            value[m.name] = decoded
153        return value
154
155    def __repr__(self):
156        return f"[type:{self.type} len:{self._len}] {self.raw}"
157
158
159class NlAttrs:
160    def __init__(self, msg):
161        self.attrs = []
162
163        offset = 0
164        while offset < len(msg):
165            attr = NlAttr(msg, offset)
166            offset += attr.full_len
167            self.attrs.append(attr)
168
169    def __iter__(self):
170        yield from self.attrs
171
172    def __repr__(self):
173        msg = ''
174        for a in self.attrs:
175            if msg:
176                msg += '\n'
177            msg += repr(a)
178        return msg
179
180
181class NlMsg:
182    def __init__(self, msg, offset, attr_space=None):
183        self.hdr = msg[offset:offset + 16]
184
185        self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
186            struct.unpack("IHHII", self.hdr)
187
188        self.raw = msg[offset + 16:offset + self.nl_len]
189
190        self.error = 0
191        self.done = 0
192
193        extack_off = None
194        if self.nl_type == Netlink.NLMSG_ERROR:
195            self.error = struct.unpack("i", self.raw[0:4])[0]
196            self.done = 1
197            extack_off = 20
198        elif self.nl_type == Netlink.NLMSG_DONE:
199            self.done = 1
200            extack_off = 4
201
202        self.extack = None
203        if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
204            self.extack = dict()
205            extack_attrs = NlAttrs(self.raw[extack_off:])
206            for extack in extack_attrs:
207                if extack.type == Netlink.NLMSGERR_ATTR_MSG:
208                    self.extack['msg'] = extack.as_strz()
209                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
210                    self.extack['miss-type'] = extack.as_scalar('u32')
211                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
212                    self.extack['miss-nest'] = extack.as_scalar('u32')
213                elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
214                    self.extack['bad-attr-offs'] = extack.as_scalar('u32')
215                else:
216                    if 'unknown' not in self.extack:
217                        self.extack['unknown'] = []
218                    self.extack['unknown'].append(extack)
219
220            if attr_space:
221                # We don't have the ability to parse nests yet, so only do global
222                if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
223                    miss_type = self.extack['miss-type']
224                    if miss_type in attr_space.attrs_by_val:
225                        spec = attr_space.attrs_by_val[miss_type]
226                        desc = spec['name']
227                        if 'doc' in spec:
228                            desc += f" ({spec['doc']})"
229                        self.extack['miss-type'] = desc
230
231    def __repr__(self):
232        msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}\n"
233        if self.error:
234            msg += '\terror: ' + str(self.error)
235        if self.extack:
236            msg += '\textack: ' + repr(self.extack)
237        return msg
238
239
240class NlMsgs:
241    def __init__(self, data, attr_space=None):
242        self.msgs = []
243
244        offset = 0
245        while offset < len(data):
246            msg = NlMsg(data, offset, attr_space=attr_space)
247            offset += msg.nl_len
248            self.msgs.append(msg)
249
250    def __iter__(self):
251        yield from self.msgs
252
253
254genl_family_name_to_id = None
255
256
257def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
258    # we prepend length in _genl_msg_finalize()
259    if seq is None:
260        seq = random.randint(1, 1024)
261    nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
262    genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
263    return nlmsg + genlmsg
264
265
266def _genl_msg_finalize(msg):
267    return struct.pack("I", len(msg) + 4) + msg
268
269
270def _genl_load_families():
271    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
272        sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
273
274        msg = _genl_msg(Netlink.GENL_ID_CTRL,
275                        Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
276                        Netlink.CTRL_CMD_GETFAMILY, 1)
277        msg = _genl_msg_finalize(msg)
278
279        sock.send(msg, 0)
280
281        global genl_family_name_to_id
282        genl_family_name_to_id = dict()
283
284        while True:
285            reply = sock.recv(128 * 1024)
286            nms = NlMsgs(reply)
287            for nl_msg in nms:
288                if nl_msg.error:
289                    print("Netlink error:", nl_msg.error)
290                    return
291                if nl_msg.done:
292                    return
293
294                gm = GenlMsg(nl_msg)
295                fam = dict()
296                for attr in gm.raw_attrs:
297                    if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
298                        fam['id'] = attr.as_scalar('u16')
299                    elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
300                        fam['name'] = attr.as_strz()
301                    elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
302                        fam['maxattr'] = attr.as_scalar('u32')
303                    elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
304                        fam['mcast'] = dict()
305                        for entry in NlAttrs(attr.raw):
306                            mcast_name = None
307                            mcast_id = None
308                            for entry_attr in NlAttrs(entry.raw):
309                                if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
310                                    mcast_name = entry_attr.as_strz()
311                                elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
312                                    mcast_id = entry_attr.as_scalar('u32')
313                            if mcast_name and mcast_id is not None:
314                                fam['mcast'][mcast_name] = mcast_id
315                if 'name' in fam and 'id' in fam:
316                    genl_family_name_to_id[fam['name']] = fam
317
318
319class GenlMsg:
320    def __init__(self, nl_msg, fixed_header_members=[]):
321        self.nl = nl_msg
322
323        self.hdr = nl_msg.raw[0:4]
324        offset = 4
325
326        self.genl_cmd, self.genl_version, _ = struct.unpack("BBH", self.hdr)
327
328        self.fixed_header_attrs = dict()
329        for m in fixed_header_members:
330            format = NlAttr.get_format(m.type, m.byte_order)
331            decoded = format.unpack_from(nl_msg.raw, offset)
332            offset += format.size
333            self.fixed_header_attrs[m.name] = decoded[0]
334
335        self.raw = nl_msg.raw[offset:]
336        self.raw_attrs = NlAttrs(self.raw)
337
338    def __repr__(self):
339        msg = repr(self.nl)
340        msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
341        for a in self.raw_attrs:
342            msg += '\t\t' + repr(a) + '\n'
343        return msg
344
345
346class GenlFamily:
347    def __init__(self, family_name):
348        self.family_name = family_name
349
350        global genl_family_name_to_id
351        if genl_family_name_to_id is None:
352            _genl_load_families()
353
354        self.genl_family = genl_family_name_to_id[family_name]
355        self.family_id = genl_family_name_to_id[family_name]['id']
356
357
358#
359# YNL implementation details.
360#
361
362
363class YnlFamily(SpecFamily):
364    def __init__(self, def_path, schema=None):
365        super().__init__(def_path, schema)
366
367        self.include_raw = False
368
369        self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC)
370        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
371        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
372
373        self.async_msg_ids = set()
374        self.async_msg_queue = []
375
376        for msg in self.msgs.values():
377            if msg.is_async:
378                self.async_msg_ids.add(msg.rsp_value)
379
380        for op_name, op in self.ops.items():
381            bound_f = functools.partial(self._op, op_name)
382            setattr(self, op.ident_name, bound_f)
383
384        try:
385            self.family = GenlFamily(self.yaml['name'])
386        except KeyError:
387            raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
388
389    def ntf_subscribe(self, mcast_name):
390        if mcast_name not in self.family.genl_family['mcast']:
391            raise Exception(f'Multicast group "{mcast_name}" not present in the family')
392
393        self.sock.bind((0, 0))
394        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
395                             self.family.genl_family['mcast'][mcast_name])
396
397    def _add_attr(self, space, name, value):
398        attr = self.attr_sets[space][name]
399        nl_type = attr.value
400        if attr["type"] == 'nest':
401            nl_type |= Netlink.NLA_F_NESTED
402            attr_payload = b''
403            for subname, subvalue in value.items():
404                attr_payload += self._add_attr(attr['nested-attributes'], subname, subvalue)
405        elif attr["type"] == 'flag':
406            attr_payload = b''
407        elif attr["type"] == 'string':
408            attr_payload = str(value).encode('ascii') + b'\x00'
409        elif attr["type"] == 'binary':
410            attr_payload = bytes.fromhex(value)
411        elif attr['type'] in NlAttr.type_formats:
412            format = NlAttr.get_format(attr['type'], attr.byte_order)
413            attr_payload = format.pack(int(value))
414        else:
415            raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
416
417        pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
418        return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
419
420    def _decode_enum(self, raw, attr_spec):
421        enum = self.consts[attr_spec['enum']]
422        if 'enum-as-flags' in attr_spec and attr_spec['enum-as-flags']:
423            i = 0
424            value = set()
425            while raw:
426                if raw & 1:
427                    value.add(enum.entries_by_val[i].name)
428                raw >>= 1
429                i += 1
430        else:
431            value = enum.entries_by_val[raw].name
432        return value
433
434    def _decode_binary(self, attr, attr_spec):
435        if attr_spec.struct_name:
436            members = self.consts[attr_spec.struct_name]
437            decoded = attr.as_struct(members)
438            for m in members:
439                if m.enum:
440                    decoded[m.name] = self._decode_enum(decoded[m.name], m)
441        elif attr_spec.sub_type:
442            decoded = attr.as_c_array(attr_spec.sub_type)
443        else:
444            decoded = attr.as_bin()
445            if attr_spec.display_hint:
446                decoded = NlAttr.formatted_string(decoded, attr_spec.display_hint)
447        return decoded
448
449    def _decode(self, attrs, space):
450        attr_space = self.attr_sets[space]
451        rsp = dict()
452        for attr in attrs:
453            attr_spec = attr_space.attrs_by_val[attr.type]
454            if attr_spec["type"] == 'nest':
455                subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'])
456                decoded = subdict
457            elif attr_spec["type"] == 'string':
458                decoded = attr.as_strz()
459            elif attr_spec["type"] == 'binary':
460                decoded = self._decode_binary(attr, attr_spec)
461            elif attr_spec["type"] == 'flag':
462                decoded = True
463            elif attr_spec["type"] in NlAttr.type_formats:
464                decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
465            else:
466                raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
467
468            if 'enum' in attr_spec:
469                decoded = self._decode_enum(decoded, attr_spec)
470
471            if not attr_spec.is_multi:
472                rsp[attr_spec['name']] = decoded
473            elif attr_spec.name in rsp:
474                rsp[attr_spec.name].append(decoded)
475            else:
476                rsp[attr_spec.name] = [decoded]
477
478        return rsp
479
480    def _decode_extack_path(self, attrs, attr_set, offset, target):
481        for attr in attrs:
482            attr_spec = attr_set.attrs_by_val[attr.type]
483            if offset > target:
484                break
485            if offset == target:
486                return '.' + attr_spec.name
487
488            if offset + attr.full_len <= target:
489                offset += attr.full_len
490                continue
491            if attr_spec['type'] != 'nest':
492                raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
493            offset += 4
494            subpath = self._decode_extack_path(NlAttrs(attr.raw),
495                                               self.attr_sets[attr_spec['nested-attributes']],
496                                               offset, target)
497            if subpath is None:
498                return None
499            return '.' + attr_spec.name + subpath
500
501        return None
502
503    def _decode_extack(self, request, attr_space, extack):
504        if 'bad-attr-offs' not in extack:
505            return
506
507        genl_req = GenlMsg(NlMsg(request, 0, attr_space=attr_space))
508        path = self._decode_extack_path(genl_req.raw_attrs, attr_space,
509                                        20, extack['bad-attr-offs'])
510        if path:
511            del extack['bad-attr-offs']
512            extack['bad-attr'] = path
513
514    def handle_ntf(self, nl_msg, genl_msg):
515        msg = dict()
516        if self.include_raw:
517            msg['nlmsg'] = nl_msg
518            msg['genlmsg'] = genl_msg
519        op = self.rsp_by_value[genl_msg.genl_cmd]
520        msg['name'] = op['name']
521        msg['msg'] = self._decode(genl_msg.raw_attrs, op.attr_set.name)
522        self.async_msg_queue.append(msg)
523
524    def check_ntf(self):
525        while True:
526            try:
527                reply = self.sock.recv(128 * 1024, socket.MSG_DONTWAIT)
528            except BlockingIOError:
529                return
530
531            nms = NlMsgs(reply)
532            for nl_msg in nms:
533                if nl_msg.error:
534                    print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
535                    print(nl_msg)
536                    continue
537                if nl_msg.done:
538                    print("Netlink done while checking for ntf!?")
539                    continue
540
541                gm = GenlMsg(nl_msg)
542                if gm.genl_cmd not in self.async_msg_ids:
543                    print("Unexpected msg id done while checking for ntf", gm)
544                    continue
545
546                self.handle_ntf(nl_msg, gm)
547
548    def operation_do_attributes(self, name):
549      """
550      For a given operation name, find and return a supported
551      set of attributes (as a dict).
552      """
553      op = self.find_operation(name)
554      if not op:
555        return None
556
557      return op['do']['request']['attributes'].copy()
558
559    def _op(self, method, vals, dump=False):
560        op = self.ops[method]
561
562        nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
563        if dump:
564            nl_flags |= Netlink.NLM_F_DUMP
565
566        req_seq = random.randint(1024, 65535)
567        msg = _genl_msg(self.family.family_id, nl_flags, op.req_value, 1, req_seq)
568        fixed_header_members = []
569        if op.fixed_header:
570            fixed_header_members = self.consts[op.fixed_header].members
571            for m in fixed_header_members:
572                value = vals.pop(m.name) if m.name in vals else 0
573                format = NlAttr.get_format(m.type, m.byte_order)
574                msg += format.pack(value)
575        for name, value in vals.items():
576            msg += self._add_attr(op.attr_set.name, name, value)
577        msg = _genl_msg_finalize(msg)
578
579        self.sock.send(msg, 0)
580
581        done = False
582        rsp = []
583        while not done:
584            reply = self.sock.recv(128 * 1024)
585            nms = NlMsgs(reply, attr_space=op.attr_set)
586            for nl_msg in nms:
587                if nl_msg.extack:
588                    self._decode_extack(msg, op.attr_set, nl_msg.extack)
589
590                if nl_msg.error:
591                    raise NlError(nl_msg)
592                if nl_msg.done:
593                    if nl_msg.extack:
594                        print("Netlink warning:")
595                        print(nl_msg)
596                    done = True
597                    break
598
599                gm = GenlMsg(nl_msg, fixed_header_members)
600                # Check if this is a reply to our request
601                if nl_msg.nl_seq != req_seq or gm.genl_cmd != op.rsp_value:
602                    if gm.genl_cmd in self.async_msg_ids:
603                        self.handle_ntf(nl_msg, gm)
604                        continue
605                    else:
606                        print('Unexpected message: ' + repr(gm))
607                        continue
608
609                rsp_msg = self._decode(gm.raw_attrs, op.attr_set.name)
610                rsp_msg.update(gm.fixed_header_attrs)
611                rsp.append(rsp_msg)
612
613        if not rsp:
614            return None
615        if not dump and len(rsp) == 1:
616            return rsp[0]
617        return rsp
618
619    def do(self, method, vals):
620        return self._op(method, vals)
621
622    def dump(self, method, vals):
623        return self._op(method, vals, dump=True)
624