1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause 2 3from collections import namedtuple 4from enum import Enum 5import functools 6import os 7import random 8import socket 9import struct 10from struct import Struct 11import sys 12import yaml 13import ipaddress 14import uuid 15import queue 16import selectors 17import time 18 19from .nlspec import SpecFamily 20 21# 22# Generic Netlink code which should really be in some library, but I can't quickly find one. 23# 24 25 26class Netlink: 27 # Netlink socket 28 SOL_NETLINK = 270 29 30 NETLINK_ADD_MEMBERSHIP = 1 31 NETLINK_CAP_ACK = 10 32 NETLINK_EXT_ACK = 11 33 NETLINK_GET_STRICT_CHK = 12 34 35 # Netlink message 36 NLMSG_ERROR = 2 37 NLMSG_DONE = 3 38 39 NLM_F_REQUEST = 1 40 NLM_F_ACK = 4 41 NLM_F_ROOT = 0x100 42 NLM_F_MATCH = 0x200 43 44 NLM_F_REPLACE = 0x100 45 NLM_F_EXCL = 0x200 46 NLM_F_CREATE = 0x400 47 NLM_F_APPEND = 0x800 48 49 NLM_F_CAPPED = 0x100 50 NLM_F_ACK_TLVS = 0x200 51 52 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH 53 54 NLA_F_NESTED = 0x8000 55 NLA_F_NET_BYTEORDER = 0x4000 56 57 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER 58 59 # Genetlink defines 60 NETLINK_GENERIC = 16 61 62 GENL_ID_CTRL = 0x10 63 64 # nlctrl 65 CTRL_CMD_GETFAMILY = 3 66 67 CTRL_ATTR_FAMILY_ID = 1 68 CTRL_ATTR_FAMILY_NAME = 2 69 CTRL_ATTR_MAXATTR = 5 70 CTRL_ATTR_MCAST_GROUPS = 7 71 72 CTRL_ATTR_MCAST_GRP_NAME = 1 73 CTRL_ATTR_MCAST_GRP_ID = 2 74 75 # Extack types 76 NLMSGERR_ATTR_MSG = 1 77 NLMSGERR_ATTR_OFFS = 2 78 NLMSGERR_ATTR_COOKIE = 3 79 NLMSGERR_ATTR_POLICY = 4 80 NLMSGERR_ATTR_MISS_TYPE = 5 81 NLMSGERR_ATTR_MISS_NEST = 6 82 83 # Policy types 84 NL_POLICY_TYPE_ATTR_TYPE = 1 85 NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2 86 NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3 87 NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4 88 NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5 89 NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6 90 NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7 91 NL_POLICY_TYPE_ATTR_POLICY_IDX = 8 92 NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9 93 NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10 94 NL_POLICY_TYPE_ATTR_PAD = 11 95 NL_POLICY_TYPE_ATTR_MASK = 12 96 97 AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64', 98 's8', 's16', 's32', 's64', 99 'binary', 'string', 'nul-string', 100 'nested', 'nested-array', 101 'bitfield32', 'sint', 'uint']) 102 103class NlError(Exception): 104 def __init__(self, nl_msg): 105 self.nl_msg = nl_msg 106 self.error = -nl_msg.error 107 108 def __str__(self): 109 return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}" 110 111 112class ConfigError(Exception): 113 pass 114 115 116class NlAttr: 117 ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) 118 type_formats = { 119 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), 120 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), 121 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")), 122 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")), 123 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")), 124 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")), 125 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")), 126 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q")) 127 } 128 129 def __init__(self, raw, offset): 130 self._len, self._type = struct.unpack("HH", raw[offset : offset + 4]) 131 self.type = self._type & ~Netlink.NLA_TYPE_MASK 132 self.is_nest = self._type & Netlink.NLA_F_NESTED 133 self.payload_len = self._len 134 self.full_len = (self.payload_len + 3) & ~3 135 self.raw = raw[offset + 4 : offset + self.payload_len] 136 137 @classmethod 138 def get_format(cls, attr_type, byte_order=None): 139 format = cls.type_formats[attr_type] 140 if byte_order: 141 return format.big if byte_order == "big-endian" \ 142 else format.little 143 return format.native 144 145 def as_scalar(self, attr_type, byte_order=None): 146 format = self.get_format(attr_type, byte_order) 147 return format.unpack(self.raw)[0] 148 149 def as_auto_scalar(self, attr_type, byte_order=None): 150 if len(self.raw) != 4 and len(self.raw) != 8: 151 raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}") 152 real_type = attr_type[0] + str(len(self.raw) * 8) 153 format = self.get_format(real_type, byte_order) 154 return format.unpack(self.raw)[0] 155 156 def as_strz(self): 157 return self.raw.decode('ascii')[:-1] 158 159 def as_bin(self): 160 return self.raw 161 162 def as_c_array(self, type): 163 format = self.get_format(type) 164 return [ x[0] for x in format.iter_unpack(self.raw) ] 165 166 def __repr__(self): 167 return f"[type:{self.type} len:{self._len}] {self.raw}" 168 169 170class NlAttrs: 171 def __init__(self, msg, offset=0): 172 self.attrs = [] 173 174 while offset < len(msg): 175 attr = NlAttr(msg, offset) 176 offset += attr.full_len 177 self.attrs.append(attr) 178 179 def __iter__(self): 180 yield from self.attrs 181 182 def __repr__(self): 183 msg = '' 184 for a in self.attrs: 185 if msg: 186 msg += '\n' 187 msg += repr(a) 188 return msg 189 190 191class NlMsg: 192 def __init__(self, msg, offset, attr_space=None): 193 self.hdr = msg[offset : offset + 16] 194 195 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ 196 struct.unpack("IHHII", self.hdr) 197 198 self.raw = msg[offset + 16 : offset + self.nl_len] 199 200 self.error = 0 201 self.done = 0 202 203 extack_off = None 204 if self.nl_type == Netlink.NLMSG_ERROR: 205 self.error = struct.unpack("i", self.raw[0:4])[0] 206 self.done = 1 207 extack_off = 20 208 elif self.nl_type == Netlink.NLMSG_DONE: 209 self.error = struct.unpack("i", self.raw[0:4])[0] 210 self.done = 1 211 extack_off = 4 212 213 self.extack = None 214 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: 215 self.extack = dict() 216 extack_attrs = NlAttrs(self.raw[extack_off:]) 217 for extack in extack_attrs: 218 if extack.type == Netlink.NLMSGERR_ATTR_MSG: 219 self.extack['msg'] = extack.as_strz() 220 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: 221 self.extack['miss-type'] = extack.as_scalar('u32') 222 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: 223 self.extack['miss-nest'] = extack.as_scalar('u32') 224 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: 225 self.extack['bad-attr-offs'] = extack.as_scalar('u32') 226 elif extack.type == Netlink.NLMSGERR_ATTR_POLICY: 227 self.extack['policy'] = self._decode_policy(extack.raw) 228 else: 229 if 'unknown' not in self.extack: 230 self.extack['unknown'] = [] 231 self.extack['unknown'].append(extack) 232 233 if attr_space: 234 self.annotate_extack(attr_space) 235 236 def _decode_policy(self, raw): 237 policy = {} 238 for attr in NlAttrs(raw): 239 if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE: 240 type = attr.as_scalar('u32') 241 policy['type'] = Netlink.AttrType(type).name 242 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S: 243 policy['min-value'] = attr.as_scalar('s64') 244 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S: 245 policy['max-value'] = attr.as_scalar('s64') 246 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U: 247 policy['min-value'] = attr.as_scalar('u64') 248 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U: 249 policy['max-value'] = attr.as_scalar('u64') 250 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH: 251 policy['min-length'] = attr.as_scalar('u32') 252 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH: 253 policy['max-length'] = attr.as_scalar('u32') 254 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK: 255 policy['bitfield32-mask'] = attr.as_scalar('u32') 256 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK: 257 policy['mask'] = attr.as_scalar('u64') 258 return policy 259 260 def annotate_extack(self, attr_space): 261 """ Make extack more human friendly with attribute information """ 262 263 # We don't have the ability to parse nests yet, so only do global 264 if 'miss-type' in self.extack and 'miss-nest' not in self.extack: 265 miss_type = self.extack['miss-type'] 266 if miss_type in attr_space.attrs_by_val: 267 spec = attr_space.attrs_by_val[miss_type] 268 self.extack['miss-type'] = spec['name'] 269 if 'doc' in spec: 270 self.extack['miss-type-doc'] = spec['doc'] 271 272 def cmd(self): 273 return self.nl_type 274 275 def __repr__(self): 276 msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}" 277 if self.error: 278 msg += '\n\terror: ' + str(self.error) 279 if self.extack: 280 msg += '\n\textack: ' + repr(self.extack) 281 return msg 282 283 284class NlMsgs: 285 def __init__(self, data): 286 self.msgs = [] 287 288 offset = 0 289 while offset < len(data): 290 msg = NlMsg(data, offset) 291 offset += msg.nl_len 292 self.msgs.append(msg) 293 294 def __iter__(self): 295 yield from self.msgs 296 297 298genl_family_name_to_id = None 299 300 301def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): 302 # we prepend length in _genl_msg_finalize() 303 if seq is None: 304 seq = random.randint(1, 1024) 305 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 306 genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0) 307 return nlmsg + genlmsg 308 309 310def _genl_msg_finalize(msg): 311 return struct.pack("I", len(msg) + 4) + msg 312 313 314def _genl_load_families(): 315 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: 316 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 317 318 msg = _genl_msg(Netlink.GENL_ID_CTRL, 319 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, 320 Netlink.CTRL_CMD_GETFAMILY, 1) 321 msg = _genl_msg_finalize(msg) 322 323 sock.send(msg, 0) 324 325 global genl_family_name_to_id 326 genl_family_name_to_id = dict() 327 328 while True: 329 reply = sock.recv(128 * 1024) 330 nms = NlMsgs(reply) 331 for nl_msg in nms: 332 if nl_msg.error: 333 print("Netlink error:", nl_msg.error) 334 return 335 if nl_msg.done: 336 return 337 338 gm = GenlMsg(nl_msg) 339 fam = dict() 340 for attr in NlAttrs(gm.raw): 341 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: 342 fam['id'] = attr.as_scalar('u16') 343 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: 344 fam['name'] = attr.as_strz() 345 elif attr.type == Netlink.CTRL_ATTR_MAXATTR: 346 fam['maxattr'] = attr.as_scalar('u32') 347 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: 348 fam['mcast'] = dict() 349 for entry in NlAttrs(attr.raw): 350 mcast_name = None 351 mcast_id = None 352 for entry_attr in NlAttrs(entry.raw): 353 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: 354 mcast_name = entry_attr.as_strz() 355 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: 356 mcast_id = entry_attr.as_scalar('u32') 357 if mcast_name and mcast_id is not None: 358 fam['mcast'][mcast_name] = mcast_id 359 if 'name' in fam and 'id' in fam: 360 genl_family_name_to_id[fam['name']] = fam 361 362 363class GenlMsg: 364 def __init__(self, nl_msg): 365 self.nl = nl_msg 366 self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0) 367 self.raw = nl_msg.raw[4:] 368 369 def cmd(self): 370 return self.genl_cmd 371 372 def __repr__(self): 373 msg = repr(self.nl) 374 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" 375 for a in self.raw_attrs: 376 msg += '\t\t' + repr(a) + '\n' 377 return msg 378 379 380class NetlinkProtocol: 381 def __init__(self, family_name, proto_num): 382 self.family_name = family_name 383 self.proto_num = proto_num 384 385 def _message(self, nl_type, nl_flags, seq=None): 386 if seq is None: 387 seq = random.randint(1, 1024) 388 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 389 return nlmsg 390 391 def message(self, flags, command, version, seq=None): 392 return self._message(command, flags, seq) 393 394 def _decode(self, nl_msg): 395 return nl_msg 396 397 def decode(self, ynl, nl_msg, op): 398 msg = self._decode(nl_msg) 399 if op is None: 400 op = ynl.rsp_by_value[msg.cmd()] 401 fixed_header_size = ynl._struct_size(op.fixed_header) 402 msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size) 403 return msg 404 405 def get_mcast_id(self, mcast_name, mcast_groups): 406 if mcast_name not in mcast_groups: 407 raise Exception(f'Multicast group "{mcast_name}" not present in the spec') 408 return mcast_groups[mcast_name].value 409 410 def msghdr_size(self): 411 return 16 412 413 414class GenlProtocol(NetlinkProtocol): 415 def __init__(self, family_name): 416 super().__init__(family_name, Netlink.NETLINK_GENERIC) 417 418 global genl_family_name_to_id 419 if genl_family_name_to_id is None: 420 _genl_load_families() 421 422 self.genl_family = genl_family_name_to_id[family_name] 423 self.family_id = genl_family_name_to_id[family_name]['id'] 424 425 def message(self, flags, command, version, seq=None): 426 nlmsg = self._message(self.family_id, flags, seq) 427 genlmsg = struct.pack("BBH", command, version, 0) 428 return nlmsg + genlmsg 429 430 def _decode(self, nl_msg): 431 return GenlMsg(nl_msg) 432 433 def get_mcast_id(self, mcast_name, mcast_groups): 434 if mcast_name not in self.genl_family['mcast']: 435 raise Exception(f'Multicast group "{mcast_name}" not present in the family') 436 return self.genl_family['mcast'][mcast_name] 437 438 def msghdr_size(self): 439 return super().msghdr_size() + 4 440 441 442class SpaceAttrs: 443 SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values']) 444 445 def __init__(self, attr_space, attrs, outer = None): 446 outer_scopes = outer.scopes if outer else [] 447 inner_scope = self.SpecValuesPair(attr_space, attrs) 448 self.scopes = [inner_scope] + outer_scopes 449 450 def lookup(self, name): 451 for scope in self.scopes: 452 if name in scope.spec: 453 if name in scope.values: 454 return scope.values[name] 455 spec_name = scope.spec.yaml['name'] 456 raise Exception( 457 f"No value for '{name}' in attribute space '{spec_name}'") 458 raise Exception(f"Attribute '{name}' not defined in any attribute-set") 459 460 461# 462# YNL implementation details. 463# 464 465 466class YnlFamily(SpecFamily): 467 def __init__(self, def_path, schema=None, process_unknown=False, 468 recv_size=0): 469 super().__init__(def_path, schema) 470 471 self.include_raw = False 472 self.process_unknown = process_unknown 473 474 try: 475 if self.proto == "netlink-raw": 476 self.nlproto = NetlinkProtocol(self.yaml['name'], 477 self.yaml['protonum']) 478 else: 479 self.nlproto = GenlProtocol(self.yaml['name']) 480 except KeyError: 481 raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel") 482 483 self._recv_dbg = False 484 # Note that netlink will use conservative (min) message size for 485 # the first dump recv() on the socket, our setting will only matter 486 # from the second recv() on. 487 self._recv_size = recv_size if recv_size else 131072 488 # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo) 489 # for a message, so smaller receive sizes will lead to truncation. 490 # Note that the min size for other families may be larger than 4k! 491 if self._recv_size < 4000: 492 raise ConfigError() 493 494 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num) 495 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 496 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) 497 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1) 498 499 self.async_msg_ids = set() 500 self.async_msg_queue = queue.Queue() 501 502 for msg in self.msgs.values(): 503 if msg.is_async: 504 self.async_msg_ids.add(msg.rsp_value) 505 506 for op_name, op in self.ops.items(): 507 bound_f = functools.partial(self._op, op_name) 508 setattr(self, op.ident_name, bound_f) 509 510 511 def ntf_subscribe(self, mcast_name): 512 mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups) 513 self.sock.bind((0, 0)) 514 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, 515 mcast_id) 516 517 def set_recv_dbg(self, enabled): 518 self._recv_dbg = enabled 519 520 def _recv_dbg_print(self, reply, nl_msgs): 521 if not self._recv_dbg: 522 return 523 print("Recv: read", len(reply), "bytes,", 524 len(nl_msgs.msgs), "messages", file=sys.stderr) 525 for nl_msg in nl_msgs: 526 print(" ", nl_msg, file=sys.stderr) 527 528 def _encode_enum(self, attr_spec, value): 529 enum = self.consts[attr_spec['enum']] 530 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 531 scalar = 0 532 if isinstance(value, str): 533 value = [value] 534 for single_value in value: 535 scalar += enum.entries[single_value].user_value(as_flags = True) 536 return scalar 537 else: 538 return enum.entries[value].user_value() 539 540 def _get_scalar(self, attr_spec, value): 541 try: 542 return int(value) 543 except (ValueError, TypeError) as e: 544 if 'enum' in attr_spec: 545 return self._encode_enum(attr_spec, value) 546 if attr_spec.display_hint: 547 return self._from_string(value, attr_spec) 548 raise e 549 550 def _add_attr(self, space, name, value, search_attrs): 551 try: 552 attr = self.attr_sets[space][name] 553 except KeyError: 554 raise Exception(f"Space '{space}' has no attribute '{name}'") 555 nl_type = attr.value 556 557 if attr.is_multi and isinstance(value, list): 558 attr_payload = b'' 559 for subvalue in value: 560 attr_payload += self._add_attr(space, name, subvalue, search_attrs) 561 return attr_payload 562 563 if attr["type"] == 'nest': 564 nl_type |= Netlink.NLA_F_NESTED 565 attr_payload = b'' 566 sub_space = attr['nested-attributes'] 567 sub_attrs = SpaceAttrs(self.attr_sets[sub_space], value, search_attrs) 568 for subname, subvalue in value.items(): 569 attr_payload += self._add_attr(sub_space, subname, subvalue, sub_attrs) 570 elif attr["type"] == 'flag': 571 if not value: 572 # If value is absent or false then skip attribute creation. 573 return b'' 574 attr_payload = b'' 575 elif attr["type"] == 'string': 576 attr_payload = str(value).encode('ascii') + b'\x00' 577 elif attr["type"] == 'binary': 578 if isinstance(value, bytes): 579 attr_payload = value 580 elif isinstance(value, str): 581 if attr.display_hint: 582 attr_payload = self._from_string(value, attr) 583 else: 584 attr_payload = bytes.fromhex(value) 585 elif isinstance(value, dict) and attr.struct_name: 586 attr_payload = self._encode_struct(attr.struct_name, value) 587 else: 588 raise Exception(f'Unknown type for binary attribute, value: {value}') 589 elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar: 590 scalar = self._get_scalar(attr, value) 591 if attr.is_auto_scalar: 592 attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64') 593 else: 594 attr_type = attr["type"] 595 format = NlAttr.get_format(attr_type, attr.byte_order) 596 attr_payload = format.pack(scalar) 597 elif attr['type'] in "bitfield32": 598 scalar_value = self._get_scalar(attr, value["value"]) 599 scalar_selector = self._get_scalar(attr, value["selector"]) 600 attr_payload = struct.pack("II", scalar_value, scalar_selector) 601 elif attr['type'] == 'sub-message': 602 msg_format, _ = self._resolve_selector(attr, search_attrs) 603 attr_payload = b'' 604 if msg_format.fixed_header: 605 attr_payload += self._encode_struct(msg_format.fixed_header, value) 606 if msg_format.attr_set: 607 if msg_format.attr_set in self.attr_sets: 608 nl_type |= Netlink.NLA_F_NESTED 609 sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs) 610 for subname, subvalue in value.items(): 611 attr_payload += self._add_attr(msg_format.attr_set, 612 subname, subvalue, sub_attrs) 613 else: 614 raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'") 615 else: 616 raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') 617 618 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) 619 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad 620 621 def _get_enum_or_unknown(self, enum, raw): 622 try: 623 name = enum.entries_by_val[raw].name 624 except KeyError as error: 625 if self.process_unknown: 626 name = f"Unknown({raw})" 627 else: 628 raise error 629 return name 630 631 def _decode_enum(self, raw, attr_spec): 632 enum = self.consts[attr_spec['enum']] 633 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 634 i = 0 635 value = set() 636 while raw: 637 if raw & 1: 638 value.add(self._get_enum_or_unknown(enum, i)) 639 raw >>= 1 640 i += 1 641 else: 642 value = self._get_enum_or_unknown(enum, raw) 643 return value 644 645 def _decode_binary(self, attr, attr_spec): 646 if attr_spec.struct_name: 647 decoded = self._decode_struct(attr.raw, attr_spec.struct_name) 648 elif attr_spec.sub_type: 649 decoded = attr.as_c_array(attr_spec.sub_type) 650 if 'enum' in attr_spec: 651 decoded = [ self._decode_enum(x, attr_spec) for x in decoded ] 652 elif attr_spec.display_hint: 653 decoded = [ self._formatted_string(x, attr_spec.display_hint) 654 for x in decoded ] 655 else: 656 decoded = attr.as_bin() 657 if attr_spec.display_hint: 658 decoded = self._formatted_string(decoded, attr_spec.display_hint) 659 return decoded 660 661 def _decode_array_attr(self, attr, attr_spec): 662 decoded = [] 663 offset = 0 664 while offset < len(attr.raw): 665 item = NlAttr(attr.raw, offset) 666 offset += item.full_len 667 668 if attr_spec["sub-type"] == 'nest': 669 subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes']) 670 decoded.append({ item.type: subattrs }) 671 elif attr_spec["sub-type"] == 'binary': 672 subattr = item.as_bin() 673 if attr_spec.display_hint: 674 subattr = self._formatted_string(subattr, attr_spec.display_hint) 675 decoded.append(subattr) 676 elif attr_spec["sub-type"] in NlAttr.type_formats: 677 subattr = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order) 678 if 'enum' in attr_spec: 679 subattr = self._decode_enum(subattr, attr_spec) 680 elif attr_spec.display_hint: 681 subattr = self._formatted_string(subattr, attr_spec.display_hint) 682 decoded.append(subattr) 683 else: 684 raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}') 685 return decoded 686 687 def _decode_nest_type_value(self, attr, attr_spec): 688 decoded = {} 689 value = attr 690 for name in attr_spec['type-value']: 691 value = NlAttr(value.raw, 0) 692 decoded[name] = value.type 693 subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes']) 694 decoded.update(subattrs) 695 return decoded 696 697 def _decode_unknown(self, attr): 698 if attr.is_nest: 699 return self._decode(NlAttrs(attr.raw), None) 700 else: 701 return attr.as_bin() 702 703 def _rsp_add(self, rsp, name, is_multi, decoded): 704 if is_multi == None: 705 if name in rsp and type(rsp[name]) is not list: 706 rsp[name] = [rsp[name]] 707 is_multi = True 708 else: 709 is_multi = False 710 711 if not is_multi: 712 rsp[name] = decoded 713 elif name in rsp: 714 rsp[name].append(decoded) 715 else: 716 rsp[name] = [decoded] 717 718 def _resolve_selector(self, attr_spec, search_attrs): 719 sub_msg = attr_spec.sub_message 720 if sub_msg not in self.sub_msgs: 721 raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}") 722 sub_msg_spec = self.sub_msgs[sub_msg] 723 724 selector = attr_spec.selector 725 value = search_attrs.lookup(selector) 726 if value not in sub_msg_spec.formats: 727 raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'") 728 729 spec = sub_msg_spec.formats[value] 730 return spec, value 731 732 def _decode_sub_msg(self, attr, attr_spec, search_attrs): 733 msg_format, _ = self._resolve_selector(attr_spec, search_attrs) 734 decoded = {} 735 offset = 0 736 if msg_format.fixed_header: 737 decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header)); 738 offset = self._struct_size(msg_format.fixed_header) 739 if msg_format.attr_set: 740 if msg_format.attr_set in self.attr_sets: 741 subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set) 742 decoded.update(subdict) 743 else: 744 raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'") 745 return decoded 746 747 def _decode(self, attrs, space, outer_attrs = None): 748 rsp = dict() 749 if space: 750 attr_space = self.attr_sets[space] 751 search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs) 752 753 for attr in attrs: 754 try: 755 attr_spec = attr_space.attrs_by_val[attr.type] 756 except (KeyError, UnboundLocalError): 757 if not self.process_unknown: 758 raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'") 759 attr_name = f"UnknownAttr({attr.type})" 760 self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr)) 761 continue 762 763 try: 764 if attr_spec["type"] == 'nest': 765 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs) 766 decoded = subdict 767 elif attr_spec["type"] == 'string': 768 decoded = attr.as_strz() 769 elif attr_spec["type"] == 'binary': 770 decoded = self._decode_binary(attr, attr_spec) 771 elif attr_spec["type"] == 'flag': 772 decoded = True 773 elif attr_spec.is_auto_scalar: 774 decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order) 775 if 'enum' in attr_spec: 776 decoded = self._decode_enum(decoded, attr_spec) 777 elif attr_spec["type"] in NlAttr.type_formats: 778 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) 779 if 'enum' in attr_spec: 780 decoded = self._decode_enum(decoded, attr_spec) 781 elif attr_spec.display_hint: 782 decoded = self._formatted_string(decoded, attr_spec.display_hint) 783 elif attr_spec["type"] == 'indexed-array': 784 decoded = self._decode_array_attr(attr, attr_spec) 785 elif attr_spec["type"] == 'bitfield32': 786 value, selector = struct.unpack("II", attr.raw) 787 if 'enum' in attr_spec: 788 value = self._decode_enum(value, attr_spec) 789 selector = self._decode_enum(selector, attr_spec) 790 decoded = {"value": value, "selector": selector} 791 elif attr_spec["type"] == 'sub-message': 792 decoded = self._decode_sub_msg(attr, attr_spec, search_attrs) 793 elif attr_spec["type"] == 'nest-type-value': 794 decoded = self._decode_nest_type_value(attr, attr_spec) 795 else: 796 if not self.process_unknown: 797 raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') 798 decoded = self._decode_unknown(attr) 799 800 self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded) 801 except: 802 print(f"Error decoding '{attr_spec.name}' from '{space}'") 803 raise 804 805 return rsp 806 807 def _decode_extack_path(self, attrs, attr_set, offset, target, search_attrs): 808 for attr in attrs: 809 try: 810 attr_spec = attr_set.attrs_by_val[attr.type] 811 except KeyError: 812 raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") 813 if offset > target: 814 break 815 if offset == target: 816 return '.' + attr_spec.name 817 818 if offset + attr.full_len <= target: 819 offset += attr.full_len 820 continue 821 822 pathname = attr_spec.name 823 if attr_spec['type'] == 'nest': 824 sub_attrs = self.attr_sets[attr_spec['nested-attributes']] 825 search_attrs = SpaceAttrs(sub_attrs, search_attrs.lookup(attr_spec['name'])) 826 elif attr_spec['type'] == 'sub-message': 827 msg_format, value = self._resolve_selector(attr_spec, search_attrs) 828 if msg_format is None: 829 raise Exception(f"Can't resolve sub-message of {attr_spec['name']} for extack") 830 sub_attrs = self.attr_sets[msg_format.attr_set] 831 pathname += f"({value})" 832 else: 833 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") 834 offset += 4 835 subpath = self._decode_extack_path(NlAttrs(attr.raw), sub_attrs, 836 offset, target, search_attrs) 837 if subpath is None: 838 return None 839 return '.' + pathname + subpath 840 841 return None 842 843 def _decode_extack(self, request, op, extack, vals): 844 if 'bad-attr-offs' not in extack: 845 return 846 847 msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op) 848 offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header) 849 search_attrs = SpaceAttrs(op.attr_set, vals) 850 path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset, 851 extack['bad-attr-offs'], search_attrs) 852 if path: 853 del extack['bad-attr-offs'] 854 extack['bad-attr'] = path 855 856 def _struct_size(self, name): 857 if name: 858 members = self.consts[name].members 859 size = 0 860 for m in members: 861 if m.type in ['pad', 'binary']: 862 if m.struct: 863 size += self._struct_size(m.struct) 864 else: 865 size += m.len 866 else: 867 format = NlAttr.get_format(m.type, m.byte_order) 868 size += format.size 869 return size 870 else: 871 return 0 872 873 def _decode_struct(self, data, name): 874 members = self.consts[name].members 875 attrs = dict() 876 offset = 0 877 for m in members: 878 value = None 879 if m.type == 'pad': 880 offset += m.len 881 elif m.type == 'binary': 882 if m.struct: 883 len = self._struct_size(m.struct) 884 value = self._decode_struct(data[offset : offset + len], 885 m.struct) 886 offset += len 887 else: 888 value = data[offset : offset + m.len] 889 offset += m.len 890 else: 891 format = NlAttr.get_format(m.type, m.byte_order) 892 [ value ] = format.unpack_from(data, offset) 893 offset += format.size 894 if value is not None: 895 if m.enum: 896 value = self._decode_enum(value, m) 897 elif m.display_hint: 898 value = self._formatted_string(value, m.display_hint) 899 attrs[m.name] = value 900 return attrs 901 902 def _encode_struct(self, name, vals): 903 members = self.consts[name].members 904 attr_payload = b'' 905 for m in members: 906 value = vals.pop(m.name) if m.name in vals else None 907 if m.type == 'pad': 908 attr_payload += bytearray(m.len) 909 elif m.type == 'binary': 910 if m.struct: 911 if value is None: 912 value = dict() 913 attr_payload += self._encode_struct(m.struct, value) 914 else: 915 if value is None: 916 attr_payload += bytearray(m.len) 917 else: 918 attr_payload += bytes.fromhex(value) 919 else: 920 if value is None: 921 value = 0 922 format = NlAttr.get_format(m.type, m.byte_order) 923 attr_payload += format.pack(value) 924 return attr_payload 925 926 def _formatted_string(self, raw, display_hint): 927 if display_hint == 'mac': 928 formatted = ':'.join('%02x' % b for b in raw) 929 elif display_hint == 'hex': 930 if isinstance(raw, int): 931 formatted = hex(raw) 932 else: 933 formatted = bytes.hex(raw, ' ') 934 elif display_hint in [ 'ipv4', 'ipv6' ]: 935 formatted = format(ipaddress.ip_address(raw)) 936 elif display_hint == 'uuid': 937 formatted = str(uuid.UUID(bytes=raw)) 938 else: 939 formatted = raw 940 return formatted 941 942 def _from_string(self, string, attr_spec): 943 if attr_spec.display_hint in ['ipv4', 'ipv6']: 944 ip = ipaddress.ip_address(string) 945 if attr_spec['type'] == 'binary': 946 raw = ip.packed 947 else: 948 raw = int(ip) 949 else: 950 raise Exception(f"Display hint '{attr_spec.display_hint}' not implemented" 951 f" when parsing '{attr_spec['name']}'") 952 return raw 953 954 def handle_ntf(self, decoded): 955 msg = dict() 956 if self.include_raw: 957 msg['raw'] = decoded 958 op = self.rsp_by_value[decoded.cmd()] 959 attrs = self._decode(decoded.raw_attrs, op.attr_set.name) 960 if op.fixed_header: 961 attrs.update(self._decode_struct(decoded.raw, op.fixed_header)) 962 963 msg['name'] = op['name'] 964 msg['msg'] = attrs 965 self.async_msg_queue.put(msg) 966 967 def check_ntf(self): 968 while True: 969 try: 970 reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT) 971 except BlockingIOError: 972 return 973 974 nms = NlMsgs(reply) 975 self._recv_dbg_print(reply, nms) 976 for nl_msg in nms: 977 if nl_msg.error: 978 print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) 979 print(nl_msg) 980 continue 981 if nl_msg.done: 982 print("Netlink done while checking for ntf!?") 983 continue 984 985 decoded = self.nlproto.decode(self, nl_msg, None) 986 if decoded.cmd() not in self.async_msg_ids: 987 print("Unexpected msg id while checking for ntf", decoded) 988 continue 989 990 self.handle_ntf(decoded) 991 992 def poll_ntf(self, duration=None): 993 start_time = time.time() 994 selector = selectors.DefaultSelector() 995 selector.register(self.sock, selectors.EVENT_READ) 996 997 while True: 998 try: 999 yield self.async_msg_queue.get_nowait() 1000 except queue.Empty: 1001 if duration is not None: 1002 timeout = start_time + duration - time.time() 1003 if timeout <= 0: 1004 return 1005 else: 1006 timeout = None 1007 events = selector.select(timeout) 1008 if events: 1009 self.check_ntf() 1010 1011 def operation_do_attributes(self, name): 1012 """ 1013 For a given operation name, find and return a supported 1014 set of attributes (as a dict). 1015 """ 1016 op = self.find_operation(name) 1017 if not op: 1018 return None 1019 1020 return op['do']['request']['attributes'].copy() 1021 1022 def _encode_message(self, op, vals, flags, req_seq): 1023 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK 1024 for flag in flags or []: 1025 nl_flags |= flag 1026 1027 msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq) 1028 if op.fixed_header: 1029 msg += self._encode_struct(op.fixed_header, vals) 1030 search_attrs = SpaceAttrs(op.attr_set, vals) 1031 for name, value in vals.items(): 1032 msg += self._add_attr(op.attr_set.name, name, value, search_attrs) 1033 msg = _genl_msg_finalize(msg) 1034 return msg 1035 1036 def _ops(self, ops): 1037 reqs_by_seq = {} 1038 req_seq = random.randint(1024, 65535) 1039 payload = b'' 1040 for (method, vals, flags) in ops: 1041 op = self.ops[method] 1042 msg = self._encode_message(op, vals, flags, req_seq) 1043 reqs_by_seq[req_seq] = (op, vals, msg, flags) 1044 payload += msg 1045 req_seq += 1 1046 1047 self.sock.send(payload, 0) 1048 1049 done = False 1050 rsp = [] 1051 op_rsp = [] 1052 while not done: 1053 reply = self.sock.recv(self._recv_size) 1054 nms = NlMsgs(reply) 1055 self._recv_dbg_print(reply, nms) 1056 for nl_msg in nms: 1057 if nl_msg.nl_seq in reqs_by_seq: 1058 (op, vals, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq] 1059 if nl_msg.extack: 1060 nl_msg.annotate_extack(op.attr_set) 1061 self._decode_extack(req_msg, op, nl_msg.extack, vals) 1062 else: 1063 op = None 1064 req_flags = [] 1065 1066 if nl_msg.error: 1067 raise NlError(nl_msg) 1068 if nl_msg.done: 1069 if nl_msg.extack: 1070 print("Netlink warning:") 1071 print(nl_msg) 1072 1073 if Netlink.NLM_F_DUMP in req_flags: 1074 rsp.append(op_rsp) 1075 elif not op_rsp: 1076 rsp.append(None) 1077 elif len(op_rsp) == 1: 1078 rsp.append(op_rsp[0]) 1079 else: 1080 rsp.append(op_rsp) 1081 op_rsp = [] 1082 1083 del reqs_by_seq[nl_msg.nl_seq] 1084 done = len(reqs_by_seq) == 0 1085 break 1086 1087 decoded = self.nlproto.decode(self, nl_msg, op) 1088 1089 # Check if this is a reply to our request 1090 if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value: 1091 if decoded.cmd() in self.async_msg_ids: 1092 self.handle_ntf(decoded) 1093 continue 1094 else: 1095 print('Unexpected message: ' + repr(decoded)) 1096 continue 1097 1098 rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name) 1099 if op.fixed_header: 1100 rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header)) 1101 op_rsp.append(rsp_msg) 1102 1103 return rsp 1104 1105 def _op(self, method, vals, flags=None, dump=False): 1106 req_flags = flags or [] 1107 if dump: 1108 req_flags.append(Netlink.NLM_F_DUMP) 1109 1110 ops = [(method, vals, req_flags)] 1111 return self._ops(ops)[0] 1112 1113 def do(self, method, vals, flags=None): 1114 return self._op(method, vals, flags) 1115 1116 def dump(self, method, vals): 1117 return self._op(method, vals, dump=True) 1118 1119 def do_multi(self, ops): 1120 return self._ops(ops) 1121