1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause 2 3import functools 4import os 5import random 6import socket 7import struct 8import yaml 9 10from .nlspec import SpecFamily 11 12# 13# Generic Netlink code which should really be in some library, but I can't quickly find one. 14# 15 16 17class Netlink: 18 # Netlink socket 19 SOL_NETLINK = 270 20 21 NETLINK_ADD_MEMBERSHIP = 1 22 NETLINK_CAP_ACK = 10 23 NETLINK_EXT_ACK = 11 24 25 # Netlink message 26 NLMSG_ERROR = 2 27 NLMSG_DONE = 3 28 29 NLM_F_REQUEST = 1 30 NLM_F_ACK = 4 31 NLM_F_ROOT = 0x100 32 NLM_F_MATCH = 0x200 33 NLM_F_APPEND = 0x800 34 35 NLM_F_CAPPED = 0x100 36 NLM_F_ACK_TLVS = 0x200 37 38 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH 39 40 NLA_F_NESTED = 0x8000 41 NLA_F_NET_BYTEORDER = 0x4000 42 43 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER 44 45 # Genetlink defines 46 NETLINK_GENERIC = 16 47 48 GENL_ID_CTRL = 0x10 49 50 # nlctrl 51 CTRL_CMD_GETFAMILY = 3 52 53 CTRL_ATTR_FAMILY_ID = 1 54 CTRL_ATTR_FAMILY_NAME = 2 55 CTRL_ATTR_MAXATTR = 5 56 CTRL_ATTR_MCAST_GROUPS = 7 57 58 CTRL_ATTR_MCAST_GRP_NAME = 1 59 CTRL_ATTR_MCAST_GRP_ID = 2 60 61 # Extack types 62 NLMSGERR_ATTR_MSG = 1 63 NLMSGERR_ATTR_OFFS = 2 64 NLMSGERR_ATTR_COOKIE = 3 65 NLMSGERR_ATTR_POLICY = 4 66 NLMSGERR_ATTR_MISS_TYPE = 5 67 NLMSGERR_ATTR_MISS_NEST = 6 68 69 70class NlError(Exception): 71 def __init__(self, nl_msg): 72 self.nl_msg = nl_msg 73 74 def __str__(self): 75 return f"Netlink error: {os.strerror(-self.nl_msg.error)}\n{self.nl_msg}" 76 77 78class NlAttr: 79 type_formats = { 'u8' : ('B', 1), 's8' : ('b', 1), 80 'u16': ('H', 2), 's16': ('h', 2), 81 'u32': ('I', 4), 's32': ('i', 4), 82 'u64': ('Q', 8), 's64': ('q', 8) } 83 84 def __init__(self, raw, offset): 85 self._len, self._type = struct.unpack("HH", raw[offset:offset + 4]) 86 self.type = self._type & ~Netlink.NLA_TYPE_MASK 87 self.payload_len = self._len 88 self.full_len = (self.payload_len + 3) & ~3 89 self.raw = raw[offset + 4:offset + self.payload_len] 90 91 def format_byte_order(byte_order): 92 if byte_order: 93 return ">" if byte_order == "big-endian" else "<" 94 return "" 95 96 def as_u8(self): 97 return struct.unpack("B", self.raw)[0] 98 99 def as_u16(self, byte_order=None): 100 endian = NlAttr.format_byte_order(byte_order) 101 return struct.unpack(f"{endian}H", self.raw)[0] 102 103 def as_u32(self, byte_order=None): 104 endian = NlAttr.format_byte_order(byte_order) 105 return struct.unpack(f"{endian}I", self.raw)[0] 106 107 def as_u64(self, byte_order=None): 108 endian = NlAttr.format_byte_order(byte_order) 109 return struct.unpack(f"{endian}Q", self.raw)[0] 110 111 def as_strz(self): 112 return self.raw.decode('ascii')[:-1] 113 114 def as_bin(self): 115 return self.raw 116 117 def as_c_array(self, type): 118 format, _ = self.type_formats[type] 119 return list({ x[0] for x in struct.iter_unpack(format, self.raw) }) 120 121 def as_struct(self, members): 122 value = dict() 123 offset = 0 124 for m in members: 125 # TODO: handle non-scalar members 126 format, size = self.type_formats[m.type] 127 decoded = struct.unpack_from(format, self.raw, offset) 128 offset += size 129 value[m.name] = decoded[0] 130 return value 131 132 def __repr__(self): 133 return f"[type:{self.type} len:{self._len}] {self.raw}" 134 135 136class NlAttrs: 137 def __init__(self, msg): 138 self.attrs = [] 139 140 offset = 0 141 while offset < len(msg): 142 attr = NlAttr(msg, offset) 143 offset += attr.full_len 144 self.attrs.append(attr) 145 146 def __iter__(self): 147 yield from self.attrs 148 149 def __repr__(self): 150 msg = '' 151 for a in self.attrs: 152 if msg: 153 msg += '\n' 154 msg += repr(a) 155 return msg 156 157 158class NlMsg: 159 def __init__(self, msg, offset, attr_space=None): 160 self.hdr = msg[offset:offset + 16] 161 162 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ 163 struct.unpack("IHHII", self.hdr) 164 165 self.raw = msg[offset + 16:offset + self.nl_len] 166 167 self.error = 0 168 self.done = 0 169 170 extack_off = None 171 if self.nl_type == Netlink.NLMSG_ERROR: 172 self.error = struct.unpack("i", self.raw[0:4])[0] 173 self.done = 1 174 extack_off = 20 175 elif self.nl_type == Netlink.NLMSG_DONE: 176 self.done = 1 177 extack_off = 4 178 179 self.extack = None 180 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: 181 self.extack = dict() 182 extack_attrs = NlAttrs(self.raw[extack_off:]) 183 for extack in extack_attrs: 184 if extack.type == Netlink.NLMSGERR_ATTR_MSG: 185 self.extack['msg'] = extack.as_strz() 186 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: 187 self.extack['miss-type'] = extack.as_u32() 188 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: 189 self.extack['miss-nest'] = extack.as_u32() 190 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: 191 self.extack['bad-attr-offs'] = extack.as_u32() 192 else: 193 if 'unknown' not in self.extack: 194 self.extack['unknown'] = [] 195 self.extack['unknown'].append(extack) 196 197 if attr_space: 198 # We don't have the ability to parse nests yet, so only do global 199 if 'miss-type' in self.extack and 'miss-nest' not in self.extack: 200 miss_type = self.extack['miss-type'] 201 if miss_type in attr_space.attrs_by_val: 202 spec = attr_space.attrs_by_val[miss_type] 203 desc = spec['name'] 204 if 'doc' in spec: 205 desc += f" ({spec['doc']})" 206 self.extack['miss-type'] = desc 207 208 def __repr__(self): 209 msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}\n" 210 if self.error: 211 msg += '\terror: ' + str(self.error) 212 if self.extack: 213 msg += '\textack: ' + repr(self.extack) 214 return msg 215 216 217class NlMsgs: 218 def __init__(self, data, attr_space=None): 219 self.msgs = [] 220 221 offset = 0 222 while offset < len(data): 223 msg = NlMsg(data, offset, attr_space=attr_space) 224 offset += msg.nl_len 225 self.msgs.append(msg) 226 227 def __iter__(self): 228 yield from self.msgs 229 230 231genl_family_name_to_id = None 232 233 234def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): 235 # we prepend length in _genl_msg_finalize() 236 if seq is None: 237 seq = random.randint(1, 1024) 238 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 239 genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0) 240 return nlmsg + genlmsg 241 242 243def _genl_msg_finalize(msg): 244 return struct.pack("I", len(msg) + 4) + msg 245 246 247def _genl_load_families(): 248 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: 249 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 250 251 msg = _genl_msg(Netlink.GENL_ID_CTRL, 252 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, 253 Netlink.CTRL_CMD_GETFAMILY, 1) 254 msg = _genl_msg_finalize(msg) 255 256 sock.send(msg, 0) 257 258 global genl_family_name_to_id 259 genl_family_name_to_id = dict() 260 261 while True: 262 reply = sock.recv(128 * 1024) 263 nms = NlMsgs(reply) 264 for nl_msg in nms: 265 if nl_msg.error: 266 print("Netlink error:", nl_msg.error) 267 return 268 if nl_msg.done: 269 return 270 271 gm = GenlMsg(nl_msg) 272 fam = dict() 273 for attr in gm.raw_attrs: 274 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: 275 fam['id'] = attr.as_u16() 276 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: 277 fam['name'] = attr.as_strz() 278 elif attr.type == Netlink.CTRL_ATTR_MAXATTR: 279 fam['maxattr'] = attr.as_u32() 280 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: 281 fam['mcast'] = dict() 282 for entry in NlAttrs(attr.raw): 283 mcast_name = None 284 mcast_id = None 285 for entry_attr in NlAttrs(entry.raw): 286 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: 287 mcast_name = entry_attr.as_strz() 288 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: 289 mcast_id = entry_attr.as_u32() 290 if mcast_name and mcast_id is not None: 291 fam['mcast'][mcast_name] = mcast_id 292 if 'name' in fam and 'id' in fam: 293 genl_family_name_to_id[fam['name']] = fam 294 295 296class GenlMsg: 297 def __init__(self, nl_msg, fixed_header_members=[]): 298 self.nl = nl_msg 299 300 self.hdr = nl_msg.raw[0:4] 301 offset = 4 302 303 self.genl_cmd, self.genl_version, _ = struct.unpack("BBH", self.hdr) 304 305 self.fixed_header_attrs = dict() 306 for m in fixed_header_members: 307 format, size = NlAttr.type_formats[m.type] 308 decoded = struct.unpack_from(format, nl_msg.raw, offset) 309 offset += size 310 self.fixed_header_attrs[m.name] = decoded[0] 311 312 self.raw = nl_msg.raw[offset:] 313 self.raw_attrs = NlAttrs(self.raw) 314 315 def __repr__(self): 316 msg = repr(self.nl) 317 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" 318 for a in self.raw_attrs: 319 msg += '\t\t' + repr(a) + '\n' 320 return msg 321 322 323class GenlFamily: 324 def __init__(self, family_name): 325 self.family_name = family_name 326 327 global genl_family_name_to_id 328 if genl_family_name_to_id is None: 329 _genl_load_families() 330 331 self.genl_family = genl_family_name_to_id[family_name] 332 self.family_id = genl_family_name_to_id[family_name]['id'] 333 334 335# 336# YNL implementation details. 337# 338 339 340class YnlFamily(SpecFamily): 341 def __init__(self, def_path, schema=None): 342 super().__init__(def_path, schema) 343 344 self.include_raw = False 345 346 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) 347 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 348 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) 349 350 self.async_msg_ids = set() 351 self.async_msg_queue = [] 352 353 for msg in self.msgs.values(): 354 if msg.is_async: 355 self.async_msg_ids.add(msg.rsp_value) 356 357 for op_name, op in self.ops.items(): 358 bound_f = functools.partial(self._op, op_name) 359 setattr(self, op.ident_name, bound_f) 360 361 try: 362 self.family = GenlFamily(self.yaml['name']) 363 except KeyError: 364 raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel") 365 366 def ntf_subscribe(self, mcast_name): 367 if mcast_name not in self.family.genl_family['mcast']: 368 raise Exception(f'Multicast group "{mcast_name}" not present in the family') 369 370 self.sock.bind((0, 0)) 371 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, 372 self.family.genl_family['mcast'][mcast_name]) 373 374 def _add_attr(self, space, name, value): 375 attr = self.attr_sets[space][name] 376 nl_type = attr.value 377 if attr["type"] == 'nest': 378 nl_type |= Netlink.NLA_F_NESTED 379 attr_payload = b'' 380 for subname, subvalue in value.items(): 381 attr_payload += self._add_attr(attr['nested-attributes'], subname, subvalue) 382 elif attr["type"] == 'flag': 383 attr_payload = b'' 384 elif attr["type"] == 'u8': 385 attr_payload = struct.pack("B", int(value)) 386 elif attr["type"] == 'u16': 387 endian = NlAttr.format_byte_order(attr.byte_order) 388 attr_payload = struct.pack(f"{endian}H", int(value)) 389 elif attr["type"] == 'u32': 390 endian = NlAttr.format_byte_order(attr.byte_order) 391 attr_payload = struct.pack(f"{endian}I", int(value)) 392 elif attr["type"] == 'u64': 393 endian = NlAttr.format_byte_order(attr.byte_order) 394 attr_payload = struct.pack(f"{endian}Q", int(value)) 395 elif attr["type"] == 'string': 396 attr_payload = str(value).encode('ascii') + b'\x00' 397 elif attr["type"] == 'binary': 398 attr_payload = value 399 else: 400 raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') 401 402 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) 403 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad 404 405 def _decode_enum(self, rsp, attr_spec): 406 raw = rsp[attr_spec['name']] 407 enum = self.consts[attr_spec['enum']] 408 i = attr_spec.get('value-start', 0) 409 if 'enum-as-flags' in attr_spec and attr_spec['enum-as-flags']: 410 value = set() 411 while raw: 412 if raw & 1: 413 value.add(enum.entries_by_val[i].name) 414 raw >>= 1 415 i += 1 416 else: 417 value = enum.entries_by_val[raw - i].name 418 rsp[attr_spec['name']] = value 419 420 def _decode_binary(self, attr, attr_spec): 421 if attr_spec.struct_name: 422 decoded = attr.as_struct(self.consts[attr_spec.struct_name]) 423 elif attr_spec.sub_type: 424 decoded = attr.as_c_array(attr_spec.sub_type) 425 else: 426 decoded = attr.as_bin() 427 return decoded 428 429 def _decode(self, attrs, space): 430 attr_space = self.attr_sets[space] 431 rsp = dict() 432 for attr in attrs: 433 attr_spec = attr_space.attrs_by_val[attr.type] 434 if attr_spec["type"] == 'nest': 435 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes']) 436 decoded = subdict 437 elif attr_spec['type'] == 'u8': 438 decoded = attr.as_u8() 439 elif attr_spec['type'] == 'u16': 440 decoded = attr.as_u16(attr_spec.byte_order) 441 elif attr_spec['type'] == 'u32': 442 decoded = attr.as_u32(attr_spec.byte_order) 443 elif attr_spec['type'] == 'u64': 444 decoded = attr.as_u64(attr_spec.byte_order) 445 elif attr_spec["type"] == 'string': 446 decoded = attr.as_strz() 447 elif attr_spec["type"] == 'binary': 448 decoded = self._decode_binary(attr, attr_spec) 449 elif attr_spec["type"] == 'flag': 450 decoded = True 451 else: 452 raise Exception(f'Unknown {attr.type} {attr_spec["name"]} {attr_spec["type"]}') 453 454 if not attr_spec.is_multi: 455 rsp[attr_spec['name']] = decoded 456 elif attr_spec.name in rsp: 457 rsp[attr_spec.name].append(decoded) 458 else: 459 rsp[attr_spec.name] = [decoded] 460 461 if 'enum' in attr_spec: 462 self._decode_enum(rsp, attr_spec) 463 return rsp 464 465 def _decode_extack_path(self, attrs, attr_set, offset, target): 466 for attr in attrs: 467 attr_spec = attr_set.attrs_by_val[attr.type] 468 if offset > target: 469 break 470 if offset == target: 471 return '.' + attr_spec.name 472 473 if offset + attr.full_len <= target: 474 offset += attr.full_len 475 continue 476 if attr_spec['type'] != 'nest': 477 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") 478 offset += 4 479 subpath = self._decode_extack_path(NlAttrs(attr.raw), 480 self.attr_sets[attr_spec['nested-attributes']], 481 offset, target) 482 if subpath is None: 483 return None 484 return '.' + attr_spec.name + subpath 485 486 return None 487 488 def _decode_extack(self, request, attr_space, extack): 489 if 'bad-attr-offs' not in extack: 490 return 491 492 genl_req = GenlMsg(NlMsg(request, 0, attr_space=attr_space)) 493 path = self._decode_extack_path(genl_req.raw_attrs, attr_space, 494 20, extack['bad-attr-offs']) 495 if path: 496 del extack['bad-attr-offs'] 497 extack['bad-attr'] = path 498 499 def handle_ntf(self, nl_msg, genl_msg): 500 msg = dict() 501 if self.include_raw: 502 msg['nlmsg'] = nl_msg 503 msg['genlmsg'] = genl_msg 504 op = self.rsp_by_value[genl_msg.genl_cmd] 505 msg['name'] = op['name'] 506 msg['msg'] = self._decode(genl_msg.raw_attrs, op.attr_set.name) 507 self.async_msg_queue.append(msg) 508 509 def check_ntf(self): 510 while True: 511 try: 512 reply = self.sock.recv(128 * 1024, socket.MSG_DONTWAIT) 513 except BlockingIOError: 514 return 515 516 nms = NlMsgs(reply) 517 for nl_msg in nms: 518 if nl_msg.error: 519 print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) 520 print(nl_msg) 521 continue 522 if nl_msg.done: 523 print("Netlink done while checking for ntf!?") 524 continue 525 526 gm = GenlMsg(nl_msg) 527 if gm.genl_cmd not in self.async_msg_ids: 528 print("Unexpected msg id done while checking for ntf", gm) 529 continue 530 531 self.handle_ntf(nl_msg, gm) 532 533 def operation_do_attributes(self, name): 534 """ 535 For a given operation name, find and return a supported 536 set of attributes (as a dict). 537 """ 538 op = self.find_operation(name) 539 if not op: 540 return None 541 542 return op['do']['request']['attributes'].copy() 543 544 def _op(self, method, vals, dump=False): 545 op = self.ops[method] 546 547 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK 548 if dump: 549 nl_flags |= Netlink.NLM_F_DUMP 550 551 req_seq = random.randint(1024, 65535) 552 msg = _genl_msg(self.family.family_id, nl_flags, op.req_value, 1, req_seq) 553 fixed_header_members = [] 554 if op.fixed_header: 555 fixed_header_members = self.consts[op.fixed_header].members 556 for m in fixed_header_members: 557 value = vals.pop(m.name) 558 format, _ = NlAttr.type_formats[m.type] 559 msg += struct.pack(format, value) 560 for name, value in vals.items(): 561 msg += self._add_attr(op.attr_set.name, name, value) 562 msg = _genl_msg_finalize(msg) 563 564 self.sock.send(msg, 0) 565 566 done = False 567 rsp = [] 568 while not done: 569 reply = self.sock.recv(128 * 1024) 570 nms = NlMsgs(reply, attr_space=op.attr_set) 571 for nl_msg in nms: 572 if nl_msg.extack: 573 self._decode_extack(msg, op.attr_set, nl_msg.extack) 574 575 if nl_msg.error: 576 raise NlError(nl_msg) 577 if nl_msg.done: 578 if nl_msg.extack: 579 print("Netlink warning:") 580 print(nl_msg) 581 done = True 582 break 583 584 gm = GenlMsg(nl_msg, fixed_header_members) 585 # Check if this is a reply to our request 586 if nl_msg.nl_seq != req_seq or gm.genl_cmd != op.rsp_value: 587 if gm.genl_cmd in self.async_msg_ids: 588 self.handle_ntf(nl_msg, gm) 589 continue 590 else: 591 print('Unexpected message: ' + repr(gm)) 592 continue 593 594 rsp.append(self._decode(gm.raw_attrs, op.attr_set.name) 595 | gm.fixed_header_attrs) 596 597 if not rsp: 598 return None 599 if not dump and len(rsp) == 1: 600 return rsp[0] 601 return rsp 602 603 def do(self, method, vals): 604 return self._op(method, vals) 605 606 def dump(self, method, vals): 607 return self._op(method, vals, dump=True) 608