1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause 2 3from collections import namedtuple 4import functools 5import os 6import random 7import socket 8import struct 9from struct import Struct 10import yaml 11import ipaddress 12import uuid 13 14from .nlspec import SpecFamily 15 16# 17# Generic Netlink code which should really be in some library, but I can't quickly find one. 18# 19 20 21class Netlink: 22 # Netlink socket 23 SOL_NETLINK = 270 24 25 NETLINK_ADD_MEMBERSHIP = 1 26 NETLINK_CAP_ACK = 10 27 NETLINK_EXT_ACK = 11 28 29 # Netlink message 30 NLMSG_ERROR = 2 31 NLMSG_DONE = 3 32 33 NLM_F_REQUEST = 1 34 NLM_F_ACK = 4 35 NLM_F_ROOT = 0x100 36 NLM_F_MATCH = 0x200 37 NLM_F_APPEND = 0x800 38 39 NLM_F_CAPPED = 0x100 40 NLM_F_ACK_TLVS = 0x200 41 42 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH 43 44 NLA_F_NESTED = 0x8000 45 NLA_F_NET_BYTEORDER = 0x4000 46 47 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER 48 49 # Genetlink defines 50 NETLINK_GENERIC = 16 51 52 GENL_ID_CTRL = 0x10 53 54 # nlctrl 55 CTRL_CMD_GETFAMILY = 3 56 57 CTRL_ATTR_FAMILY_ID = 1 58 CTRL_ATTR_FAMILY_NAME = 2 59 CTRL_ATTR_MAXATTR = 5 60 CTRL_ATTR_MCAST_GROUPS = 7 61 62 CTRL_ATTR_MCAST_GRP_NAME = 1 63 CTRL_ATTR_MCAST_GRP_ID = 2 64 65 # Extack types 66 NLMSGERR_ATTR_MSG = 1 67 NLMSGERR_ATTR_OFFS = 2 68 NLMSGERR_ATTR_COOKIE = 3 69 NLMSGERR_ATTR_POLICY = 4 70 NLMSGERR_ATTR_MISS_TYPE = 5 71 NLMSGERR_ATTR_MISS_NEST = 6 72 73 74class NlError(Exception): 75 def __init__(self, nl_msg): 76 self.nl_msg = nl_msg 77 78 def __str__(self): 79 return f"Netlink error: {os.strerror(-self.nl_msg.error)}\n{self.nl_msg}" 80 81 82class NlAttr: 83 ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) 84 type_formats = { 85 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), 86 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), 87 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")), 88 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")), 89 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")), 90 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")), 91 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")), 92 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q")) 93 } 94 95 def __init__(self, raw, offset): 96 self._len, self._type = struct.unpack("HH", raw[offset:offset + 4]) 97 self.type = self._type & ~Netlink.NLA_TYPE_MASK 98 self.payload_len = self._len 99 self.full_len = (self.payload_len + 3) & ~3 100 self.raw = raw[offset + 4:offset + self.payload_len] 101 102 @classmethod 103 def get_format(cls, attr_type, byte_order=None): 104 format = cls.type_formats[attr_type] 105 if byte_order: 106 return format.big if byte_order == "big-endian" \ 107 else format.little 108 return format.native 109 110 @classmethod 111 def formatted_string(cls, raw, display_hint): 112 if display_hint == 'mac': 113 formatted = ':'.join('%02x' % b for b in raw) 114 elif display_hint == 'hex': 115 formatted = bytes.hex(raw, ' ') 116 elif display_hint in [ 'ipv4', 'ipv6' ]: 117 formatted = format(ipaddress.ip_address(raw)) 118 elif display_hint == 'uuid': 119 formatted = str(uuid.UUID(bytes=raw)) 120 else: 121 formatted = raw 122 return formatted 123 124 def as_scalar(self, attr_type, byte_order=None): 125 format = self.get_format(attr_type, byte_order) 126 return format.unpack(self.raw)[0] 127 128 def as_strz(self): 129 return self.raw.decode('ascii')[:-1] 130 131 def as_bin(self): 132 return self.raw 133 134 def as_c_array(self, type): 135 format = self.get_format(type) 136 return [ x[0] for x in format.iter_unpack(self.raw) ] 137 138 def as_struct(self, members): 139 value = dict() 140 offset = 0 141 for m in members: 142 # TODO: handle non-scalar members 143 if m.type == 'binary': 144 decoded = self.raw[offset:offset+m['len']] 145 offset += m['len'] 146 elif m.type in NlAttr.type_formats: 147 format = self.get_format(m.type, m.byte_order) 148 [ decoded ] = format.unpack_from(self.raw, offset) 149 offset += format.size 150 if m.display_hint: 151 decoded = self.formatted_string(decoded, m.display_hint) 152 value[m.name] = decoded 153 return value 154 155 def __repr__(self): 156 return f"[type:{self.type} len:{self._len}] {self.raw}" 157 158 159class NlAttrs: 160 def __init__(self, msg): 161 self.attrs = [] 162 163 offset = 0 164 while offset < len(msg): 165 attr = NlAttr(msg, offset) 166 offset += attr.full_len 167 self.attrs.append(attr) 168 169 def __iter__(self): 170 yield from self.attrs 171 172 def __repr__(self): 173 msg = '' 174 for a in self.attrs: 175 if msg: 176 msg += '\n' 177 msg += repr(a) 178 return msg 179 180 181class NlMsg: 182 def __init__(self, msg, offset, attr_space=None): 183 self.hdr = msg[offset:offset + 16] 184 185 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ 186 struct.unpack("IHHII", self.hdr) 187 188 self.raw = msg[offset + 16:offset + self.nl_len] 189 190 self.error = 0 191 self.done = 0 192 193 extack_off = None 194 if self.nl_type == Netlink.NLMSG_ERROR: 195 self.error = struct.unpack("i", self.raw[0:4])[0] 196 self.done = 1 197 extack_off = 20 198 elif self.nl_type == Netlink.NLMSG_DONE: 199 self.done = 1 200 extack_off = 4 201 202 self.extack = None 203 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: 204 self.extack = dict() 205 extack_attrs = NlAttrs(self.raw[extack_off:]) 206 for extack in extack_attrs: 207 if extack.type == Netlink.NLMSGERR_ATTR_MSG: 208 self.extack['msg'] = extack.as_strz() 209 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: 210 self.extack['miss-type'] = extack.as_scalar('u32') 211 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: 212 self.extack['miss-nest'] = extack.as_scalar('u32') 213 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: 214 self.extack['bad-attr-offs'] = extack.as_scalar('u32') 215 else: 216 if 'unknown' not in self.extack: 217 self.extack['unknown'] = [] 218 self.extack['unknown'].append(extack) 219 220 if attr_space: 221 # We don't have the ability to parse nests yet, so only do global 222 if 'miss-type' in self.extack and 'miss-nest' not in self.extack: 223 miss_type = self.extack['miss-type'] 224 if miss_type in attr_space.attrs_by_val: 225 spec = attr_space.attrs_by_val[miss_type] 226 desc = spec['name'] 227 if 'doc' in spec: 228 desc += f" ({spec['doc']})" 229 self.extack['miss-type'] = desc 230 231 def __repr__(self): 232 msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}\n" 233 if self.error: 234 msg += '\terror: ' + str(self.error) 235 if self.extack: 236 msg += '\textack: ' + repr(self.extack) 237 return msg 238 239 240class NlMsgs: 241 def __init__(self, data, attr_space=None): 242 self.msgs = [] 243 244 offset = 0 245 while offset < len(data): 246 msg = NlMsg(data, offset, attr_space=attr_space) 247 offset += msg.nl_len 248 self.msgs.append(msg) 249 250 def __iter__(self): 251 yield from self.msgs 252 253 254genl_family_name_to_id = None 255 256 257def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): 258 # we prepend length in _genl_msg_finalize() 259 if seq is None: 260 seq = random.randint(1, 1024) 261 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 262 genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0) 263 return nlmsg + genlmsg 264 265 266def _genl_msg_finalize(msg): 267 return struct.pack("I", len(msg) + 4) + msg 268 269 270def _genl_load_families(): 271 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: 272 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 273 274 msg = _genl_msg(Netlink.GENL_ID_CTRL, 275 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, 276 Netlink.CTRL_CMD_GETFAMILY, 1) 277 msg = _genl_msg_finalize(msg) 278 279 sock.send(msg, 0) 280 281 global genl_family_name_to_id 282 genl_family_name_to_id = dict() 283 284 while True: 285 reply = sock.recv(128 * 1024) 286 nms = NlMsgs(reply) 287 for nl_msg in nms: 288 if nl_msg.error: 289 print("Netlink error:", nl_msg.error) 290 return 291 if nl_msg.done: 292 return 293 294 gm = GenlMsg(nl_msg) 295 fam = dict() 296 for attr in gm.raw_attrs: 297 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: 298 fam['id'] = attr.as_scalar('u16') 299 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: 300 fam['name'] = attr.as_strz() 301 elif attr.type == Netlink.CTRL_ATTR_MAXATTR: 302 fam['maxattr'] = attr.as_scalar('u32') 303 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: 304 fam['mcast'] = dict() 305 for entry in NlAttrs(attr.raw): 306 mcast_name = None 307 mcast_id = None 308 for entry_attr in NlAttrs(entry.raw): 309 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: 310 mcast_name = entry_attr.as_strz() 311 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: 312 mcast_id = entry_attr.as_scalar('u32') 313 if mcast_name and mcast_id is not None: 314 fam['mcast'][mcast_name] = mcast_id 315 if 'name' in fam and 'id' in fam: 316 genl_family_name_to_id[fam['name']] = fam 317 318 319class GenlMsg: 320 def __init__(self, nl_msg, fixed_header_members=[]): 321 self.nl = nl_msg 322 323 self.hdr = nl_msg.raw[0:4] 324 offset = 4 325 326 self.genl_cmd, self.genl_version, _ = struct.unpack("BBH", self.hdr) 327 328 self.fixed_header_attrs = dict() 329 for m in fixed_header_members: 330 format = NlAttr.get_format(m.type, m.byte_order) 331 decoded = format.unpack_from(nl_msg.raw, offset) 332 offset += format.size 333 self.fixed_header_attrs[m.name] = decoded[0] 334 335 self.raw = nl_msg.raw[offset:] 336 self.raw_attrs = NlAttrs(self.raw) 337 338 def __repr__(self): 339 msg = repr(self.nl) 340 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" 341 for a in self.raw_attrs: 342 msg += '\t\t' + repr(a) + '\n' 343 return msg 344 345 346class GenlFamily: 347 def __init__(self, family_name): 348 self.family_name = family_name 349 350 global genl_family_name_to_id 351 if genl_family_name_to_id is None: 352 _genl_load_families() 353 354 self.genl_family = genl_family_name_to_id[family_name] 355 self.family_id = genl_family_name_to_id[family_name]['id'] 356 357 358# 359# YNL implementation details. 360# 361 362 363class YnlFamily(SpecFamily): 364 def __init__(self, def_path, schema=None): 365 super().__init__(def_path, schema) 366 367 self.include_raw = False 368 369 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) 370 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 371 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) 372 373 self.async_msg_ids = set() 374 self.async_msg_queue = [] 375 376 for msg in self.msgs.values(): 377 if msg.is_async: 378 self.async_msg_ids.add(msg.rsp_value) 379 380 for op_name, op in self.ops.items(): 381 bound_f = functools.partial(self._op, op_name) 382 setattr(self, op.ident_name, bound_f) 383 384 try: 385 self.family = GenlFamily(self.yaml['name']) 386 except KeyError: 387 raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel") 388 389 def ntf_subscribe(self, mcast_name): 390 if mcast_name not in self.family.genl_family['mcast']: 391 raise Exception(f'Multicast group "{mcast_name}" not present in the family') 392 393 self.sock.bind((0, 0)) 394 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, 395 self.family.genl_family['mcast'][mcast_name]) 396 397 def _add_attr(self, space, name, value): 398 attr = self.attr_sets[space][name] 399 nl_type = attr.value 400 if attr["type"] == 'nest': 401 nl_type |= Netlink.NLA_F_NESTED 402 attr_payload = b'' 403 for subname, subvalue in value.items(): 404 attr_payload += self._add_attr(attr['nested-attributes'], subname, subvalue) 405 elif attr["type"] == 'flag': 406 attr_payload = b'' 407 elif attr["type"] == 'string': 408 attr_payload = str(value).encode('ascii') + b'\x00' 409 elif attr["type"] == 'binary': 410 attr_payload = bytes.fromhex(value) 411 elif attr['type'] in NlAttr.type_formats: 412 format = NlAttr.get_format(attr['type'], attr.byte_order) 413 attr_payload = format.pack(int(value)) 414 else: 415 raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') 416 417 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) 418 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad 419 420 def _decode_enum(self, rsp, attr_spec): 421 raw = rsp[attr_spec['name']] 422 enum = self.consts[attr_spec['enum']] 423 i = attr_spec.get('value-start', 0) 424 if 'enum-as-flags' in attr_spec and attr_spec['enum-as-flags']: 425 value = set() 426 while raw: 427 if raw & 1: 428 value.add(enum.entries_by_val[i].name) 429 raw >>= 1 430 i += 1 431 else: 432 value = enum.entries_by_val[raw - i].name 433 rsp[attr_spec['name']] = value 434 435 def _decode_binary(self, attr, attr_spec): 436 if attr_spec.struct_name: 437 members = self.consts[attr_spec.struct_name] 438 decoded = attr.as_struct(members) 439 for m in members: 440 if m.enum: 441 self._decode_enum(decoded, m) 442 elif attr_spec.sub_type: 443 decoded = attr.as_c_array(attr_spec.sub_type) 444 else: 445 decoded = attr.as_bin() 446 if attr_spec.display_hint: 447 decoded = NlAttr.formatted_string(decoded, attr_spec.display_hint) 448 return decoded 449 450 def _decode(self, attrs, space): 451 attr_space = self.attr_sets[space] 452 rsp = dict() 453 for attr in attrs: 454 attr_spec = attr_space.attrs_by_val[attr.type] 455 if attr_spec["type"] == 'nest': 456 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes']) 457 decoded = subdict 458 elif attr_spec["type"] == 'string': 459 decoded = attr.as_strz() 460 elif attr_spec["type"] == 'binary': 461 decoded = self._decode_binary(attr, attr_spec) 462 elif attr_spec["type"] == 'flag': 463 decoded = True 464 elif attr_spec["type"] in NlAttr.type_formats: 465 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) 466 else: 467 raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') 468 469 if not attr_spec.is_multi: 470 rsp[attr_spec['name']] = decoded 471 elif attr_spec.name in rsp: 472 rsp[attr_spec.name].append(decoded) 473 else: 474 rsp[attr_spec.name] = [decoded] 475 476 if 'enum' in attr_spec: 477 self._decode_enum(rsp, attr_spec) 478 return rsp 479 480 def _decode_extack_path(self, attrs, attr_set, offset, target): 481 for attr in attrs: 482 attr_spec = attr_set.attrs_by_val[attr.type] 483 if offset > target: 484 break 485 if offset == target: 486 return '.' + attr_spec.name 487 488 if offset + attr.full_len <= target: 489 offset += attr.full_len 490 continue 491 if attr_spec['type'] != 'nest': 492 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") 493 offset += 4 494 subpath = self._decode_extack_path(NlAttrs(attr.raw), 495 self.attr_sets[attr_spec['nested-attributes']], 496 offset, target) 497 if subpath is None: 498 return None 499 return '.' + attr_spec.name + subpath 500 501 return None 502 503 def _decode_extack(self, request, attr_space, extack): 504 if 'bad-attr-offs' not in extack: 505 return 506 507 genl_req = GenlMsg(NlMsg(request, 0, attr_space=attr_space)) 508 path = self._decode_extack_path(genl_req.raw_attrs, attr_space, 509 20, extack['bad-attr-offs']) 510 if path: 511 del extack['bad-attr-offs'] 512 extack['bad-attr'] = path 513 514 def handle_ntf(self, nl_msg, genl_msg): 515 msg = dict() 516 if self.include_raw: 517 msg['nlmsg'] = nl_msg 518 msg['genlmsg'] = genl_msg 519 op = self.rsp_by_value[genl_msg.genl_cmd] 520 msg['name'] = op['name'] 521 msg['msg'] = self._decode(genl_msg.raw_attrs, op.attr_set.name) 522 self.async_msg_queue.append(msg) 523 524 def check_ntf(self): 525 while True: 526 try: 527 reply = self.sock.recv(128 * 1024, socket.MSG_DONTWAIT) 528 except BlockingIOError: 529 return 530 531 nms = NlMsgs(reply) 532 for nl_msg in nms: 533 if nl_msg.error: 534 print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) 535 print(nl_msg) 536 continue 537 if nl_msg.done: 538 print("Netlink done while checking for ntf!?") 539 continue 540 541 gm = GenlMsg(nl_msg) 542 if gm.genl_cmd not in self.async_msg_ids: 543 print("Unexpected msg id done while checking for ntf", gm) 544 continue 545 546 self.handle_ntf(nl_msg, gm) 547 548 def operation_do_attributes(self, name): 549 """ 550 For a given operation name, find and return a supported 551 set of attributes (as a dict). 552 """ 553 op = self.find_operation(name) 554 if not op: 555 return None 556 557 return op['do']['request']['attributes'].copy() 558 559 def _op(self, method, vals, dump=False): 560 op = self.ops[method] 561 562 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK 563 if dump: 564 nl_flags |= Netlink.NLM_F_DUMP 565 566 req_seq = random.randint(1024, 65535) 567 msg = _genl_msg(self.family.family_id, nl_flags, op.req_value, 1, req_seq) 568 fixed_header_members = [] 569 if op.fixed_header: 570 fixed_header_members = self.consts[op.fixed_header].members 571 for m in fixed_header_members: 572 value = vals.pop(m.name) if m.name in vals else 0 573 format = NlAttr.get_format(m.type, m.byte_order) 574 msg += format.pack(value) 575 for name, value in vals.items(): 576 msg += self._add_attr(op.attr_set.name, name, value) 577 msg = _genl_msg_finalize(msg) 578 579 self.sock.send(msg, 0) 580 581 done = False 582 rsp = [] 583 while not done: 584 reply = self.sock.recv(128 * 1024) 585 nms = NlMsgs(reply, attr_space=op.attr_set) 586 for nl_msg in nms: 587 if nl_msg.extack: 588 self._decode_extack(msg, op.attr_set, nl_msg.extack) 589 590 if nl_msg.error: 591 raise NlError(nl_msg) 592 if nl_msg.done: 593 if nl_msg.extack: 594 print("Netlink warning:") 595 print(nl_msg) 596 done = True 597 break 598 599 gm = GenlMsg(nl_msg, fixed_header_members) 600 # Check if this is a reply to our request 601 if nl_msg.nl_seq != req_seq or gm.genl_cmd != op.rsp_value: 602 if gm.genl_cmd in self.async_msg_ids: 603 self.handle_ntf(nl_msg, gm) 604 continue 605 else: 606 print('Unexpected message: ' + repr(gm)) 607 continue 608 609 rsp_msg = self._decode(gm.raw_attrs, op.attr_set.name) 610 rsp_msg.update(gm.fixed_header_attrs) 611 rsp.append(rsp_msg) 612 613 if not rsp: 614 return None 615 if not dump and len(rsp) == 1: 616 return rsp[0] 617 return rsp 618 619 def do(self, method, vals): 620 return self._op(method, vals) 621 622 def dump(self, method, vals): 623 return self._op(method, vals, dump=True) 624