1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause 2 3from collections import namedtuple 4from enum import Enum 5import functools 6import os 7import random 8import socket 9import struct 10from struct import Struct 11import sys 12import ipaddress 13import uuid 14import queue 15import selectors 16import time 17 18from .nlspec import SpecFamily 19 20# 21# Generic Netlink code which should really be in some library, but I can't quickly find one. 22# 23 24 25class Netlink: 26 # Netlink socket 27 SOL_NETLINK = 270 28 29 NETLINK_ADD_MEMBERSHIP = 1 30 NETLINK_CAP_ACK = 10 31 NETLINK_EXT_ACK = 11 32 NETLINK_GET_STRICT_CHK = 12 33 34 # Netlink message 35 NLMSG_ERROR = 2 36 NLMSG_DONE = 3 37 38 NLM_F_REQUEST = 1 39 NLM_F_ACK = 4 40 NLM_F_ROOT = 0x100 41 NLM_F_MATCH = 0x200 42 43 NLM_F_REPLACE = 0x100 44 NLM_F_EXCL = 0x200 45 NLM_F_CREATE = 0x400 46 NLM_F_APPEND = 0x800 47 48 NLM_F_CAPPED = 0x100 49 NLM_F_ACK_TLVS = 0x200 50 51 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH 52 53 NLA_F_NESTED = 0x8000 54 NLA_F_NET_BYTEORDER = 0x4000 55 56 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER 57 58 # Genetlink defines 59 NETLINK_GENERIC = 16 60 61 GENL_ID_CTRL = 0x10 62 63 # nlctrl 64 CTRL_CMD_GETFAMILY = 3 65 66 CTRL_ATTR_FAMILY_ID = 1 67 CTRL_ATTR_FAMILY_NAME = 2 68 CTRL_ATTR_MAXATTR = 5 69 CTRL_ATTR_MCAST_GROUPS = 7 70 71 CTRL_ATTR_MCAST_GRP_NAME = 1 72 CTRL_ATTR_MCAST_GRP_ID = 2 73 74 # Extack types 75 NLMSGERR_ATTR_MSG = 1 76 NLMSGERR_ATTR_OFFS = 2 77 NLMSGERR_ATTR_COOKIE = 3 78 NLMSGERR_ATTR_POLICY = 4 79 NLMSGERR_ATTR_MISS_TYPE = 5 80 NLMSGERR_ATTR_MISS_NEST = 6 81 82 # Policy types 83 NL_POLICY_TYPE_ATTR_TYPE = 1 84 NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2 85 NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3 86 NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4 87 NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5 88 NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6 89 NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7 90 NL_POLICY_TYPE_ATTR_POLICY_IDX = 8 91 NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9 92 NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10 93 NL_POLICY_TYPE_ATTR_PAD = 11 94 NL_POLICY_TYPE_ATTR_MASK = 12 95 96 AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64', 97 's8', 's16', 's32', 's64', 98 'binary', 'string', 'nul-string', 99 'nested', 'nested-array', 100 'bitfield32', 'sint', 'uint']) 101 102class NlError(Exception): 103 def __init__(self, nl_msg): 104 self.nl_msg = nl_msg 105 self.error = -nl_msg.error 106 107 def __str__(self): 108 return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}" 109 110 111class ConfigError(Exception): 112 pass 113 114 115class NlAttr: 116 ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) 117 type_formats = { 118 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), 119 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), 120 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")), 121 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")), 122 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")), 123 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")), 124 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")), 125 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q")) 126 } 127 128 def __init__(self, raw, offset): 129 self._len, self._type = struct.unpack("HH", raw[offset : offset + 4]) 130 self.type = self._type & ~Netlink.NLA_TYPE_MASK 131 self.is_nest = self._type & Netlink.NLA_F_NESTED 132 self.payload_len = self._len 133 self.full_len = (self.payload_len + 3) & ~3 134 self.raw = raw[offset + 4 : offset + self.payload_len] 135 136 @classmethod 137 def get_format(cls, attr_type, byte_order=None): 138 format = cls.type_formats[attr_type] 139 if byte_order: 140 return format.big if byte_order == "big-endian" \ 141 else format.little 142 return format.native 143 144 def as_scalar(self, attr_type, byte_order=None): 145 format = self.get_format(attr_type, byte_order) 146 return format.unpack(self.raw)[0] 147 148 def as_auto_scalar(self, attr_type, byte_order=None): 149 if len(self.raw) != 4 and len(self.raw) != 8: 150 raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}") 151 real_type = attr_type[0] + str(len(self.raw) * 8) 152 format = self.get_format(real_type, byte_order) 153 return format.unpack(self.raw)[0] 154 155 def as_strz(self): 156 return self.raw.decode('ascii')[:-1] 157 158 def as_bin(self): 159 return self.raw 160 161 def as_c_array(self, type): 162 format = self.get_format(type) 163 return [ x[0] for x in format.iter_unpack(self.raw) ] 164 165 def __repr__(self): 166 return f"[type:{self.type} len:{self._len}] {self.raw}" 167 168 169class NlAttrs: 170 def __init__(self, msg, offset=0): 171 self.attrs = [] 172 173 while offset < len(msg): 174 attr = NlAttr(msg, offset) 175 offset += attr.full_len 176 self.attrs.append(attr) 177 178 def __iter__(self): 179 yield from self.attrs 180 181 def __repr__(self): 182 msg = '' 183 for a in self.attrs: 184 if msg: 185 msg += '\n' 186 msg += repr(a) 187 return msg 188 189 190class NlMsg: 191 def __init__(self, msg, offset, attr_space=None): 192 self.hdr = msg[offset : offset + 16] 193 194 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ 195 struct.unpack("IHHII", self.hdr) 196 197 self.raw = msg[offset + 16 : offset + self.nl_len] 198 199 self.error = 0 200 self.done = 0 201 202 extack_off = None 203 if self.nl_type == Netlink.NLMSG_ERROR: 204 self.error = struct.unpack("i", self.raw[0:4])[0] 205 self.done = 1 206 extack_off = 20 207 elif self.nl_type == Netlink.NLMSG_DONE: 208 self.error = struct.unpack("i", self.raw[0:4])[0] 209 self.done = 1 210 extack_off = 4 211 212 self.extack = None 213 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: 214 self.extack = dict() 215 extack_attrs = NlAttrs(self.raw[extack_off:]) 216 for extack in extack_attrs: 217 if extack.type == Netlink.NLMSGERR_ATTR_MSG: 218 self.extack['msg'] = extack.as_strz() 219 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: 220 self.extack['miss-type'] = extack.as_scalar('u32') 221 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: 222 self.extack['miss-nest'] = extack.as_scalar('u32') 223 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: 224 self.extack['bad-attr-offs'] = extack.as_scalar('u32') 225 elif extack.type == Netlink.NLMSGERR_ATTR_POLICY: 226 self.extack['policy'] = self._decode_policy(extack.raw) 227 else: 228 if 'unknown' not in self.extack: 229 self.extack['unknown'] = [] 230 self.extack['unknown'].append(extack) 231 232 if attr_space: 233 self.annotate_extack(attr_space) 234 235 def _decode_policy(self, raw): 236 policy = {} 237 for attr in NlAttrs(raw): 238 if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE: 239 type = attr.as_scalar('u32') 240 policy['type'] = Netlink.AttrType(type).name 241 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S: 242 policy['min-value'] = attr.as_scalar('s64') 243 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S: 244 policy['max-value'] = attr.as_scalar('s64') 245 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U: 246 policy['min-value'] = attr.as_scalar('u64') 247 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U: 248 policy['max-value'] = attr.as_scalar('u64') 249 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH: 250 policy['min-length'] = attr.as_scalar('u32') 251 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH: 252 policy['max-length'] = attr.as_scalar('u32') 253 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK: 254 policy['bitfield32-mask'] = attr.as_scalar('u32') 255 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK: 256 policy['mask'] = attr.as_scalar('u64') 257 return policy 258 259 def annotate_extack(self, attr_space): 260 """ Make extack more human friendly with attribute information """ 261 262 # We don't have the ability to parse nests yet, so only do global 263 if 'miss-type' in self.extack and 'miss-nest' not in self.extack: 264 miss_type = self.extack['miss-type'] 265 if miss_type in attr_space.attrs_by_val: 266 spec = attr_space.attrs_by_val[miss_type] 267 self.extack['miss-type'] = spec['name'] 268 if 'doc' in spec: 269 self.extack['miss-type-doc'] = spec['doc'] 270 271 def cmd(self): 272 return self.nl_type 273 274 def __repr__(self): 275 msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}" 276 if self.error: 277 msg += '\n\terror: ' + str(self.error) 278 if self.extack: 279 msg += '\n\textack: ' + repr(self.extack) 280 return msg 281 282 283class NlMsgs: 284 def __init__(self, data): 285 self.msgs = [] 286 287 offset = 0 288 while offset < len(data): 289 msg = NlMsg(data, offset) 290 offset += msg.nl_len 291 self.msgs.append(msg) 292 293 def __iter__(self): 294 yield from self.msgs 295 296 297genl_family_name_to_id = None 298 299 300def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): 301 # we prepend length in _genl_msg_finalize() 302 if seq is None: 303 seq = random.randint(1, 1024) 304 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 305 genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0) 306 return nlmsg + genlmsg 307 308 309def _genl_msg_finalize(msg): 310 return struct.pack("I", len(msg) + 4) + msg 311 312 313def _genl_load_families(): 314 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: 315 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 316 317 msg = _genl_msg(Netlink.GENL_ID_CTRL, 318 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, 319 Netlink.CTRL_CMD_GETFAMILY, 1) 320 msg = _genl_msg_finalize(msg) 321 322 sock.send(msg, 0) 323 324 global genl_family_name_to_id 325 genl_family_name_to_id = dict() 326 327 while True: 328 reply = sock.recv(128 * 1024) 329 nms = NlMsgs(reply) 330 for nl_msg in nms: 331 if nl_msg.error: 332 print("Netlink error:", nl_msg.error) 333 return 334 if nl_msg.done: 335 return 336 337 gm = GenlMsg(nl_msg) 338 fam = dict() 339 for attr in NlAttrs(gm.raw): 340 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: 341 fam['id'] = attr.as_scalar('u16') 342 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: 343 fam['name'] = attr.as_strz() 344 elif attr.type == Netlink.CTRL_ATTR_MAXATTR: 345 fam['maxattr'] = attr.as_scalar('u32') 346 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: 347 fam['mcast'] = dict() 348 for entry in NlAttrs(attr.raw): 349 mcast_name = None 350 mcast_id = None 351 for entry_attr in NlAttrs(entry.raw): 352 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: 353 mcast_name = entry_attr.as_strz() 354 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: 355 mcast_id = entry_attr.as_scalar('u32') 356 if mcast_name and mcast_id is not None: 357 fam['mcast'][mcast_name] = mcast_id 358 if 'name' in fam and 'id' in fam: 359 genl_family_name_to_id[fam['name']] = fam 360 361 362class GenlMsg: 363 def __init__(self, nl_msg): 364 self.nl = nl_msg 365 self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0) 366 self.raw = nl_msg.raw[4:] 367 368 def cmd(self): 369 return self.genl_cmd 370 371 def __repr__(self): 372 msg = repr(self.nl) 373 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" 374 for a in self.raw_attrs: 375 msg += '\t\t' + repr(a) + '\n' 376 return msg 377 378 379class NetlinkProtocol: 380 def __init__(self, family_name, proto_num): 381 self.family_name = family_name 382 self.proto_num = proto_num 383 384 def _message(self, nl_type, nl_flags, seq=None): 385 if seq is None: 386 seq = random.randint(1, 1024) 387 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 388 return nlmsg 389 390 def message(self, flags, command, version, seq=None): 391 return self._message(command, flags, seq) 392 393 def _decode(self, nl_msg): 394 return nl_msg 395 396 def decode(self, ynl, nl_msg, op): 397 msg = self._decode(nl_msg) 398 if op is None: 399 op = ynl.rsp_by_value[msg.cmd()] 400 fixed_header_size = ynl._struct_size(op.fixed_header) 401 msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size) 402 return msg 403 404 def get_mcast_id(self, mcast_name, mcast_groups): 405 if mcast_name not in mcast_groups: 406 raise Exception(f'Multicast group "{mcast_name}" not present in the spec') 407 return mcast_groups[mcast_name].value 408 409 def msghdr_size(self): 410 return 16 411 412 413class GenlProtocol(NetlinkProtocol): 414 def __init__(self, family_name): 415 super().__init__(family_name, Netlink.NETLINK_GENERIC) 416 417 global genl_family_name_to_id 418 if genl_family_name_to_id is None: 419 _genl_load_families() 420 421 self.genl_family = genl_family_name_to_id[family_name] 422 self.family_id = genl_family_name_to_id[family_name]['id'] 423 424 def message(self, flags, command, version, seq=None): 425 nlmsg = self._message(self.family_id, flags, seq) 426 genlmsg = struct.pack("BBH", command, version, 0) 427 return nlmsg + genlmsg 428 429 def _decode(self, nl_msg): 430 return GenlMsg(nl_msg) 431 432 def get_mcast_id(self, mcast_name, mcast_groups): 433 if mcast_name not in self.genl_family['mcast']: 434 raise Exception(f'Multicast group "{mcast_name}" not present in the family') 435 return self.genl_family['mcast'][mcast_name] 436 437 def msghdr_size(self): 438 return super().msghdr_size() + 4 439 440 441class SpaceAttrs: 442 SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values']) 443 444 def __init__(self, attr_space, attrs, outer = None): 445 outer_scopes = outer.scopes if outer else [] 446 inner_scope = self.SpecValuesPair(attr_space, attrs) 447 self.scopes = [inner_scope] + outer_scopes 448 449 def lookup(self, name): 450 for scope in self.scopes: 451 if name in scope.spec: 452 if name in scope.values: 453 return scope.values[name] 454 spec_name = scope.spec.yaml['name'] 455 raise Exception( 456 f"No value for '{name}' in attribute space '{spec_name}'") 457 raise Exception(f"Attribute '{name}' not defined in any attribute-set") 458 459 460# 461# YNL implementation details. 462# 463 464 465class YnlFamily(SpecFamily): 466 def __init__(self, def_path, schema=None, process_unknown=False, 467 recv_size=0): 468 super().__init__(def_path, schema) 469 470 self.include_raw = False 471 self.process_unknown = process_unknown 472 473 try: 474 if self.proto == "netlink-raw": 475 self.nlproto = NetlinkProtocol(self.yaml['name'], 476 self.yaml['protonum']) 477 else: 478 self.nlproto = GenlProtocol(self.yaml['name']) 479 except KeyError: 480 raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel") 481 482 self._recv_dbg = False 483 # Note that netlink will use conservative (min) message size for 484 # the first dump recv() on the socket, our setting will only matter 485 # from the second recv() on. 486 self._recv_size = recv_size if recv_size else 131072 487 # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo) 488 # for a message, so smaller receive sizes will lead to truncation. 489 # Note that the min size for other families may be larger than 4k! 490 if self._recv_size < 4000: 491 raise ConfigError() 492 493 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num) 494 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 495 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) 496 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1) 497 498 self.async_msg_ids = set() 499 self.async_msg_queue = queue.Queue() 500 501 for msg in self.msgs.values(): 502 if msg.is_async: 503 self.async_msg_ids.add(msg.rsp_value) 504 505 for op_name, op in self.ops.items(): 506 bound_f = functools.partial(self._op, op_name) 507 setattr(self, op.ident_name, bound_f) 508 509 510 def ntf_subscribe(self, mcast_name): 511 mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups) 512 self.sock.bind((0, 0)) 513 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, 514 mcast_id) 515 516 def set_recv_dbg(self, enabled): 517 self._recv_dbg = enabled 518 519 def _recv_dbg_print(self, reply, nl_msgs): 520 if not self._recv_dbg: 521 return 522 print("Recv: read", len(reply), "bytes,", 523 len(nl_msgs.msgs), "messages", file=sys.stderr) 524 for nl_msg in nl_msgs: 525 print(" ", nl_msg, file=sys.stderr) 526 527 def _encode_enum(self, attr_spec, value): 528 enum = self.consts[attr_spec['enum']] 529 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 530 scalar = 0 531 if isinstance(value, str): 532 value = [value] 533 for single_value in value: 534 scalar += enum.entries[single_value].user_value(as_flags = True) 535 return scalar 536 else: 537 return enum.entries[value].user_value() 538 539 def _get_scalar(self, attr_spec, value): 540 try: 541 return int(value) 542 except (ValueError, TypeError) as e: 543 if 'enum' in attr_spec: 544 return self._encode_enum(attr_spec, value) 545 if attr_spec.display_hint: 546 return self._from_string(value, attr_spec) 547 raise e 548 549 def _add_attr(self, space, name, value, search_attrs): 550 try: 551 attr = self.attr_sets[space][name] 552 except KeyError: 553 raise Exception(f"Space '{space}' has no attribute '{name}'") 554 nl_type = attr.value 555 556 if attr.is_multi and isinstance(value, list): 557 attr_payload = b'' 558 for subvalue in value: 559 attr_payload += self._add_attr(space, name, subvalue, search_attrs) 560 return attr_payload 561 562 if attr["type"] == 'nest': 563 nl_type |= Netlink.NLA_F_NESTED 564 attr_payload = b'' 565 sub_space = attr['nested-attributes'] 566 sub_attrs = SpaceAttrs(self.attr_sets[sub_space], value, search_attrs) 567 for subname, subvalue in value.items(): 568 attr_payload += self._add_attr(sub_space, subname, subvalue, sub_attrs) 569 elif attr["type"] == 'flag': 570 if not value: 571 # If value is absent or false then skip attribute creation. 572 return b'' 573 attr_payload = b'' 574 elif attr["type"] == 'string': 575 attr_payload = str(value).encode('ascii') + b'\x00' 576 elif attr["type"] == 'binary': 577 if value is None: 578 attr_payload = b'' 579 elif isinstance(value, bytes): 580 attr_payload = value 581 elif isinstance(value, str): 582 if attr.display_hint: 583 attr_payload = self._from_string(value, attr) 584 else: 585 attr_payload = bytes.fromhex(value) 586 elif isinstance(value, dict) and attr.struct_name: 587 attr_payload = self._encode_struct(attr.struct_name, value) 588 elif isinstance(value, list) and attr.sub_type in NlAttr.type_formats: 589 format = NlAttr.get_format(attr.sub_type) 590 attr_payload = b''.join([format.pack(x) for x in value]) 591 else: 592 raise Exception(f'Unknown type for binary attribute, value: {value}') 593 elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar: 594 scalar = self._get_scalar(attr, value) 595 if attr.is_auto_scalar: 596 attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64') 597 else: 598 attr_type = attr["type"] 599 format = NlAttr.get_format(attr_type, attr.byte_order) 600 attr_payload = format.pack(scalar) 601 elif attr['type'] in "bitfield32": 602 scalar_value = self._get_scalar(attr, value["value"]) 603 scalar_selector = self._get_scalar(attr, value["selector"]) 604 attr_payload = struct.pack("II", scalar_value, scalar_selector) 605 elif attr['type'] == 'sub-message': 606 msg_format, _ = self._resolve_selector(attr, search_attrs) 607 attr_payload = b'' 608 if msg_format.fixed_header: 609 attr_payload += self._encode_struct(msg_format.fixed_header, value) 610 if msg_format.attr_set: 611 if msg_format.attr_set in self.attr_sets: 612 nl_type |= Netlink.NLA_F_NESTED 613 sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs) 614 for subname, subvalue in value.items(): 615 attr_payload += self._add_attr(msg_format.attr_set, 616 subname, subvalue, sub_attrs) 617 else: 618 raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'") 619 else: 620 raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') 621 622 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) 623 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad 624 625 def _get_enum_or_unknown(self, enum, raw): 626 try: 627 name = enum.entries_by_val[raw].name 628 except KeyError as error: 629 if self.process_unknown: 630 name = f"Unknown({raw})" 631 else: 632 raise error 633 return name 634 635 def _decode_enum(self, raw, attr_spec): 636 enum = self.consts[attr_spec['enum']] 637 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 638 i = 0 639 value = set() 640 while raw: 641 if raw & 1: 642 value.add(self._get_enum_or_unknown(enum, i)) 643 raw >>= 1 644 i += 1 645 else: 646 value = self._get_enum_or_unknown(enum, raw) 647 return value 648 649 def _decode_binary(self, attr, attr_spec): 650 if attr_spec.struct_name: 651 decoded = self._decode_struct(attr.raw, attr_spec.struct_name) 652 elif attr_spec.sub_type: 653 decoded = attr.as_c_array(attr_spec.sub_type) 654 if 'enum' in attr_spec: 655 decoded = [ self._decode_enum(x, attr_spec) for x in decoded ] 656 elif attr_spec.display_hint: 657 decoded = [ self._formatted_string(x, attr_spec.display_hint) 658 for x in decoded ] 659 else: 660 decoded = attr.as_bin() 661 if attr_spec.display_hint: 662 decoded = self._formatted_string(decoded, attr_spec.display_hint) 663 return decoded 664 665 def _decode_array_attr(self, attr, attr_spec): 666 decoded = [] 667 offset = 0 668 while offset < len(attr.raw): 669 item = NlAttr(attr.raw, offset) 670 offset += item.full_len 671 672 if attr_spec["sub-type"] == 'nest': 673 subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes']) 674 decoded.append({ item.type: subattrs }) 675 elif attr_spec["sub-type"] == 'binary': 676 subattr = item.as_bin() 677 if attr_spec.display_hint: 678 subattr = self._formatted_string(subattr, attr_spec.display_hint) 679 decoded.append(subattr) 680 elif attr_spec["sub-type"] in NlAttr.type_formats: 681 subattr = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order) 682 if 'enum' in attr_spec: 683 subattr = self._decode_enum(subattr, attr_spec) 684 elif attr_spec.display_hint: 685 subattr = self._formatted_string(subattr, attr_spec.display_hint) 686 decoded.append(subattr) 687 else: 688 raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}') 689 return decoded 690 691 def _decode_nest_type_value(self, attr, attr_spec): 692 decoded = {} 693 value = attr 694 for name in attr_spec['type-value']: 695 value = NlAttr(value.raw, 0) 696 decoded[name] = value.type 697 subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes']) 698 decoded.update(subattrs) 699 return decoded 700 701 def _decode_unknown(self, attr): 702 if attr.is_nest: 703 return self._decode(NlAttrs(attr.raw), None) 704 else: 705 return attr.as_bin() 706 707 def _rsp_add(self, rsp, name, is_multi, decoded): 708 if is_multi is None: 709 if name in rsp and type(rsp[name]) is not list: 710 rsp[name] = [rsp[name]] 711 is_multi = True 712 else: 713 is_multi = False 714 715 if not is_multi: 716 rsp[name] = decoded 717 elif name in rsp: 718 rsp[name].append(decoded) 719 else: 720 rsp[name] = [decoded] 721 722 def _resolve_selector(self, attr_spec, search_attrs): 723 sub_msg = attr_spec.sub_message 724 if sub_msg not in self.sub_msgs: 725 raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}") 726 sub_msg_spec = self.sub_msgs[sub_msg] 727 728 selector = attr_spec.selector 729 value = search_attrs.lookup(selector) 730 if value not in sub_msg_spec.formats: 731 raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'") 732 733 spec = sub_msg_spec.formats[value] 734 return spec, value 735 736 def _decode_sub_msg(self, attr, attr_spec, search_attrs): 737 msg_format, _ = self._resolve_selector(attr_spec, search_attrs) 738 decoded = {} 739 offset = 0 740 if msg_format.fixed_header: 741 decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header)) 742 offset = self._struct_size(msg_format.fixed_header) 743 if msg_format.attr_set: 744 if msg_format.attr_set in self.attr_sets: 745 subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set) 746 decoded.update(subdict) 747 else: 748 raise Exception(f"Unknown attribute-set '{msg_format.attr_set}' when decoding '{attr_spec.name}'") 749 return decoded 750 751 def _decode(self, attrs, space, outer_attrs = None): 752 rsp = dict() 753 if space: 754 attr_space = self.attr_sets[space] 755 search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs) 756 757 for attr in attrs: 758 try: 759 attr_spec = attr_space.attrs_by_val[attr.type] 760 except (KeyError, UnboundLocalError): 761 if not self.process_unknown: 762 raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'") 763 attr_name = f"UnknownAttr({attr.type})" 764 self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr)) 765 continue 766 767 try: 768 if attr_spec["type"] == 'nest': 769 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs) 770 decoded = subdict 771 elif attr_spec["type"] == 'string': 772 decoded = attr.as_strz() 773 elif attr_spec["type"] == 'binary': 774 decoded = self._decode_binary(attr, attr_spec) 775 elif attr_spec["type"] == 'flag': 776 decoded = True 777 elif attr_spec.is_auto_scalar: 778 decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order) 779 if 'enum' in attr_spec: 780 decoded = self._decode_enum(decoded, attr_spec) 781 elif attr_spec["type"] in NlAttr.type_formats: 782 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) 783 if 'enum' in attr_spec: 784 decoded = self._decode_enum(decoded, attr_spec) 785 elif attr_spec.display_hint: 786 decoded = self._formatted_string(decoded, attr_spec.display_hint) 787 elif attr_spec["type"] == 'indexed-array': 788 decoded = self._decode_array_attr(attr, attr_spec) 789 elif attr_spec["type"] == 'bitfield32': 790 value, selector = struct.unpack("II", attr.raw) 791 if 'enum' in attr_spec: 792 value = self._decode_enum(value, attr_spec) 793 selector = self._decode_enum(selector, attr_spec) 794 decoded = {"value": value, "selector": selector} 795 elif attr_spec["type"] == 'sub-message': 796 decoded = self._decode_sub_msg(attr, attr_spec, search_attrs) 797 elif attr_spec["type"] == 'nest-type-value': 798 decoded = self._decode_nest_type_value(attr, attr_spec) 799 else: 800 if not self.process_unknown: 801 raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') 802 decoded = self._decode_unknown(attr) 803 804 self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded) 805 except: 806 print(f"Error decoding '{attr_spec.name}' from '{space}'") 807 raise 808 809 return rsp 810 811 def _decode_extack_path(self, attrs, attr_set, offset, target, search_attrs): 812 for attr in attrs: 813 try: 814 attr_spec = attr_set.attrs_by_val[attr.type] 815 except KeyError: 816 raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") 817 if offset > target: 818 break 819 if offset == target: 820 return '.' + attr_spec.name 821 822 if offset + attr.full_len <= target: 823 offset += attr.full_len 824 continue 825 826 pathname = attr_spec.name 827 if attr_spec['type'] == 'nest': 828 sub_attrs = self.attr_sets[attr_spec['nested-attributes']] 829 search_attrs = SpaceAttrs(sub_attrs, search_attrs.lookup(attr_spec['name'])) 830 elif attr_spec['type'] == 'sub-message': 831 msg_format, value = self._resolve_selector(attr_spec, search_attrs) 832 if msg_format is None: 833 raise Exception(f"Can't resolve sub-message of {attr_spec['name']} for extack") 834 sub_attrs = self.attr_sets[msg_format.attr_set] 835 pathname += f"({value})" 836 else: 837 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") 838 offset += 4 839 subpath = self._decode_extack_path(NlAttrs(attr.raw), sub_attrs, 840 offset, target, search_attrs) 841 if subpath is None: 842 return None 843 return '.' + pathname + subpath 844 845 return None 846 847 def _decode_extack(self, request, op, extack, vals): 848 if 'bad-attr-offs' not in extack: 849 return 850 851 msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op) 852 offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header) 853 search_attrs = SpaceAttrs(op.attr_set, vals) 854 path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset, 855 extack['bad-attr-offs'], search_attrs) 856 if path: 857 del extack['bad-attr-offs'] 858 extack['bad-attr'] = path 859 860 def _struct_size(self, name): 861 if name: 862 members = self.consts[name].members 863 size = 0 864 for m in members: 865 if m.type in ['pad', 'binary']: 866 if m.struct: 867 size += self._struct_size(m.struct) 868 else: 869 size += m.len 870 else: 871 format = NlAttr.get_format(m.type, m.byte_order) 872 size += format.size 873 return size 874 else: 875 return 0 876 877 def _decode_struct(self, data, name): 878 members = self.consts[name].members 879 attrs = dict() 880 offset = 0 881 for m in members: 882 value = None 883 if m.type == 'pad': 884 offset += m.len 885 elif m.type == 'binary': 886 if m.struct: 887 len = self._struct_size(m.struct) 888 value = self._decode_struct(data[offset : offset + len], 889 m.struct) 890 offset += len 891 else: 892 value = data[offset : offset + m.len] 893 offset += m.len 894 else: 895 format = NlAttr.get_format(m.type, m.byte_order) 896 [ value ] = format.unpack_from(data, offset) 897 offset += format.size 898 if value is not None: 899 if m.enum: 900 value = self._decode_enum(value, m) 901 elif m.display_hint: 902 value = self._formatted_string(value, m.display_hint) 903 attrs[m.name] = value 904 return attrs 905 906 def _encode_struct(self, name, vals): 907 members = self.consts[name].members 908 attr_payload = b'' 909 for m in members: 910 value = vals.pop(m.name) if m.name in vals else None 911 if m.type == 'pad': 912 attr_payload += bytearray(m.len) 913 elif m.type == 'binary': 914 if m.struct: 915 if value is None: 916 value = dict() 917 attr_payload += self._encode_struct(m.struct, value) 918 else: 919 if value is None: 920 attr_payload += bytearray(m.len) 921 else: 922 attr_payload += bytes.fromhex(value) 923 else: 924 if value is None: 925 value = 0 926 format = NlAttr.get_format(m.type, m.byte_order) 927 attr_payload += format.pack(value) 928 return attr_payload 929 930 def _formatted_string(self, raw, display_hint): 931 if display_hint == 'mac': 932 formatted = ':'.join('%02x' % b for b in raw) 933 elif display_hint == 'hex': 934 if isinstance(raw, int): 935 formatted = hex(raw) 936 else: 937 formatted = bytes.hex(raw, ' ') 938 elif display_hint in [ 'ipv4', 'ipv6' ]: 939 formatted = format(ipaddress.ip_address(raw)) 940 elif display_hint == 'uuid': 941 formatted = str(uuid.UUID(bytes=raw)) 942 else: 943 formatted = raw 944 return formatted 945 946 def _from_string(self, string, attr_spec): 947 if attr_spec.display_hint in ['ipv4', 'ipv6']: 948 ip = ipaddress.ip_address(string) 949 if attr_spec['type'] == 'binary': 950 raw = ip.packed 951 else: 952 raw = int(ip) 953 else: 954 raise Exception(f"Display hint '{attr_spec.display_hint}' not implemented" 955 f" when parsing '{attr_spec['name']}'") 956 return raw 957 958 def handle_ntf(self, decoded): 959 msg = dict() 960 if self.include_raw: 961 msg['raw'] = decoded 962 op = self.rsp_by_value[decoded.cmd()] 963 attrs = self._decode(decoded.raw_attrs, op.attr_set.name) 964 if op.fixed_header: 965 attrs.update(self._decode_struct(decoded.raw, op.fixed_header)) 966 967 msg['name'] = op['name'] 968 msg['msg'] = attrs 969 self.async_msg_queue.put(msg) 970 971 def check_ntf(self): 972 while True: 973 try: 974 reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT) 975 except BlockingIOError: 976 return 977 978 nms = NlMsgs(reply) 979 self._recv_dbg_print(reply, nms) 980 for nl_msg in nms: 981 if nl_msg.error: 982 print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) 983 print(nl_msg) 984 continue 985 if nl_msg.done: 986 print("Netlink done while checking for ntf!?") 987 continue 988 989 decoded = self.nlproto.decode(self, nl_msg, None) 990 if decoded.cmd() not in self.async_msg_ids: 991 print("Unexpected msg id while checking for ntf", decoded) 992 continue 993 994 self.handle_ntf(decoded) 995 996 def poll_ntf(self, duration=None): 997 start_time = time.time() 998 selector = selectors.DefaultSelector() 999 selector.register(self.sock, selectors.EVENT_READ) 1000 1001 while True: 1002 try: 1003 yield self.async_msg_queue.get_nowait() 1004 except queue.Empty: 1005 if duration is not None: 1006 timeout = start_time + duration - time.time() 1007 if timeout <= 0: 1008 return 1009 else: 1010 timeout = None 1011 events = selector.select(timeout) 1012 if events: 1013 self.check_ntf() 1014 1015 def operation_do_attributes(self, name): 1016 """ 1017 For a given operation name, find and return a supported 1018 set of attributes (as a dict). 1019 """ 1020 op = self.find_operation(name) 1021 if not op: 1022 return None 1023 1024 return op['do']['request']['attributes'].copy() 1025 1026 def _encode_message(self, op, vals, flags, req_seq): 1027 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK 1028 for flag in flags or []: 1029 nl_flags |= flag 1030 1031 msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq) 1032 if op.fixed_header: 1033 msg += self._encode_struct(op.fixed_header, vals) 1034 search_attrs = SpaceAttrs(op.attr_set, vals) 1035 for name, value in vals.items(): 1036 msg += self._add_attr(op.attr_set.name, name, value, search_attrs) 1037 msg = _genl_msg_finalize(msg) 1038 return msg 1039 1040 def _ops(self, ops): 1041 reqs_by_seq = {} 1042 req_seq = random.randint(1024, 65535) 1043 payload = b'' 1044 for (method, vals, flags) in ops: 1045 op = self.ops[method] 1046 msg = self._encode_message(op, vals, flags, req_seq) 1047 reqs_by_seq[req_seq] = (op, vals, msg, flags) 1048 payload += msg 1049 req_seq += 1 1050 1051 self.sock.send(payload, 0) 1052 1053 done = False 1054 rsp = [] 1055 op_rsp = [] 1056 while not done: 1057 reply = self.sock.recv(self._recv_size) 1058 nms = NlMsgs(reply) 1059 self._recv_dbg_print(reply, nms) 1060 for nl_msg in nms: 1061 if nl_msg.nl_seq in reqs_by_seq: 1062 (op, vals, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq] 1063 if nl_msg.extack: 1064 nl_msg.annotate_extack(op.attr_set) 1065 self._decode_extack(req_msg, op, nl_msg.extack, vals) 1066 else: 1067 op = None 1068 req_flags = [] 1069 1070 if nl_msg.error: 1071 raise NlError(nl_msg) 1072 if nl_msg.done: 1073 if nl_msg.extack: 1074 print("Netlink warning:") 1075 print(nl_msg) 1076 1077 if Netlink.NLM_F_DUMP in req_flags: 1078 rsp.append(op_rsp) 1079 elif not op_rsp: 1080 rsp.append(None) 1081 elif len(op_rsp) == 1: 1082 rsp.append(op_rsp[0]) 1083 else: 1084 rsp.append(op_rsp) 1085 op_rsp = [] 1086 1087 del reqs_by_seq[nl_msg.nl_seq] 1088 done = len(reqs_by_seq) == 0 1089 break 1090 1091 decoded = self.nlproto.decode(self, nl_msg, op) 1092 1093 # Check if this is a reply to our request 1094 if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value: 1095 if decoded.cmd() in self.async_msg_ids: 1096 self.handle_ntf(decoded) 1097 continue 1098 else: 1099 print('Unexpected message: ' + repr(decoded)) 1100 continue 1101 1102 rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name) 1103 if op.fixed_header: 1104 rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header)) 1105 op_rsp.append(rsp_msg) 1106 1107 return rsp 1108 1109 def _op(self, method, vals, flags=None, dump=False): 1110 req_flags = flags or [] 1111 if dump: 1112 req_flags.append(Netlink.NLM_F_DUMP) 1113 1114 ops = [(method, vals, req_flags)] 1115 return self._ops(ops)[0] 1116 1117 def do(self, method, vals, flags=None): 1118 return self._op(method, vals, flags) 1119 1120 def dump(self, method, vals): 1121 return self._op(method, vals, dump=True) 1122 1123 def do_multi(self, ops): 1124 return self._ops(ops) 1125