1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause 2 3from collections import namedtuple 4from enum import Enum 5import functools 6import os 7import random 8import socket 9import struct 10from struct import Struct 11import sys 12import yaml 13import ipaddress 14import uuid 15import queue 16import selectors 17import time 18 19from .nlspec import SpecFamily 20 21# 22# Generic Netlink code which should really be in some library, but I can't quickly find one. 23# 24 25 26class Netlink: 27 # Netlink socket 28 SOL_NETLINK = 270 29 30 NETLINK_ADD_MEMBERSHIP = 1 31 NETLINK_CAP_ACK = 10 32 NETLINK_EXT_ACK = 11 33 NETLINK_GET_STRICT_CHK = 12 34 35 # Netlink message 36 NLMSG_ERROR = 2 37 NLMSG_DONE = 3 38 39 NLM_F_REQUEST = 1 40 NLM_F_ACK = 4 41 NLM_F_ROOT = 0x100 42 NLM_F_MATCH = 0x200 43 44 NLM_F_REPLACE = 0x100 45 NLM_F_EXCL = 0x200 46 NLM_F_CREATE = 0x400 47 NLM_F_APPEND = 0x800 48 49 NLM_F_CAPPED = 0x100 50 NLM_F_ACK_TLVS = 0x200 51 52 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH 53 54 NLA_F_NESTED = 0x8000 55 NLA_F_NET_BYTEORDER = 0x4000 56 57 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER 58 59 # Genetlink defines 60 NETLINK_GENERIC = 16 61 62 GENL_ID_CTRL = 0x10 63 64 # nlctrl 65 CTRL_CMD_GETFAMILY = 3 66 67 CTRL_ATTR_FAMILY_ID = 1 68 CTRL_ATTR_FAMILY_NAME = 2 69 CTRL_ATTR_MAXATTR = 5 70 CTRL_ATTR_MCAST_GROUPS = 7 71 72 CTRL_ATTR_MCAST_GRP_NAME = 1 73 CTRL_ATTR_MCAST_GRP_ID = 2 74 75 # Extack types 76 NLMSGERR_ATTR_MSG = 1 77 NLMSGERR_ATTR_OFFS = 2 78 NLMSGERR_ATTR_COOKIE = 3 79 NLMSGERR_ATTR_POLICY = 4 80 NLMSGERR_ATTR_MISS_TYPE = 5 81 NLMSGERR_ATTR_MISS_NEST = 6 82 83 # Policy types 84 NL_POLICY_TYPE_ATTR_TYPE = 1 85 NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2 86 NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3 87 NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4 88 NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5 89 NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6 90 NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7 91 NL_POLICY_TYPE_ATTR_POLICY_IDX = 8 92 NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9 93 NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10 94 NL_POLICY_TYPE_ATTR_PAD = 11 95 NL_POLICY_TYPE_ATTR_MASK = 12 96 97 AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64', 98 's8', 's16', 's32', 's64', 99 'binary', 'string', 'nul-string', 100 'nested', 'nested-array', 101 'bitfield32', 'sint', 'uint']) 102 103class NlError(Exception): 104 def __init__(self, nl_msg): 105 self.nl_msg = nl_msg 106 self.error = -nl_msg.error 107 108 def __str__(self): 109 return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}" 110 111 112class ConfigError(Exception): 113 pass 114 115 116class NlAttr: 117 ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) 118 type_formats = { 119 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), 120 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), 121 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")), 122 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")), 123 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")), 124 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")), 125 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")), 126 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q")) 127 } 128 129 def __init__(self, raw, offset): 130 self._len, self._type = struct.unpack("HH", raw[offset : offset + 4]) 131 self.type = self._type & ~Netlink.NLA_TYPE_MASK 132 self.is_nest = self._type & Netlink.NLA_F_NESTED 133 self.payload_len = self._len 134 self.full_len = (self.payload_len + 3) & ~3 135 self.raw = raw[offset + 4 : offset + self.payload_len] 136 137 @classmethod 138 def get_format(cls, attr_type, byte_order=None): 139 format = cls.type_formats[attr_type] 140 if byte_order: 141 return format.big if byte_order == "big-endian" \ 142 else format.little 143 return format.native 144 145 def as_scalar(self, attr_type, byte_order=None): 146 format = self.get_format(attr_type, byte_order) 147 return format.unpack(self.raw)[0] 148 149 def as_auto_scalar(self, attr_type, byte_order=None): 150 if len(self.raw) != 4 and len(self.raw) != 8: 151 raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}") 152 real_type = attr_type[0] + str(len(self.raw) * 8) 153 format = self.get_format(real_type, byte_order) 154 return format.unpack(self.raw)[0] 155 156 def as_strz(self): 157 return self.raw.decode('ascii')[:-1] 158 159 def as_bin(self): 160 return self.raw 161 162 def as_c_array(self, type): 163 format = self.get_format(type) 164 return [ x[0] for x in format.iter_unpack(self.raw) ] 165 166 def __repr__(self): 167 return f"[type:{self.type} len:{self._len}] {self.raw}" 168 169 170class NlAttrs: 171 def __init__(self, msg, offset=0): 172 self.attrs = [] 173 174 while offset < len(msg): 175 attr = NlAttr(msg, offset) 176 offset += attr.full_len 177 self.attrs.append(attr) 178 179 def __iter__(self): 180 yield from self.attrs 181 182 def __repr__(self): 183 msg = '' 184 for a in self.attrs: 185 if msg: 186 msg += '\n' 187 msg += repr(a) 188 return msg 189 190 191class NlMsg: 192 def __init__(self, msg, offset, attr_space=None): 193 self.hdr = msg[offset : offset + 16] 194 195 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ 196 struct.unpack("IHHII", self.hdr) 197 198 self.raw = msg[offset + 16 : offset + self.nl_len] 199 200 self.error = 0 201 self.done = 0 202 203 extack_off = None 204 if self.nl_type == Netlink.NLMSG_ERROR: 205 self.error = struct.unpack("i", self.raw[0:4])[0] 206 self.done = 1 207 extack_off = 20 208 elif self.nl_type == Netlink.NLMSG_DONE: 209 self.error = struct.unpack("i", self.raw[0:4])[0] 210 self.done = 1 211 extack_off = 4 212 213 self.extack = None 214 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: 215 self.extack = dict() 216 extack_attrs = NlAttrs(self.raw[extack_off:]) 217 for extack in extack_attrs: 218 if extack.type == Netlink.NLMSGERR_ATTR_MSG: 219 self.extack['msg'] = extack.as_strz() 220 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: 221 self.extack['miss-type'] = extack.as_scalar('u32') 222 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: 223 self.extack['miss-nest'] = extack.as_scalar('u32') 224 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: 225 self.extack['bad-attr-offs'] = extack.as_scalar('u32') 226 elif extack.type == Netlink.NLMSGERR_ATTR_POLICY: 227 self.extack['policy'] = self._decode_policy(extack.raw) 228 else: 229 if 'unknown' not in self.extack: 230 self.extack['unknown'] = [] 231 self.extack['unknown'].append(extack) 232 233 if attr_space: 234 self.annotate_extack(attr_space) 235 236 def _decode_policy(self, raw): 237 policy = {} 238 for attr in NlAttrs(raw): 239 if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE: 240 type = attr.as_scalar('u32') 241 policy['type'] = Netlink.AttrType(type).name 242 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S: 243 policy['min-value'] = attr.as_scalar('s64') 244 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S: 245 policy['max-value'] = attr.as_scalar('s64') 246 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U: 247 policy['min-value'] = attr.as_scalar('u64') 248 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U: 249 policy['max-value'] = attr.as_scalar('u64') 250 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH: 251 policy['min-length'] = attr.as_scalar('u32') 252 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH: 253 policy['max-length'] = attr.as_scalar('u32') 254 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK: 255 policy['bitfield32-mask'] = attr.as_scalar('u32') 256 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK: 257 policy['mask'] = attr.as_scalar('u64') 258 return policy 259 260 def annotate_extack(self, attr_space): 261 """ Make extack more human friendly with attribute information """ 262 263 # We don't have the ability to parse nests yet, so only do global 264 if 'miss-type' in self.extack and 'miss-nest' not in self.extack: 265 miss_type = self.extack['miss-type'] 266 if miss_type in attr_space.attrs_by_val: 267 spec = attr_space.attrs_by_val[miss_type] 268 self.extack['miss-type'] = spec['name'] 269 if 'doc' in spec: 270 self.extack['miss-type-doc'] = spec['doc'] 271 272 def cmd(self): 273 return self.nl_type 274 275 def __repr__(self): 276 msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}" 277 if self.error: 278 msg += '\n\terror: ' + str(self.error) 279 if self.extack: 280 msg += '\n\textack: ' + repr(self.extack) 281 return msg 282 283 284class NlMsgs: 285 def __init__(self, data): 286 self.msgs = [] 287 288 offset = 0 289 while offset < len(data): 290 msg = NlMsg(data, offset) 291 offset += msg.nl_len 292 self.msgs.append(msg) 293 294 def __iter__(self): 295 yield from self.msgs 296 297 298genl_family_name_to_id = None 299 300 301def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): 302 # we prepend length in _genl_msg_finalize() 303 if seq is None: 304 seq = random.randint(1, 1024) 305 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 306 genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0) 307 return nlmsg + genlmsg 308 309 310def _genl_msg_finalize(msg): 311 return struct.pack("I", len(msg) + 4) + msg 312 313 314def _genl_load_families(): 315 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: 316 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 317 318 msg = _genl_msg(Netlink.GENL_ID_CTRL, 319 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, 320 Netlink.CTRL_CMD_GETFAMILY, 1) 321 msg = _genl_msg_finalize(msg) 322 323 sock.send(msg, 0) 324 325 global genl_family_name_to_id 326 genl_family_name_to_id = dict() 327 328 while True: 329 reply = sock.recv(128 * 1024) 330 nms = NlMsgs(reply) 331 for nl_msg in nms: 332 if nl_msg.error: 333 print("Netlink error:", nl_msg.error) 334 return 335 if nl_msg.done: 336 return 337 338 gm = GenlMsg(nl_msg) 339 fam = dict() 340 for attr in NlAttrs(gm.raw): 341 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: 342 fam['id'] = attr.as_scalar('u16') 343 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: 344 fam['name'] = attr.as_strz() 345 elif attr.type == Netlink.CTRL_ATTR_MAXATTR: 346 fam['maxattr'] = attr.as_scalar('u32') 347 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: 348 fam['mcast'] = dict() 349 for entry in NlAttrs(attr.raw): 350 mcast_name = None 351 mcast_id = None 352 for entry_attr in NlAttrs(entry.raw): 353 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: 354 mcast_name = entry_attr.as_strz() 355 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: 356 mcast_id = entry_attr.as_scalar('u32') 357 if mcast_name and mcast_id is not None: 358 fam['mcast'][mcast_name] = mcast_id 359 if 'name' in fam and 'id' in fam: 360 genl_family_name_to_id[fam['name']] = fam 361 362 363class GenlMsg: 364 def __init__(self, nl_msg): 365 self.nl = nl_msg 366 self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0) 367 self.raw = nl_msg.raw[4:] 368 369 def cmd(self): 370 return self.genl_cmd 371 372 def __repr__(self): 373 msg = repr(self.nl) 374 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" 375 for a in self.raw_attrs: 376 msg += '\t\t' + repr(a) + '\n' 377 return msg 378 379 380class NetlinkProtocol: 381 def __init__(self, family_name, proto_num): 382 self.family_name = family_name 383 self.proto_num = proto_num 384 385 def _message(self, nl_type, nl_flags, seq=None): 386 if seq is None: 387 seq = random.randint(1, 1024) 388 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 389 return nlmsg 390 391 def message(self, flags, command, version, seq=None): 392 return self._message(command, flags, seq) 393 394 def _decode(self, nl_msg): 395 return nl_msg 396 397 def decode(self, ynl, nl_msg, op): 398 msg = self._decode(nl_msg) 399 if op is None: 400 op = ynl.rsp_by_value[msg.cmd()] 401 fixed_header_size = ynl._struct_size(op.fixed_header) 402 msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size) 403 return msg 404 405 def get_mcast_id(self, mcast_name, mcast_groups): 406 if mcast_name not in mcast_groups: 407 raise Exception(f'Multicast group "{mcast_name}" not present in the spec') 408 return mcast_groups[mcast_name].value 409 410 def msghdr_size(self): 411 return 16 412 413 414class GenlProtocol(NetlinkProtocol): 415 def __init__(self, family_name): 416 super().__init__(family_name, Netlink.NETLINK_GENERIC) 417 418 global genl_family_name_to_id 419 if genl_family_name_to_id is None: 420 _genl_load_families() 421 422 self.genl_family = genl_family_name_to_id[family_name] 423 self.family_id = genl_family_name_to_id[family_name]['id'] 424 425 def message(self, flags, command, version, seq=None): 426 nlmsg = self._message(self.family_id, flags, seq) 427 genlmsg = struct.pack("BBH", command, version, 0) 428 return nlmsg + genlmsg 429 430 def _decode(self, nl_msg): 431 return GenlMsg(nl_msg) 432 433 def get_mcast_id(self, mcast_name, mcast_groups): 434 if mcast_name not in self.genl_family['mcast']: 435 raise Exception(f'Multicast group "{mcast_name}" not present in the family') 436 return self.genl_family['mcast'][mcast_name] 437 438 def msghdr_size(self): 439 return super().msghdr_size() + 4 440 441 442class SpaceAttrs: 443 SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values']) 444 445 def __init__(self, attr_space, attrs, outer = None): 446 outer_scopes = outer.scopes if outer else [] 447 inner_scope = self.SpecValuesPair(attr_space, attrs) 448 self.scopes = [inner_scope] + outer_scopes 449 450 def lookup(self, name): 451 for scope in self.scopes: 452 if name in scope.spec: 453 if name in scope.values: 454 return scope.values[name] 455 spec_name = scope.spec.yaml['name'] 456 raise Exception( 457 f"No value for '{name}' in attribute space '{spec_name}'") 458 raise Exception(f"Attribute '{name}' not defined in any attribute-set") 459 460 461# 462# YNL implementation details. 463# 464 465 466class YnlFamily(SpecFamily): 467 def __init__(self, def_path, schema=None, process_unknown=False, 468 recv_size=0): 469 super().__init__(def_path, schema) 470 471 self.include_raw = False 472 self.process_unknown = process_unknown 473 474 try: 475 if self.proto == "netlink-raw": 476 self.nlproto = NetlinkProtocol(self.yaml['name'], 477 self.yaml['protonum']) 478 else: 479 self.nlproto = GenlProtocol(self.yaml['name']) 480 except KeyError: 481 raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel") 482 483 self._recv_dbg = False 484 # Note that netlink will use conservative (min) message size for 485 # the first dump recv() on the socket, our setting will only matter 486 # from the second recv() on. 487 self._recv_size = recv_size if recv_size else 131072 488 # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo) 489 # for a message, so smaller receive sizes will lead to truncation. 490 # Note that the min size for other families may be larger than 4k! 491 if self._recv_size < 4000: 492 raise ConfigError() 493 494 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num) 495 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 496 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) 497 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1) 498 499 self.async_msg_ids = set() 500 self.async_msg_queue = queue.Queue() 501 502 for msg in self.msgs.values(): 503 if msg.is_async: 504 self.async_msg_ids.add(msg.rsp_value) 505 506 for op_name, op in self.ops.items(): 507 bound_f = functools.partial(self._op, op_name) 508 setattr(self, op.ident_name, bound_f) 509 510 511 def ntf_subscribe(self, mcast_name): 512 mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups) 513 self.sock.bind((0, 0)) 514 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, 515 mcast_id) 516 517 def set_recv_dbg(self, enabled): 518 self._recv_dbg = enabled 519 520 def _recv_dbg_print(self, reply, nl_msgs): 521 if not self._recv_dbg: 522 return 523 print("Recv: read", len(reply), "bytes,", 524 len(nl_msgs.msgs), "messages", file=sys.stderr) 525 for nl_msg in nl_msgs: 526 print(" ", nl_msg, file=sys.stderr) 527 528 def _encode_enum(self, attr_spec, value): 529 enum = self.consts[attr_spec['enum']] 530 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 531 scalar = 0 532 if isinstance(value, str): 533 value = [value] 534 for single_value in value: 535 scalar += enum.entries[single_value].user_value(as_flags = True) 536 return scalar 537 else: 538 return enum.entries[value].user_value() 539 540 def _get_scalar(self, attr_spec, value): 541 try: 542 return int(value) 543 except (ValueError, TypeError) as e: 544 if 'enum' in attr_spec: 545 return self._encode_enum(attr_spec, value) 546 if attr_spec.display_hint: 547 return self._from_string(value, attr_spec) 548 raise e 549 550 def _add_attr(self, space, name, value, search_attrs): 551 try: 552 attr = self.attr_sets[space][name] 553 except KeyError: 554 raise Exception(f"Space '{space}' has no attribute '{name}'") 555 nl_type = attr.value 556 557 if attr.is_multi and isinstance(value, list): 558 attr_payload = b'' 559 for subvalue in value: 560 attr_payload += self._add_attr(space, name, subvalue, search_attrs) 561 return attr_payload 562 563 if attr["type"] == 'nest': 564 nl_type |= Netlink.NLA_F_NESTED 565 attr_payload = b'' 566 sub_space = attr['nested-attributes'] 567 sub_attrs = SpaceAttrs(self.attr_sets[sub_space], value, search_attrs) 568 for subname, subvalue in value.items(): 569 attr_payload += self._add_attr(sub_space, subname, subvalue, sub_attrs) 570 elif attr["type"] == 'flag': 571 if not value: 572 # If value is absent or false then skip attribute creation. 573 return b'' 574 attr_payload = b'' 575 elif attr["type"] == 'string': 576 attr_payload = str(value).encode('ascii') + b'\x00' 577 elif attr["type"] == 'binary': 578 if isinstance(value, bytes): 579 attr_payload = value 580 elif isinstance(value, str): 581 if attr.display_hint: 582 attr_payload = self._from_string(value, attr) 583 else: 584 attr_payload = bytes.fromhex(value) 585 elif isinstance(value, dict) and attr.struct_name: 586 attr_payload = self._encode_struct(attr.struct_name, value) 587 else: 588 raise Exception(f'Unknown type for binary attribute, value: {value}') 589 elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar: 590 scalar = self._get_scalar(attr, value) 591 if attr.is_auto_scalar: 592 attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64') 593 else: 594 attr_type = attr["type"] 595 format = NlAttr.get_format(attr_type, attr.byte_order) 596 attr_payload = format.pack(scalar) 597 elif attr['type'] in "bitfield32": 598 scalar_value = self._get_scalar(attr, value["value"]) 599 scalar_selector = self._get_scalar(attr, value["selector"]) 600 attr_payload = struct.pack("II", scalar_value, scalar_selector) 601 elif attr['type'] == 'sub-message': 602 msg_format, _ = self._resolve_selector(attr, search_attrs) 603 attr_payload = b'' 604 if msg_format.fixed_header: 605 attr_payload += self._encode_struct(msg_format.fixed_header, value) 606 if msg_format.attr_set: 607 if msg_format.attr_set in self.attr_sets: 608 nl_type |= Netlink.NLA_F_NESTED 609 sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs) 610 for subname, subvalue in value.items(): 611 attr_payload += self._add_attr(msg_format.attr_set, 612 subname, subvalue, sub_attrs) 613 else: 614 raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'") 615 else: 616 raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') 617 618 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) 619 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad 620 621 def _decode_enum(self, raw, attr_spec): 622 enum = self.consts[attr_spec['enum']] 623 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 624 i = 0 625 value = set() 626 while raw: 627 if raw & 1: 628 value.add(enum.entries_by_val[i].name) 629 raw >>= 1 630 i += 1 631 else: 632 value = enum.entries_by_val[raw].name 633 return value 634 635 def _decode_binary(self, attr, attr_spec): 636 if attr_spec.struct_name: 637 decoded = self._decode_struct(attr.raw, attr_spec.struct_name) 638 elif attr_spec.sub_type: 639 decoded = attr.as_c_array(attr_spec.sub_type) 640 if 'enum' in attr_spec: 641 decoded = [ self._decode_enum(x, attr_spec) for x in decoded ] 642 elif attr_spec.display_hint: 643 decoded = [ self._formatted_string(x, attr_spec.display_hint) 644 for x in decoded ] 645 else: 646 decoded = attr.as_bin() 647 if attr_spec.display_hint: 648 decoded = self._formatted_string(decoded, attr_spec.display_hint) 649 return decoded 650 651 def _decode_array_attr(self, attr, attr_spec): 652 decoded = [] 653 offset = 0 654 while offset < len(attr.raw): 655 item = NlAttr(attr.raw, offset) 656 offset += item.full_len 657 658 if attr_spec["sub-type"] == 'nest': 659 subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes']) 660 decoded.append({ item.type: subattrs }) 661 elif attr_spec["sub-type"] == 'binary': 662 subattr = item.as_bin() 663 if attr_spec.display_hint: 664 subattr = self._formatted_string(subattr, attr_spec.display_hint) 665 decoded.append(subattr) 666 elif attr_spec["sub-type"] in NlAttr.type_formats: 667 subattr = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order) 668 if 'enum' in attr_spec: 669 subattr = self._decode_enum(subattr, attr_spec) 670 elif attr_spec.display_hint: 671 subattr = self._formatted_string(subattr, attr_spec.display_hint) 672 decoded.append(subattr) 673 else: 674 raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}') 675 return decoded 676 677 def _decode_nest_type_value(self, attr, attr_spec): 678 decoded = {} 679 value = attr 680 for name in attr_spec['type-value']: 681 value = NlAttr(value.raw, 0) 682 decoded[name] = value.type 683 subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes']) 684 decoded.update(subattrs) 685 return decoded 686 687 def _decode_unknown(self, attr): 688 if attr.is_nest: 689 return self._decode(NlAttrs(attr.raw), None) 690 else: 691 return attr.as_bin() 692 693 def _rsp_add(self, rsp, name, is_multi, decoded): 694 if is_multi == None: 695 if name in rsp and type(rsp[name]) is not list: 696 rsp[name] = [rsp[name]] 697 is_multi = True 698 else: 699 is_multi = False 700 701 if not is_multi: 702 rsp[name] = decoded 703 elif name in rsp: 704 rsp[name].append(decoded) 705 else: 706 rsp[name] = [decoded] 707 708 def _resolve_selector(self, attr_spec, search_attrs): 709 sub_msg = attr_spec.sub_message 710 if sub_msg not in self.sub_msgs: 711 raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}") 712 sub_msg_spec = self.sub_msgs[sub_msg] 713 714 selector = attr_spec.selector 715 value = search_attrs.lookup(selector) 716 if value not in sub_msg_spec.formats: 717 raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'") 718 719 spec = sub_msg_spec.formats[value] 720 return spec, value 721 722 def _decode_sub_msg(self, attr, attr_spec, search_attrs): 723 msg_format, _ = self._resolve_selector(attr_spec, search_attrs) 724 decoded = {} 725 offset = 0 726 if msg_format.fixed_header: 727 decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header)); 728 offset = self._struct_size(msg_format.fixed_header) 729 if msg_format.attr_set: 730 if msg_format.attr_set in self.attr_sets: 731 subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set) 732 decoded.update(subdict) 733 else: 734 raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'") 735 return decoded 736 737 def _decode(self, attrs, space, outer_attrs = None): 738 rsp = dict() 739 if space: 740 attr_space = self.attr_sets[space] 741 search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs) 742 743 for attr in attrs: 744 try: 745 attr_spec = attr_space.attrs_by_val[attr.type] 746 except (KeyError, UnboundLocalError): 747 if not self.process_unknown: 748 raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'") 749 attr_name = f"UnknownAttr({attr.type})" 750 self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr)) 751 continue 752 753 try: 754 if attr_spec["type"] == 'nest': 755 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs) 756 decoded = subdict 757 elif attr_spec["type"] == 'string': 758 decoded = attr.as_strz() 759 elif attr_spec["type"] == 'binary': 760 decoded = self._decode_binary(attr, attr_spec) 761 elif attr_spec["type"] == 'flag': 762 decoded = True 763 elif attr_spec.is_auto_scalar: 764 decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order) 765 elif attr_spec["type"] in NlAttr.type_formats: 766 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) 767 if 'enum' in attr_spec: 768 decoded = self._decode_enum(decoded, attr_spec) 769 elif attr_spec.display_hint: 770 decoded = self._formatted_string(decoded, attr_spec.display_hint) 771 elif attr_spec["type"] == 'indexed-array': 772 decoded = self._decode_array_attr(attr, attr_spec) 773 elif attr_spec["type"] == 'bitfield32': 774 value, selector = struct.unpack("II", attr.raw) 775 if 'enum' in attr_spec: 776 value = self._decode_enum(value, attr_spec) 777 selector = self._decode_enum(selector, attr_spec) 778 decoded = {"value": value, "selector": selector} 779 elif attr_spec["type"] == 'sub-message': 780 decoded = self._decode_sub_msg(attr, attr_spec, search_attrs) 781 elif attr_spec["type"] == 'nest-type-value': 782 decoded = self._decode_nest_type_value(attr, attr_spec) 783 else: 784 if not self.process_unknown: 785 raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') 786 decoded = self._decode_unknown(attr) 787 788 self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded) 789 except: 790 print(f"Error decoding '{attr_spec.name}' from '{space}'") 791 raise 792 793 return rsp 794 795 def _decode_extack_path(self, attrs, attr_set, offset, target, search_attrs): 796 for attr in attrs: 797 try: 798 attr_spec = attr_set.attrs_by_val[attr.type] 799 except KeyError: 800 raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") 801 if offset > target: 802 break 803 if offset == target: 804 return '.' + attr_spec.name 805 806 if offset + attr.full_len <= target: 807 offset += attr.full_len 808 continue 809 810 pathname = attr_spec.name 811 if attr_spec['type'] == 'nest': 812 sub_attrs = self.attr_sets[attr_spec['nested-attributes']] 813 search_attrs = SpaceAttrs(sub_attrs, search_attrs.lookup(attr_spec['name'])) 814 elif attr_spec['type'] == 'sub-message': 815 msg_format, value = self._resolve_selector(attr_spec, search_attrs) 816 if msg_format is None: 817 raise Exception(f"Can't resolve sub-message of {attr_spec['name']} for extack") 818 sub_attrs = self.attr_sets[msg_format.attr_set] 819 pathname += f"({value})" 820 else: 821 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") 822 offset += 4 823 subpath = self._decode_extack_path(NlAttrs(attr.raw), sub_attrs, 824 offset, target, search_attrs) 825 if subpath is None: 826 return None 827 return '.' + pathname + subpath 828 829 return None 830 831 def _decode_extack(self, request, op, extack, vals): 832 if 'bad-attr-offs' not in extack: 833 return 834 835 msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op) 836 offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header) 837 search_attrs = SpaceAttrs(op.attr_set, vals) 838 path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset, 839 extack['bad-attr-offs'], search_attrs) 840 if path: 841 del extack['bad-attr-offs'] 842 extack['bad-attr'] = path 843 844 def _struct_size(self, name): 845 if name: 846 members = self.consts[name].members 847 size = 0 848 for m in members: 849 if m.type in ['pad', 'binary']: 850 if m.struct: 851 size += self._struct_size(m.struct) 852 else: 853 size += m.len 854 else: 855 format = NlAttr.get_format(m.type, m.byte_order) 856 size += format.size 857 return size 858 else: 859 return 0 860 861 def _decode_struct(self, data, name): 862 members = self.consts[name].members 863 attrs = dict() 864 offset = 0 865 for m in members: 866 value = None 867 if m.type == 'pad': 868 offset += m.len 869 elif m.type == 'binary': 870 if m.struct: 871 len = self._struct_size(m.struct) 872 value = self._decode_struct(data[offset : offset + len], 873 m.struct) 874 offset += len 875 else: 876 value = data[offset : offset + m.len] 877 offset += m.len 878 else: 879 format = NlAttr.get_format(m.type, m.byte_order) 880 [ value ] = format.unpack_from(data, offset) 881 offset += format.size 882 if value is not None: 883 if m.enum: 884 value = self._decode_enum(value, m) 885 elif m.display_hint: 886 value = self._formatted_string(value, m.display_hint) 887 attrs[m.name] = value 888 return attrs 889 890 def _encode_struct(self, name, vals): 891 members = self.consts[name].members 892 attr_payload = b'' 893 for m in members: 894 value = vals.pop(m.name) if m.name in vals else None 895 if m.type == 'pad': 896 attr_payload += bytearray(m.len) 897 elif m.type == 'binary': 898 if m.struct: 899 if value is None: 900 value = dict() 901 attr_payload += self._encode_struct(m.struct, value) 902 else: 903 if value is None: 904 attr_payload += bytearray(m.len) 905 else: 906 attr_payload += bytes.fromhex(value) 907 else: 908 if value is None: 909 value = 0 910 format = NlAttr.get_format(m.type, m.byte_order) 911 attr_payload += format.pack(value) 912 return attr_payload 913 914 def _formatted_string(self, raw, display_hint): 915 if display_hint == 'mac': 916 formatted = ':'.join('%02x' % b for b in raw) 917 elif display_hint == 'hex': 918 if isinstance(raw, int): 919 formatted = hex(raw) 920 else: 921 formatted = bytes.hex(raw, ' ') 922 elif display_hint in [ 'ipv4', 'ipv6' ]: 923 formatted = format(ipaddress.ip_address(raw)) 924 elif display_hint == 'uuid': 925 formatted = str(uuid.UUID(bytes=raw)) 926 else: 927 formatted = raw 928 return formatted 929 930 def _from_string(self, string, attr_spec): 931 if attr_spec.display_hint in ['ipv4', 'ipv6']: 932 ip = ipaddress.ip_address(string) 933 if attr_spec['type'] == 'binary': 934 raw = ip.packed 935 else: 936 raw = int(ip) 937 else: 938 raise Exception(f"Display hint '{attr_spec.display_hint}' not implemented" 939 f" when parsing '{attr_spec['name']}'") 940 return raw 941 942 def handle_ntf(self, decoded): 943 msg = dict() 944 if self.include_raw: 945 msg['raw'] = decoded 946 op = self.rsp_by_value[decoded.cmd()] 947 attrs = self._decode(decoded.raw_attrs, op.attr_set.name) 948 if op.fixed_header: 949 attrs.update(self._decode_struct(decoded.raw, op.fixed_header)) 950 951 msg['name'] = op['name'] 952 msg['msg'] = attrs 953 self.async_msg_queue.put(msg) 954 955 def check_ntf(self): 956 while True: 957 try: 958 reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT) 959 except BlockingIOError: 960 return 961 962 nms = NlMsgs(reply) 963 self._recv_dbg_print(reply, nms) 964 for nl_msg in nms: 965 if nl_msg.error: 966 print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) 967 print(nl_msg) 968 continue 969 if nl_msg.done: 970 print("Netlink done while checking for ntf!?") 971 continue 972 973 decoded = self.nlproto.decode(self, nl_msg, None) 974 if decoded.cmd() not in self.async_msg_ids: 975 print("Unexpected msg id while checking for ntf", decoded) 976 continue 977 978 self.handle_ntf(decoded) 979 980 def poll_ntf(self, duration=None): 981 start_time = time.time() 982 selector = selectors.DefaultSelector() 983 selector.register(self.sock, selectors.EVENT_READ) 984 985 while True: 986 try: 987 yield self.async_msg_queue.get_nowait() 988 except queue.Empty: 989 if duration is not None: 990 timeout = start_time + duration - time.time() 991 if timeout <= 0: 992 return 993 else: 994 timeout = None 995 events = selector.select(timeout) 996 if events: 997 self.check_ntf() 998 999 def operation_do_attributes(self, name): 1000 """ 1001 For a given operation name, find and return a supported 1002 set of attributes (as a dict). 1003 """ 1004 op = self.find_operation(name) 1005 if not op: 1006 return None 1007 1008 return op['do']['request']['attributes'].copy() 1009 1010 def _encode_message(self, op, vals, flags, req_seq): 1011 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK 1012 for flag in flags or []: 1013 nl_flags |= flag 1014 1015 msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq) 1016 if op.fixed_header: 1017 msg += self._encode_struct(op.fixed_header, vals) 1018 search_attrs = SpaceAttrs(op.attr_set, vals) 1019 for name, value in vals.items(): 1020 msg += self._add_attr(op.attr_set.name, name, value, search_attrs) 1021 msg = _genl_msg_finalize(msg) 1022 return msg 1023 1024 def _ops(self, ops): 1025 reqs_by_seq = {} 1026 req_seq = random.randint(1024, 65535) 1027 payload = b'' 1028 for (method, vals, flags) in ops: 1029 op = self.ops[method] 1030 msg = self._encode_message(op, vals, flags, req_seq) 1031 reqs_by_seq[req_seq] = (op, vals, msg, flags) 1032 payload += msg 1033 req_seq += 1 1034 1035 self.sock.send(payload, 0) 1036 1037 done = False 1038 rsp = [] 1039 op_rsp = [] 1040 while not done: 1041 reply = self.sock.recv(self._recv_size) 1042 nms = NlMsgs(reply) 1043 self._recv_dbg_print(reply, nms) 1044 for nl_msg in nms: 1045 if nl_msg.nl_seq in reqs_by_seq: 1046 (op, vals, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq] 1047 if nl_msg.extack: 1048 nl_msg.annotate_extack(op.attr_set) 1049 self._decode_extack(req_msg, op, nl_msg.extack, vals) 1050 else: 1051 op = None 1052 req_flags = [] 1053 1054 if nl_msg.error: 1055 raise NlError(nl_msg) 1056 if nl_msg.done: 1057 if nl_msg.extack: 1058 print("Netlink warning:") 1059 print(nl_msg) 1060 1061 if Netlink.NLM_F_DUMP in req_flags: 1062 rsp.append(op_rsp) 1063 elif not op_rsp: 1064 rsp.append(None) 1065 elif len(op_rsp) == 1: 1066 rsp.append(op_rsp[0]) 1067 else: 1068 rsp.append(op_rsp) 1069 op_rsp = [] 1070 1071 del reqs_by_seq[nl_msg.nl_seq] 1072 done = len(reqs_by_seq) == 0 1073 break 1074 1075 decoded = self.nlproto.decode(self, nl_msg, op) 1076 1077 # Check if this is a reply to our request 1078 if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value: 1079 if decoded.cmd() in self.async_msg_ids: 1080 self.handle_ntf(decoded) 1081 continue 1082 else: 1083 print('Unexpected message: ' + repr(decoded)) 1084 continue 1085 1086 rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name) 1087 if op.fixed_header: 1088 rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header)) 1089 op_rsp.append(rsp_msg) 1090 1091 return rsp 1092 1093 def _op(self, method, vals, flags=None, dump=False): 1094 req_flags = flags or [] 1095 if dump: 1096 req_flags.append(Netlink.NLM_F_DUMP) 1097 1098 ops = [(method, vals, req_flags)] 1099 return self._ops(ops)[0] 1100 1101 def do(self, method, vals, flags=None): 1102 return self._op(method, vals, flags) 1103 1104 def dump(self, method, vals): 1105 return self._op(method, vals, dump=True) 1106 1107 def do_multi(self, ops): 1108 return self._ops(ops) 1109