xref: /linux/tools/net/ynl/lib/ynl.py (revision 52990390f91c1c39ca742fc8f390b29891d95127)
1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2
3import functools
4import os
5import random
6import socket
7import struct
8import yaml
9
10from .nlspec import SpecFamily
11
12#
13# Generic Netlink code which should really be in some library, but I can't quickly find one.
14#
15
16
17class Netlink:
18    # Netlink socket
19    SOL_NETLINK = 270
20
21    NETLINK_ADD_MEMBERSHIP = 1
22    NETLINK_CAP_ACK = 10
23    NETLINK_EXT_ACK = 11
24
25    # Netlink message
26    NLMSG_ERROR = 2
27    NLMSG_DONE = 3
28
29    NLM_F_REQUEST = 1
30    NLM_F_ACK = 4
31    NLM_F_ROOT = 0x100
32    NLM_F_MATCH = 0x200
33    NLM_F_APPEND = 0x800
34
35    NLM_F_CAPPED = 0x100
36    NLM_F_ACK_TLVS = 0x200
37
38    NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
39
40    NLA_F_NESTED = 0x8000
41    NLA_F_NET_BYTEORDER = 0x4000
42
43    NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
44
45    # Genetlink defines
46    NETLINK_GENERIC = 16
47
48    GENL_ID_CTRL = 0x10
49
50    # nlctrl
51    CTRL_CMD_GETFAMILY = 3
52
53    CTRL_ATTR_FAMILY_ID = 1
54    CTRL_ATTR_FAMILY_NAME = 2
55    CTRL_ATTR_MAXATTR = 5
56    CTRL_ATTR_MCAST_GROUPS = 7
57
58    CTRL_ATTR_MCAST_GRP_NAME = 1
59    CTRL_ATTR_MCAST_GRP_ID = 2
60
61    # Extack types
62    NLMSGERR_ATTR_MSG = 1
63    NLMSGERR_ATTR_OFFS = 2
64    NLMSGERR_ATTR_COOKIE = 3
65    NLMSGERR_ATTR_POLICY = 4
66    NLMSGERR_ATTR_MISS_TYPE = 5
67    NLMSGERR_ATTR_MISS_NEST = 6
68
69
70class NlError(Exception):
71  def __init__(self, nl_msg):
72    self.nl_msg = nl_msg
73
74  def __str__(self):
75    return f"Netlink error: {os.strerror(-self.nl_msg.error)}\n{self.nl_msg}"
76
77
78class NlAttr:
79    type_formats = { 'u8' : ('B', 1), 's8' : ('b', 1),
80                     'u16': ('H', 2), 's16': ('h', 2),
81                     'u32': ('I', 4), 's32': ('i', 4),
82                     'u64': ('Q', 8), 's64': ('q', 8) }
83
84    def __init__(self, raw, offset):
85        self._len, self._type = struct.unpack("HH", raw[offset:offset + 4])
86        self.type = self._type & ~Netlink.NLA_TYPE_MASK
87        self.payload_len = self._len
88        self.full_len = (self.payload_len + 3) & ~3
89        self.raw = raw[offset + 4:offset + self.payload_len]
90
91    def format_byte_order(byte_order):
92        if byte_order:
93            return ">" if byte_order == "big-endian" else "<"
94        return ""
95
96    def as_u8(self):
97        return struct.unpack("B", self.raw)[0]
98
99    def as_u16(self, byte_order=None):
100        endian = NlAttr.format_byte_order(byte_order)
101        return struct.unpack(f"{endian}H", self.raw)[0]
102
103    def as_u32(self, byte_order=None):
104        endian = NlAttr.format_byte_order(byte_order)
105        return struct.unpack(f"{endian}I", self.raw)[0]
106
107    def as_u64(self, byte_order=None):
108        endian = NlAttr.format_byte_order(byte_order)
109        return struct.unpack(f"{endian}Q", self.raw)[0]
110
111    def as_strz(self):
112        return self.raw.decode('ascii')[:-1]
113
114    def as_bin(self):
115        return self.raw
116
117    def as_c_array(self, type):
118        format, _ = self.type_formats[type]
119        return list({ x[0] for x in struct.iter_unpack(format, self.raw) })
120
121    def as_struct(self, members):
122        value = dict()
123        offset = 0
124        for m in members:
125            # TODO: handle non-scalar members
126            format, size = self.type_formats[m.type]
127            decoded = struct.unpack_from(format, self.raw, offset)
128            offset += size
129            value[m.name] = decoded[0]
130        return value
131
132    def __repr__(self):
133        return f"[type:{self.type} len:{self._len}] {self.raw}"
134
135
136class NlAttrs:
137    def __init__(self, msg):
138        self.attrs = []
139
140        offset = 0
141        while offset < len(msg):
142            attr = NlAttr(msg, offset)
143            offset += attr.full_len
144            self.attrs.append(attr)
145
146    def __iter__(self):
147        yield from self.attrs
148
149    def __repr__(self):
150        msg = ''
151        for a in self.attrs:
152            if msg:
153                msg += '\n'
154            msg += repr(a)
155        return msg
156
157
158class NlMsg:
159    def __init__(self, msg, offset, attr_space=None):
160        self.hdr = msg[offset:offset + 16]
161
162        self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
163            struct.unpack("IHHII", self.hdr)
164
165        self.raw = msg[offset + 16:offset + self.nl_len]
166
167        self.error = 0
168        self.done = 0
169
170        extack_off = None
171        if self.nl_type == Netlink.NLMSG_ERROR:
172            self.error = struct.unpack("i", self.raw[0:4])[0]
173            self.done = 1
174            extack_off = 20
175        elif self.nl_type == Netlink.NLMSG_DONE:
176            self.done = 1
177            extack_off = 4
178
179        self.extack = None
180        if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
181            self.extack = dict()
182            extack_attrs = NlAttrs(self.raw[extack_off:])
183            for extack in extack_attrs:
184                if extack.type == Netlink.NLMSGERR_ATTR_MSG:
185                    self.extack['msg'] = extack.as_strz()
186                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
187                    self.extack['miss-type'] = extack.as_u32()
188                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
189                    self.extack['miss-nest'] = extack.as_u32()
190                elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
191                    self.extack['bad-attr-offs'] = extack.as_u32()
192                else:
193                    if 'unknown' not in self.extack:
194                        self.extack['unknown'] = []
195                    self.extack['unknown'].append(extack)
196
197            if attr_space:
198                # We don't have the ability to parse nests yet, so only do global
199                if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
200                    miss_type = self.extack['miss-type']
201                    if miss_type in attr_space.attrs_by_val:
202                        spec = attr_space.attrs_by_val[miss_type]
203                        desc = spec['name']
204                        if 'doc' in spec:
205                            desc += f" ({spec['doc']})"
206                        self.extack['miss-type'] = desc
207
208    def __repr__(self):
209        msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}\n"
210        if self.error:
211            msg += '\terror: ' + str(self.error)
212        if self.extack:
213            msg += '\textack: ' + repr(self.extack)
214        return msg
215
216
217class NlMsgs:
218    def __init__(self, data, attr_space=None):
219        self.msgs = []
220
221        offset = 0
222        while offset < len(data):
223            msg = NlMsg(data, offset, attr_space=attr_space)
224            offset += msg.nl_len
225            self.msgs.append(msg)
226
227    def __iter__(self):
228        yield from self.msgs
229
230
231genl_family_name_to_id = None
232
233
234def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
235    # we prepend length in _genl_msg_finalize()
236    if seq is None:
237        seq = random.randint(1, 1024)
238    nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
239    genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
240    return nlmsg + genlmsg
241
242
243def _genl_msg_finalize(msg):
244    return struct.pack("I", len(msg) + 4) + msg
245
246
247def _genl_load_families():
248    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
249        sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
250
251        msg = _genl_msg(Netlink.GENL_ID_CTRL,
252                        Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
253                        Netlink.CTRL_CMD_GETFAMILY, 1)
254        msg = _genl_msg_finalize(msg)
255
256        sock.send(msg, 0)
257
258        global genl_family_name_to_id
259        genl_family_name_to_id = dict()
260
261        while True:
262            reply = sock.recv(128 * 1024)
263            nms = NlMsgs(reply)
264            for nl_msg in nms:
265                if nl_msg.error:
266                    print("Netlink error:", nl_msg.error)
267                    return
268                if nl_msg.done:
269                    return
270
271                gm = GenlMsg(nl_msg)
272                fam = dict()
273                for attr in gm.raw_attrs:
274                    if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
275                        fam['id'] = attr.as_u16()
276                    elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
277                        fam['name'] = attr.as_strz()
278                    elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
279                        fam['maxattr'] = attr.as_u32()
280                    elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
281                        fam['mcast'] = dict()
282                        for entry in NlAttrs(attr.raw):
283                            mcast_name = None
284                            mcast_id = None
285                            for entry_attr in NlAttrs(entry.raw):
286                                if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
287                                    mcast_name = entry_attr.as_strz()
288                                elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
289                                    mcast_id = entry_attr.as_u32()
290                            if mcast_name and mcast_id is not None:
291                                fam['mcast'][mcast_name] = mcast_id
292                if 'name' in fam and 'id' in fam:
293                    genl_family_name_to_id[fam['name']] = fam
294
295
296class GenlMsg:
297    def __init__(self, nl_msg, fixed_header_members=[]):
298        self.nl = nl_msg
299
300        self.hdr = nl_msg.raw[0:4]
301        offset = 4
302
303        self.genl_cmd, self.genl_version, _ = struct.unpack("BBH", self.hdr)
304
305        self.fixed_header_attrs = dict()
306        for m in fixed_header_members:
307            format, size = NlAttr.type_formats[m.type]
308            decoded = struct.unpack_from(format, nl_msg.raw, offset)
309            offset += size
310            self.fixed_header_attrs[m.name] = decoded[0]
311
312        self.raw = nl_msg.raw[offset:]
313        self.raw_attrs = NlAttrs(self.raw)
314
315    def __repr__(self):
316        msg = repr(self.nl)
317        msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
318        for a in self.raw_attrs:
319            msg += '\t\t' + repr(a) + '\n'
320        return msg
321
322
323class GenlFamily:
324    def __init__(self, family_name):
325        self.family_name = family_name
326
327        global genl_family_name_to_id
328        if genl_family_name_to_id is None:
329            _genl_load_families()
330
331        self.genl_family = genl_family_name_to_id[family_name]
332        self.family_id = genl_family_name_to_id[family_name]['id']
333
334
335#
336# YNL implementation details.
337#
338
339
340class YnlFamily(SpecFamily):
341    def __init__(self, def_path, schema=None):
342        super().__init__(def_path, schema)
343
344        self.include_raw = False
345
346        self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC)
347        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
348        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
349
350        self.async_msg_ids = set()
351        self.async_msg_queue = []
352
353        for msg in self.msgs.values():
354            if msg.is_async:
355                self.async_msg_ids.add(msg.rsp_value)
356
357        for op_name, op in self.ops.items():
358            bound_f = functools.partial(self._op, op_name)
359            setattr(self, op.ident_name, bound_f)
360
361        try:
362            self.family = GenlFamily(self.yaml['name'])
363        except KeyError:
364            raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
365
366    def ntf_subscribe(self, mcast_name):
367        if mcast_name not in self.family.genl_family['mcast']:
368            raise Exception(f'Multicast group "{mcast_name}" not present in the family')
369
370        self.sock.bind((0, 0))
371        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
372                             self.family.genl_family['mcast'][mcast_name])
373
374    def _add_attr(self, space, name, value):
375        attr = self.attr_sets[space][name]
376        nl_type = attr.value
377        if attr["type"] == 'nest':
378            nl_type |= Netlink.NLA_F_NESTED
379            attr_payload = b''
380            for subname, subvalue in value.items():
381                attr_payload += self._add_attr(attr['nested-attributes'], subname, subvalue)
382        elif attr["type"] == 'flag':
383            attr_payload = b''
384        elif attr["type"] == 'u8':
385            attr_payload = struct.pack("B", int(value))
386        elif attr["type"] == 'u16':
387            endian = NlAttr.format_byte_order(attr.byte_order)
388            attr_payload = struct.pack(f"{endian}H", int(value))
389        elif attr["type"] == 'u32':
390            endian = NlAttr.format_byte_order(attr.byte_order)
391            attr_payload = struct.pack(f"{endian}I", int(value))
392        elif attr["type"] == 'u64':
393            endian = NlAttr.format_byte_order(attr.byte_order)
394            attr_payload = struct.pack(f"{endian}Q", int(value))
395        elif attr["type"] == 'string':
396            attr_payload = str(value).encode('ascii') + b'\x00'
397        elif attr["type"] == 'binary':
398            attr_payload = value
399        else:
400            raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
401
402        pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
403        return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
404
405    def _decode_enum(self, rsp, attr_spec):
406        raw = rsp[attr_spec['name']]
407        enum = self.consts[attr_spec['enum']]
408        i = attr_spec.get('value-start', 0)
409        if 'enum-as-flags' in attr_spec and attr_spec['enum-as-flags']:
410            value = set()
411            while raw:
412                if raw & 1:
413                    value.add(enum.entries_by_val[i].name)
414                raw >>= 1
415                i += 1
416        else:
417            value = enum.entries_by_val[raw - i].name
418        rsp[attr_spec['name']] = value
419
420    def _decode_binary(self, attr, attr_spec):
421        if attr_spec.struct_name:
422            decoded = attr.as_struct(self.consts[attr_spec.struct_name])
423        elif attr_spec.sub_type:
424            decoded = attr.as_c_array(attr_spec.sub_type)
425        else:
426            decoded = attr.as_bin()
427        return decoded
428
429    def _decode(self, attrs, space):
430        attr_space = self.attr_sets[space]
431        rsp = dict()
432        for attr in attrs:
433            attr_spec = attr_space.attrs_by_val[attr.type]
434            if attr_spec["type"] == 'nest':
435                subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'])
436                decoded = subdict
437            elif attr_spec['type'] == 'u8':
438                decoded = attr.as_u8()
439            elif attr_spec['type'] == 'u16':
440                decoded = attr.as_u16(attr_spec.byte_order)
441            elif attr_spec['type'] == 'u32':
442                decoded = attr.as_u32(attr_spec.byte_order)
443            elif attr_spec['type'] == 'u64':
444                decoded = attr.as_u64(attr_spec.byte_order)
445            elif attr_spec["type"] == 'string':
446                decoded = attr.as_strz()
447            elif attr_spec["type"] == 'binary':
448                decoded = self._decode_binary(attr, attr_spec)
449            elif attr_spec["type"] == 'flag':
450                decoded = True
451            else:
452                raise Exception(f'Unknown {attr.type} {attr_spec["name"]} {attr_spec["type"]}')
453
454            if not attr_spec.is_multi:
455                rsp[attr_spec['name']] = decoded
456            elif attr_spec.name in rsp:
457                rsp[attr_spec.name].append(decoded)
458            else:
459                rsp[attr_spec.name] = [decoded]
460
461            if 'enum' in attr_spec:
462                self._decode_enum(rsp, attr_spec)
463        return rsp
464
465    def _decode_extack_path(self, attrs, attr_set, offset, target):
466        for attr in attrs:
467            attr_spec = attr_set.attrs_by_val[attr.type]
468            if offset > target:
469                break
470            if offset == target:
471                return '.' + attr_spec.name
472
473            if offset + attr.full_len <= target:
474                offset += attr.full_len
475                continue
476            if attr_spec['type'] != 'nest':
477                raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
478            offset += 4
479            subpath = self._decode_extack_path(NlAttrs(attr.raw),
480                                               self.attr_sets[attr_spec['nested-attributes']],
481                                               offset, target)
482            if subpath is None:
483                return None
484            return '.' + attr_spec.name + subpath
485
486        return None
487
488    def _decode_extack(self, request, attr_space, extack):
489        if 'bad-attr-offs' not in extack:
490            return
491
492        genl_req = GenlMsg(NlMsg(request, 0, attr_space=attr_space))
493        path = self._decode_extack_path(genl_req.raw_attrs, attr_space,
494                                        20, extack['bad-attr-offs'])
495        if path:
496            del extack['bad-attr-offs']
497            extack['bad-attr'] = path
498
499    def handle_ntf(self, nl_msg, genl_msg):
500        msg = dict()
501        if self.include_raw:
502            msg['nlmsg'] = nl_msg
503            msg['genlmsg'] = genl_msg
504        op = self.rsp_by_value[genl_msg.genl_cmd]
505        msg['name'] = op['name']
506        msg['msg'] = self._decode(genl_msg.raw_attrs, op.attr_set.name)
507        self.async_msg_queue.append(msg)
508
509    def check_ntf(self):
510        while True:
511            try:
512                reply = self.sock.recv(128 * 1024, socket.MSG_DONTWAIT)
513            except BlockingIOError:
514                return
515
516            nms = NlMsgs(reply)
517            for nl_msg in nms:
518                if nl_msg.error:
519                    print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
520                    print(nl_msg)
521                    continue
522                if nl_msg.done:
523                    print("Netlink done while checking for ntf!?")
524                    continue
525
526                gm = GenlMsg(nl_msg)
527                if gm.genl_cmd not in self.async_msg_ids:
528                    print("Unexpected msg id done while checking for ntf", gm)
529                    continue
530
531                self.handle_ntf(nl_msg, gm)
532
533    def operation_do_attributes(self, name):
534      """
535      For a given operation name, find and return a supported
536      set of attributes (as a dict).
537      """
538      op = self.find_operation(name)
539      if not op:
540        return None
541
542      return op['do']['request']['attributes'].copy()
543
544    def _op(self, method, vals, dump=False):
545        op = self.ops[method]
546
547        nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
548        if dump:
549            nl_flags |= Netlink.NLM_F_DUMP
550
551        req_seq = random.randint(1024, 65535)
552        msg = _genl_msg(self.family.family_id, nl_flags, op.req_value, 1, req_seq)
553        fixed_header_members = []
554        if op.fixed_header:
555            fixed_header_members = self.consts[op.fixed_header].members
556            for m in fixed_header_members:
557                value = vals.pop(m.name)
558                format, _ = NlAttr.type_formats[m.type]
559                msg += struct.pack(format, value)
560        for name, value in vals.items():
561            msg += self._add_attr(op.attr_set.name, name, value)
562        msg = _genl_msg_finalize(msg)
563
564        self.sock.send(msg, 0)
565
566        done = False
567        rsp = []
568        while not done:
569            reply = self.sock.recv(128 * 1024)
570            nms = NlMsgs(reply, attr_space=op.attr_set)
571            for nl_msg in nms:
572                if nl_msg.extack:
573                    self._decode_extack(msg, op.attr_set, nl_msg.extack)
574
575                if nl_msg.error:
576                    raise NlError(nl_msg)
577                if nl_msg.done:
578                    if nl_msg.extack:
579                        print("Netlink warning:")
580                        print(nl_msg)
581                    done = True
582                    break
583
584                gm = GenlMsg(nl_msg, fixed_header_members)
585                # Check if this is a reply to our request
586                if nl_msg.nl_seq != req_seq or gm.genl_cmd != op.rsp_value:
587                    if gm.genl_cmd in self.async_msg_ids:
588                        self.handle_ntf(nl_msg, gm)
589                        continue
590                    else:
591                        print('Unexpected message: ' + repr(gm))
592                        continue
593
594                rsp.append(self._decode(gm.raw_attrs, op.attr_set.name)
595                           | gm.fixed_header_attrs)
596
597        if not rsp:
598            return None
599        if not dump and len(rsp) == 1:
600            return rsp[0]
601        return rsp
602
603    def do(self, method, vals):
604        return self._op(method, vals)
605
606    def dump(self, method, vals):
607        return self._op(method, vals, dump=True)
608