1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause 2 3from collections import namedtuple 4from enum import Enum 5import functools 6import os 7import random 8import socket 9import struct 10from struct import Struct 11import sys 12import yaml 13import ipaddress 14import uuid 15import queue 16import selectors 17import time 18 19from .nlspec import SpecFamily 20 21# 22# Generic Netlink code which should really be in some library, but I can't quickly find one. 23# 24 25 26class Netlink: 27 # Netlink socket 28 SOL_NETLINK = 270 29 30 NETLINK_ADD_MEMBERSHIP = 1 31 NETLINK_CAP_ACK = 10 32 NETLINK_EXT_ACK = 11 33 NETLINK_GET_STRICT_CHK = 12 34 35 # Netlink message 36 NLMSG_ERROR = 2 37 NLMSG_DONE = 3 38 39 NLM_F_REQUEST = 1 40 NLM_F_ACK = 4 41 NLM_F_ROOT = 0x100 42 NLM_F_MATCH = 0x200 43 44 NLM_F_REPLACE = 0x100 45 NLM_F_EXCL = 0x200 46 NLM_F_CREATE = 0x400 47 NLM_F_APPEND = 0x800 48 49 NLM_F_CAPPED = 0x100 50 NLM_F_ACK_TLVS = 0x200 51 52 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH 53 54 NLA_F_NESTED = 0x8000 55 NLA_F_NET_BYTEORDER = 0x4000 56 57 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER 58 59 # Genetlink defines 60 NETLINK_GENERIC = 16 61 62 GENL_ID_CTRL = 0x10 63 64 # nlctrl 65 CTRL_CMD_GETFAMILY = 3 66 67 CTRL_ATTR_FAMILY_ID = 1 68 CTRL_ATTR_FAMILY_NAME = 2 69 CTRL_ATTR_MAXATTR = 5 70 CTRL_ATTR_MCAST_GROUPS = 7 71 72 CTRL_ATTR_MCAST_GRP_NAME = 1 73 CTRL_ATTR_MCAST_GRP_ID = 2 74 75 # Extack types 76 NLMSGERR_ATTR_MSG = 1 77 NLMSGERR_ATTR_OFFS = 2 78 NLMSGERR_ATTR_COOKIE = 3 79 NLMSGERR_ATTR_POLICY = 4 80 NLMSGERR_ATTR_MISS_TYPE = 5 81 NLMSGERR_ATTR_MISS_NEST = 6 82 83 # Policy types 84 NL_POLICY_TYPE_ATTR_TYPE = 1 85 NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2 86 NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3 87 NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4 88 NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5 89 NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6 90 NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7 91 NL_POLICY_TYPE_ATTR_POLICY_IDX = 8 92 NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9 93 NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10 94 NL_POLICY_TYPE_ATTR_PAD = 11 95 NL_POLICY_TYPE_ATTR_MASK = 12 96 97 AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64', 98 's8', 's16', 's32', 's64', 99 'binary', 'string', 'nul-string', 100 'nested', 'nested-array', 101 'bitfield32', 'sint', 'uint']) 102 103class NlError(Exception): 104 def __init__(self, nl_msg): 105 self.nl_msg = nl_msg 106 self.error = -nl_msg.error 107 108 def __str__(self): 109 return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}" 110 111 112class ConfigError(Exception): 113 pass 114 115 116class NlAttr: 117 ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) 118 type_formats = { 119 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), 120 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), 121 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")), 122 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")), 123 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")), 124 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")), 125 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")), 126 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q")) 127 } 128 129 def __init__(self, raw, offset): 130 self._len, self._type = struct.unpack("HH", raw[offset : offset + 4]) 131 self.type = self._type & ~Netlink.NLA_TYPE_MASK 132 self.is_nest = self._type & Netlink.NLA_F_NESTED 133 self.payload_len = self._len 134 self.full_len = (self.payload_len + 3) & ~3 135 self.raw = raw[offset + 4 : offset + self.payload_len] 136 137 @classmethod 138 def get_format(cls, attr_type, byte_order=None): 139 format = cls.type_formats[attr_type] 140 if byte_order: 141 return format.big if byte_order == "big-endian" \ 142 else format.little 143 return format.native 144 145 def as_scalar(self, attr_type, byte_order=None): 146 format = self.get_format(attr_type, byte_order) 147 return format.unpack(self.raw)[0] 148 149 def as_auto_scalar(self, attr_type, byte_order=None): 150 if len(self.raw) != 4 and len(self.raw) != 8: 151 raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}") 152 real_type = attr_type[0] + str(len(self.raw) * 8) 153 format = self.get_format(real_type, byte_order) 154 return format.unpack(self.raw)[0] 155 156 def as_strz(self): 157 return self.raw.decode('ascii')[:-1] 158 159 def as_bin(self): 160 return self.raw 161 162 def as_c_array(self, type): 163 format = self.get_format(type) 164 return [ x[0] for x in format.iter_unpack(self.raw) ] 165 166 def __repr__(self): 167 return f"[type:{self.type} len:{self._len}] {self.raw}" 168 169 170class NlAttrs: 171 def __init__(self, msg, offset=0): 172 self.attrs = [] 173 174 while offset < len(msg): 175 attr = NlAttr(msg, offset) 176 offset += attr.full_len 177 self.attrs.append(attr) 178 179 def __iter__(self): 180 yield from self.attrs 181 182 def __repr__(self): 183 msg = '' 184 for a in self.attrs: 185 if msg: 186 msg += '\n' 187 msg += repr(a) 188 return msg 189 190 191class NlMsg: 192 def __init__(self, msg, offset, attr_space=None): 193 self.hdr = msg[offset : offset + 16] 194 195 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ 196 struct.unpack("IHHII", self.hdr) 197 198 self.raw = msg[offset + 16 : offset + self.nl_len] 199 200 self.error = 0 201 self.done = 0 202 203 extack_off = None 204 if self.nl_type == Netlink.NLMSG_ERROR: 205 self.error = struct.unpack("i", self.raw[0:4])[0] 206 self.done = 1 207 extack_off = 20 208 elif self.nl_type == Netlink.NLMSG_DONE: 209 self.error = struct.unpack("i", self.raw[0:4])[0] 210 self.done = 1 211 extack_off = 4 212 213 self.extack = None 214 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: 215 self.extack = dict() 216 extack_attrs = NlAttrs(self.raw[extack_off:]) 217 for extack in extack_attrs: 218 if extack.type == Netlink.NLMSGERR_ATTR_MSG: 219 self.extack['msg'] = extack.as_strz() 220 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: 221 self.extack['miss-type'] = extack.as_scalar('u32') 222 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: 223 self.extack['miss-nest'] = extack.as_scalar('u32') 224 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: 225 self.extack['bad-attr-offs'] = extack.as_scalar('u32') 226 elif extack.type == Netlink.NLMSGERR_ATTR_POLICY: 227 self.extack['policy'] = self._decode_policy(extack.raw) 228 else: 229 if 'unknown' not in self.extack: 230 self.extack['unknown'] = [] 231 self.extack['unknown'].append(extack) 232 233 if attr_space: 234 self.annotate_extack(attr_space) 235 236 def _decode_policy(self, raw): 237 policy = {} 238 for attr in NlAttrs(raw): 239 if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE: 240 type = attr.as_scalar('u32') 241 policy['type'] = Netlink.AttrType(type).name 242 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S: 243 policy['min-value'] = attr.as_scalar('s64') 244 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S: 245 policy['max-value'] = attr.as_scalar('s64') 246 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U: 247 policy['min-value'] = attr.as_scalar('u64') 248 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U: 249 policy['max-value'] = attr.as_scalar('u64') 250 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH: 251 policy['min-length'] = attr.as_scalar('u32') 252 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH: 253 policy['max-length'] = attr.as_scalar('u32') 254 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK: 255 policy['bitfield32-mask'] = attr.as_scalar('u32') 256 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK: 257 policy['mask'] = attr.as_scalar('u64') 258 return policy 259 260 def annotate_extack(self, attr_space): 261 """ Make extack more human friendly with attribute information """ 262 263 # We don't have the ability to parse nests yet, so only do global 264 if 'miss-type' in self.extack and 'miss-nest' not in self.extack: 265 miss_type = self.extack['miss-type'] 266 if miss_type in attr_space.attrs_by_val: 267 spec = attr_space.attrs_by_val[miss_type] 268 self.extack['miss-type'] = spec['name'] 269 if 'doc' in spec: 270 self.extack['miss-type-doc'] = spec['doc'] 271 272 def cmd(self): 273 return self.nl_type 274 275 def __repr__(self): 276 msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}" 277 if self.error: 278 msg += '\n\terror: ' + str(self.error) 279 if self.extack: 280 msg += '\n\textack: ' + repr(self.extack) 281 return msg 282 283 284class NlMsgs: 285 def __init__(self, data): 286 self.msgs = [] 287 288 offset = 0 289 while offset < len(data): 290 msg = NlMsg(data, offset) 291 offset += msg.nl_len 292 self.msgs.append(msg) 293 294 def __iter__(self): 295 yield from self.msgs 296 297 298genl_family_name_to_id = None 299 300 301def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): 302 # we prepend length in _genl_msg_finalize() 303 if seq is None: 304 seq = random.randint(1, 1024) 305 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 306 genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0) 307 return nlmsg + genlmsg 308 309 310def _genl_msg_finalize(msg): 311 return struct.pack("I", len(msg) + 4) + msg 312 313 314def _genl_load_families(): 315 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: 316 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 317 318 msg = _genl_msg(Netlink.GENL_ID_CTRL, 319 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, 320 Netlink.CTRL_CMD_GETFAMILY, 1) 321 msg = _genl_msg_finalize(msg) 322 323 sock.send(msg, 0) 324 325 global genl_family_name_to_id 326 genl_family_name_to_id = dict() 327 328 while True: 329 reply = sock.recv(128 * 1024) 330 nms = NlMsgs(reply) 331 for nl_msg in nms: 332 if nl_msg.error: 333 print("Netlink error:", nl_msg.error) 334 return 335 if nl_msg.done: 336 return 337 338 gm = GenlMsg(nl_msg) 339 fam = dict() 340 for attr in NlAttrs(gm.raw): 341 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: 342 fam['id'] = attr.as_scalar('u16') 343 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: 344 fam['name'] = attr.as_strz() 345 elif attr.type == Netlink.CTRL_ATTR_MAXATTR: 346 fam['maxattr'] = attr.as_scalar('u32') 347 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: 348 fam['mcast'] = dict() 349 for entry in NlAttrs(attr.raw): 350 mcast_name = None 351 mcast_id = None 352 for entry_attr in NlAttrs(entry.raw): 353 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: 354 mcast_name = entry_attr.as_strz() 355 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: 356 mcast_id = entry_attr.as_scalar('u32') 357 if mcast_name and mcast_id is not None: 358 fam['mcast'][mcast_name] = mcast_id 359 if 'name' in fam and 'id' in fam: 360 genl_family_name_to_id[fam['name']] = fam 361 362 363class GenlMsg: 364 def __init__(self, nl_msg): 365 self.nl = nl_msg 366 self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0) 367 self.raw = nl_msg.raw[4:] 368 369 def cmd(self): 370 return self.genl_cmd 371 372 def __repr__(self): 373 msg = repr(self.nl) 374 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" 375 for a in self.raw_attrs: 376 msg += '\t\t' + repr(a) + '\n' 377 return msg 378 379 380class NetlinkProtocol: 381 def __init__(self, family_name, proto_num): 382 self.family_name = family_name 383 self.proto_num = proto_num 384 385 def _message(self, nl_type, nl_flags, seq=None): 386 if seq is None: 387 seq = random.randint(1, 1024) 388 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 389 return nlmsg 390 391 def message(self, flags, command, version, seq=None): 392 return self._message(command, flags, seq) 393 394 def _decode(self, nl_msg): 395 return nl_msg 396 397 def decode(self, ynl, nl_msg, op): 398 msg = self._decode(nl_msg) 399 if op is None: 400 op = ynl.rsp_by_value[msg.cmd()] 401 fixed_header_size = ynl._struct_size(op.fixed_header) 402 msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size) 403 return msg 404 405 def get_mcast_id(self, mcast_name, mcast_groups): 406 if mcast_name not in mcast_groups: 407 raise Exception(f'Multicast group "{mcast_name}" not present in the spec') 408 return mcast_groups[mcast_name].value 409 410 def msghdr_size(self): 411 return 16 412 413 414class GenlProtocol(NetlinkProtocol): 415 def __init__(self, family_name): 416 super().__init__(family_name, Netlink.NETLINK_GENERIC) 417 418 global genl_family_name_to_id 419 if genl_family_name_to_id is None: 420 _genl_load_families() 421 422 self.genl_family = genl_family_name_to_id[family_name] 423 self.family_id = genl_family_name_to_id[family_name]['id'] 424 425 def message(self, flags, command, version, seq=None): 426 nlmsg = self._message(self.family_id, flags, seq) 427 genlmsg = struct.pack("BBH", command, version, 0) 428 return nlmsg + genlmsg 429 430 def _decode(self, nl_msg): 431 return GenlMsg(nl_msg) 432 433 def get_mcast_id(self, mcast_name, mcast_groups): 434 if mcast_name not in self.genl_family['mcast']: 435 raise Exception(f'Multicast group "{mcast_name}" not present in the family') 436 return self.genl_family['mcast'][mcast_name] 437 438 def msghdr_size(self): 439 return super().msghdr_size() + 4 440 441 442class SpaceAttrs: 443 SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values']) 444 445 def __init__(self, attr_space, attrs, outer = None): 446 outer_scopes = outer.scopes if outer else [] 447 inner_scope = self.SpecValuesPair(attr_space, attrs) 448 self.scopes = [inner_scope] + outer_scopes 449 450 def lookup(self, name): 451 for scope in self.scopes: 452 if name in scope.spec: 453 if name in scope.values: 454 return scope.values[name] 455 spec_name = scope.spec.yaml['name'] 456 raise Exception( 457 f"No value for '{name}' in attribute space '{spec_name}'") 458 raise Exception(f"Attribute '{name}' not defined in any attribute-set") 459 460 461# 462# YNL implementation details. 463# 464 465 466class YnlFamily(SpecFamily): 467 def __init__(self, def_path, schema=None, process_unknown=False, 468 recv_size=0): 469 super().__init__(def_path, schema) 470 471 self.include_raw = False 472 self.process_unknown = process_unknown 473 474 try: 475 if self.proto == "netlink-raw": 476 self.nlproto = NetlinkProtocol(self.yaml['name'], 477 self.yaml['protonum']) 478 else: 479 self.nlproto = GenlProtocol(self.yaml['name']) 480 except KeyError: 481 raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel") 482 483 self._recv_dbg = False 484 # Note that netlink will use conservative (min) message size for 485 # the first dump recv() on the socket, our setting will only matter 486 # from the second recv() on. 487 self._recv_size = recv_size if recv_size else 131072 488 # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo) 489 # for a message, so smaller receive sizes will lead to truncation. 490 # Note that the min size for other families may be larger than 4k! 491 if self._recv_size < 4000: 492 raise ConfigError() 493 494 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num) 495 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 496 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) 497 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1) 498 499 self.async_msg_ids = set() 500 self.async_msg_queue = queue.Queue() 501 502 for msg in self.msgs.values(): 503 if msg.is_async: 504 self.async_msg_ids.add(msg.rsp_value) 505 506 for op_name, op in self.ops.items(): 507 bound_f = functools.partial(self._op, op_name) 508 setattr(self, op.ident_name, bound_f) 509 510 511 def ntf_subscribe(self, mcast_name): 512 mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups) 513 self.sock.bind((0, 0)) 514 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, 515 mcast_id) 516 517 def set_recv_dbg(self, enabled): 518 self._recv_dbg = enabled 519 520 def _recv_dbg_print(self, reply, nl_msgs): 521 if not self._recv_dbg: 522 return 523 print("Recv: read", len(reply), "bytes,", 524 len(nl_msgs.msgs), "messages", file=sys.stderr) 525 for nl_msg in nl_msgs: 526 print(" ", nl_msg, file=sys.stderr) 527 528 def _encode_enum(self, attr_spec, value): 529 enum = self.consts[attr_spec['enum']] 530 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 531 scalar = 0 532 if isinstance(value, str): 533 value = [value] 534 for single_value in value: 535 scalar += enum.entries[single_value].user_value(as_flags = True) 536 return scalar 537 else: 538 return enum.entries[value].user_value() 539 540 def _get_scalar(self, attr_spec, value): 541 try: 542 return int(value) 543 except (ValueError, TypeError) as e: 544 if 'enum' in attr_spec: 545 return self._encode_enum(attr_spec, value) 546 if attr_spec.display_hint: 547 return self._from_string(value, attr_spec) 548 raise e 549 550 def _add_attr(self, space, name, value, search_attrs): 551 try: 552 attr = self.attr_sets[space][name] 553 except KeyError: 554 raise Exception(f"Space '{space}' has no attribute '{name}'") 555 nl_type = attr.value 556 557 if attr.is_multi and isinstance(value, list): 558 attr_payload = b'' 559 for subvalue in value: 560 attr_payload += self._add_attr(space, name, subvalue, search_attrs) 561 return attr_payload 562 563 if attr["type"] == 'nest': 564 nl_type |= Netlink.NLA_F_NESTED 565 attr_payload = b'' 566 sub_space = attr['nested-attributes'] 567 sub_attrs = SpaceAttrs(self.attr_sets[sub_space], value, search_attrs) 568 for subname, subvalue in value.items(): 569 attr_payload += self._add_attr(sub_space, subname, subvalue, sub_attrs) 570 elif attr["type"] == 'flag': 571 if not value: 572 # If value is absent or false then skip attribute creation. 573 return b'' 574 attr_payload = b'' 575 elif attr["type"] == 'string': 576 attr_payload = str(value).encode('ascii') + b'\x00' 577 elif attr["type"] == 'binary': 578 if value is None: 579 attr_payload = b'' 580 elif isinstance(value, bytes): 581 attr_payload = value 582 elif isinstance(value, str): 583 if attr.display_hint: 584 attr_payload = self._from_string(value, attr) 585 else: 586 attr_payload = bytes.fromhex(value) 587 elif isinstance(value, dict) and attr.struct_name: 588 attr_payload = self._encode_struct(attr.struct_name, value) 589 elif isinstance(value, list) and attr.sub_type in NlAttr.type_formats: 590 format = NlAttr.get_format(attr.sub_type) 591 attr_payload = b''.join([format.pack(x) for x in value]) 592 else: 593 raise Exception(f'Unknown type for binary attribute, value: {value}') 594 elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar: 595 scalar = self._get_scalar(attr, value) 596 if attr.is_auto_scalar: 597 attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64') 598 else: 599 attr_type = attr["type"] 600 format = NlAttr.get_format(attr_type, attr.byte_order) 601 attr_payload = format.pack(scalar) 602 elif attr['type'] in "bitfield32": 603 scalar_value = self._get_scalar(attr, value["value"]) 604 scalar_selector = self._get_scalar(attr, value["selector"]) 605 attr_payload = struct.pack("II", scalar_value, scalar_selector) 606 elif attr['type'] == 'sub-message': 607 msg_format, _ = self._resolve_selector(attr, search_attrs) 608 attr_payload = b'' 609 if msg_format.fixed_header: 610 attr_payload += self._encode_struct(msg_format.fixed_header, value) 611 if msg_format.attr_set: 612 if msg_format.attr_set in self.attr_sets: 613 nl_type |= Netlink.NLA_F_NESTED 614 sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs) 615 for subname, subvalue in value.items(): 616 attr_payload += self._add_attr(msg_format.attr_set, 617 subname, subvalue, sub_attrs) 618 else: 619 raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'") 620 else: 621 raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') 622 623 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) 624 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad 625 626 def _get_enum_or_unknown(self, enum, raw): 627 try: 628 name = enum.entries_by_val[raw].name 629 except KeyError as error: 630 if self.process_unknown: 631 name = f"Unknown({raw})" 632 else: 633 raise error 634 return name 635 636 def _decode_enum(self, raw, attr_spec): 637 enum = self.consts[attr_spec['enum']] 638 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 639 i = 0 640 value = set() 641 while raw: 642 if raw & 1: 643 value.add(self._get_enum_or_unknown(enum, i)) 644 raw >>= 1 645 i += 1 646 else: 647 value = self._get_enum_or_unknown(enum, raw) 648 return value 649 650 def _decode_binary(self, attr, attr_spec): 651 if attr_spec.struct_name: 652 decoded = self._decode_struct(attr.raw, attr_spec.struct_name) 653 elif attr_spec.sub_type: 654 decoded = attr.as_c_array(attr_spec.sub_type) 655 if 'enum' in attr_spec: 656 decoded = [ self._decode_enum(x, attr_spec) for x in decoded ] 657 elif attr_spec.display_hint: 658 decoded = [ self._formatted_string(x, attr_spec.display_hint) 659 for x in decoded ] 660 else: 661 decoded = attr.as_bin() 662 if attr_spec.display_hint: 663 decoded = self._formatted_string(decoded, attr_spec.display_hint) 664 return decoded 665 666 def _decode_array_attr(self, attr, attr_spec): 667 decoded = [] 668 offset = 0 669 while offset < len(attr.raw): 670 item = NlAttr(attr.raw, offset) 671 offset += item.full_len 672 673 if attr_spec["sub-type"] == 'nest': 674 subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes']) 675 decoded.append({ item.type: subattrs }) 676 elif attr_spec["sub-type"] == 'binary': 677 subattr = item.as_bin() 678 if attr_spec.display_hint: 679 subattr = self._formatted_string(subattr, attr_spec.display_hint) 680 decoded.append(subattr) 681 elif attr_spec["sub-type"] in NlAttr.type_formats: 682 subattr = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order) 683 if 'enum' in attr_spec: 684 subattr = self._decode_enum(subattr, attr_spec) 685 elif attr_spec.display_hint: 686 subattr = self._formatted_string(subattr, attr_spec.display_hint) 687 decoded.append(subattr) 688 else: 689 raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}') 690 return decoded 691 692 def _decode_nest_type_value(self, attr, attr_spec): 693 decoded = {} 694 value = attr 695 for name in attr_spec['type-value']: 696 value = NlAttr(value.raw, 0) 697 decoded[name] = value.type 698 subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes']) 699 decoded.update(subattrs) 700 return decoded 701 702 def _decode_unknown(self, attr): 703 if attr.is_nest: 704 return self._decode(NlAttrs(attr.raw), None) 705 else: 706 return attr.as_bin() 707 708 def _rsp_add(self, rsp, name, is_multi, decoded): 709 if is_multi == None: 710 if name in rsp and type(rsp[name]) is not list: 711 rsp[name] = [rsp[name]] 712 is_multi = True 713 else: 714 is_multi = False 715 716 if not is_multi: 717 rsp[name] = decoded 718 elif name in rsp: 719 rsp[name].append(decoded) 720 else: 721 rsp[name] = [decoded] 722 723 def _resolve_selector(self, attr_spec, search_attrs): 724 sub_msg = attr_spec.sub_message 725 if sub_msg not in self.sub_msgs: 726 raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}") 727 sub_msg_spec = self.sub_msgs[sub_msg] 728 729 selector = attr_spec.selector 730 value = search_attrs.lookup(selector) 731 if value not in sub_msg_spec.formats: 732 raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'") 733 734 spec = sub_msg_spec.formats[value] 735 return spec, value 736 737 def _decode_sub_msg(self, attr, attr_spec, search_attrs): 738 msg_format, _ = self._resolve_selector(attr_spec, search_attrs) 739 decoded = {} 740 offset = 0 741 if msg_format.fixed_header: 742 decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header)); 743 offset = self._struct_size(msg_format.fixed_header) 744 if msg_format.attr_set: 745 if msg_format.attr_set in self.attr_sets: 746 subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set) 747 decoded.update(subdict) 748 else: 749 raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'") 750 return decoded 751 752 def _decode(self, attrs, space, outer_attrs = None): 753 rsp = dict() 754 if space: 755 attr_space = self.attr_sets[space] 756 search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs) 757 758 for attr in attrs: 759 try: 760 attr_spec = attr_space.attrs_by_val[attr.type] 761 except (KeyError, UnboundLocalError): 762 if not self.process_unknown: 763 raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'") 764 attr_name = f"UnknownAttr({attr.type})" 765 self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr)) 766 continue 767 768 try: 769 if attr_spec["type"] == 'nest': 770 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs) 771 decoded = subdict 772 elif attr_spec["type"] == 'string': 773 decoded = attr.as_strz() 774 elif attr_spec["type"] == 'binary': 775 decoded = self._decode_binary(attr, attr_spec) 776 elif attr_spec["type"] == 'flag': 777 decoded = True 778 elif attr_spec.is_auto_scalar: 779 decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order) 780 if 'enum' in attr_spec: 781 decoded = self._decode_enum(decoded, attr_spec) 782 elif attr_spec["type"] in NlAttr.type_formats: 783 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) 784 if 'enum' in attr_spec: 785 decoded = self._decode_enum(decoded, attr_spec) 786 elif attr_spec.display_hint: 787 decoded = self._formatted_string(decoded, attr_spec.display_hint) 788 elif attr_spec["type"] == 'indexed-array': 789 decoded = self._decode_array_attr(attr, attr_spec) 790 elif attr_spec["type"] == 'bitfield32': 791 value, selector = struct.unpack("II", attr.raw) 792 if 'enum' in attr_spec: 793 value = self._decode_enum(value, attr_spec) 794 selector = self._decode_enum(selector, attr_spec) 795 decoded = {"value": value, "selector": selector} 796 elif attr_spec["type"] == 'sub-message': 797 decoded = self._decode_sub_msg(attr, attr_spec, search_attrs) 798 elif attr_spec["type"] == 'nest-type-value': 799 decoded = self._decode_nest_type_value(attr, attr_spec) 800 else: 801 if not self.process_unknown: 802 raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') 803 decoded = self._decode_unknown(attr) 804 805 self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded) 806 except: 807 print(f"Error decoding '{attr_spec.name}' from '{space}'") 808 raise 809 810 return rsp 811 812 def _decode_extack_path(self, attrs, attr_set, offset, target, search_attrs): 813 for attr in attrs: 814 try: 815 attr_spec = attr_set.attrs_by_val[attr.type] 816 except KeyError: 817 raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") 818 if offset > target: 819 break 820 if offset == target: 821 return '.' + attr_spec.name 822 823 if offset + attr.full_len <= target: 824 offset += attr.full_len 825 continue 826 827 pathname = attr_spec.name 828 if attr_spec['type'] == 'nest': 829 sub_attrs = self.attr_sets[attr_spec['nested-attributes']] 830 search_attrs = SpaceAttrs(sub_attrs, search_attrs.lookup(attr_spec['name'])) 831 elif attr_spec['type'] == 'sub-message': 832 msg_format, value = self._resolve_selector(attr_spec, search_attrs) 833 if msg_format is None: 834 raise Exception(f"Can't resolve sub-message of {attr_spec['name']} for extack") 835 sub_attrs = self.attr_sets[msg_format.attr_set] 836 pathname += f"({value})" 837 else: 838 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") 839 offset += 4 840 subpath = self._decode_extack_path(NlAttrs(attr.raw), sub_attrs, 841 offset, target, search_attrs) 842 if subpath is None: 843 return None 844 return '.' + pathname + subpath 845 846 return None 847 848 def _decode_extack(self, request, op, extack, vals): 849 if 'bad-attr-offs' not in extack: 850 return 851 852 msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op) 853 offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header) 854 search_attrs = SpaceAttrs(op.attr_set, vals) 855 path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset, 856 extack['bad-attr-offs'], search_attrs) 857 if path: 858 del extack['bad-attr-offs'] 859 extack['bad-attr'] = path 860 861 def _struct_size(self, name): 862 if name: 863 members = self.consts[name].members 864 size = 0 865 for m in members: 866 if m.type in ['pad', 'binary']: 867 if m.struct: 868 size += self._struct_size(m.struct) 869 else: 870 size += m.len 871 else: 872 format = NlAttr.get_format(m.type, m.byte_order) 873 size += format.size 874 return size 875 else: 876 return 0 877 878 def _decode_struct(self, data, name): 879 members = self.consts[name].members 880 attrs = dict() 881 offset = 0 882 for m in members: 883 value = None 884 if m.type == 'pad': 885 offset += m.len 886 elif m.type == 'binary': 887 if m.struct: 888 len = self._struct_size(m.struct) 889 value = self._decode_struct(data[offset : offset + len], 890 m.struct) 891 offset += len 892 else: 893 value = data[offset : offset + m.len] 894 offset += m.len 895 else: 896 format = NlAttr.get_format(m.type, m.byte_order) 897 [ value ] = format.unpack_from(data, offset) 898 offset += format.size 899 if value is not None: 900 if m.enum: 901 value = self._decode_enum(value, m) 902 elif m.display_hint: 903 value = self._formatted_string(value, m.display_hint) 904 attrs[m.name] = value 905 return attrs 906 907 def _encode_struct(self, name, vals): 908 members = self.consts[name].members 909 attr_payload = b'' 910 for m in members: 911 value = vals.pop(m.name) if m.name in vals else None 912 if m.type == 'pad': 913 attr_payload += bytearray(m.len) 914 elif m.type == 'binary': 915 if m.struct: 916 if value is None: 917 value = dict() 918 attr_payload += self._encode_struct(m.struct, value) 919 else: 920 if value is None: 921 attr_payload += bytearray(m.len) 922 else: 923 attr_payload += bytes.fromhex(value) 924 else: 925 if value is None: 926 value = 0 927 format = NlAttr.get_format(m.type, m.byte_order) 928 attr_payload += format.pack(value) 929 return attr_payload 930 931 def _formatted_string(self, raw, display_hint): 932 if display_hint == 'mac': 933 formatted = ':'.join('%02x' % b for b in raw) 934 elif display_hint == 'hex': 935 if isinstance(raw, int): 936 formatted = hex(raw) 937 else: 938 formatted = bytes.hex(raw, ' ') 939 elif display_hint in [ 'ipv4', 'ipv6' ]: 940 formatted = format(ipaddress.ip_address(raw)) 941 elif display_hint == 'uuid': 942 formatted = str(uuid.UUID(bytes=raw)) 943 else: 944 formatted = raw 945 return formatted 946 947 def _from_string(self, string, attr_spec): 948 if attr_spec.display_hint in ['ipv4', 'ipv6']: 949 ip = ipaddress.ip_address(string) 950 if attr_spec['type'] == 'binary': 951 raw = ip.packed 952 else: 953 raw = int(ip) 954 else: 955 raise Exception(f"Display hint '{attr_spec.display_hint}' not implemented" 956 f" when parsing '{attr_spec['name']}'") 957 return raw 958 959 def handle_ntf(self, decoded): 960 msg = dict() 961 if self.include_raw: 962 msg['raw'] = decoded 963 op = self.rsp_by_value[decoded.cmd()] 964 attrs = self._decode(decoded.raw_attrs, op.attr_set.name) 965 if op.fixed_header: 966 attrs.update(self._decode_struct(decoded.raw, op.fixed_header)) 967 968 msg['name'] = op['name'] 969 msg['msg'] = attrs 970 self.async_msg_queue.put(msg) 971 972 def check_ntf(self): 973 while True: 974 try: 975 reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT) 976 except BlockingIOError: 977 return 978 979 nms = NlMsgs(reply) 980 self._recv_dbg_print(reply, nms) 981 for nl_msg in nms: 982 if nl_msg.error: 983 print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) 984 print(nl_msg) 985 continue 986 if nl_msg.done: 987 print("Netlink done while checking for ntf!?") 988 continue 989 990 decoded = self.nlproto.decode(self, nl_msg, None) 991 if decoded.cmd() not in self.async_msg_ids: 992 print("Unexpected msg id while checking for ntf", decoded) 993 continue 994 995 self.handle_ntf(decoded) 996 997 def poll_ntf(self, duration=None): 998 start_time = time.time() 999 selector = selectors.DefaultSelector() 1000 selector.register(self.sock, selectors.EVENT_READ) 1001 1002 while True: 1003 try: 1004 yield self.async_msg_queue.get_nowait() 1005 except queue.Empty: 1006 if duration is not None: 1007 timeout = start_time + duration - time.time() 1008 if timeout <= 0: 1009 return 1010 else: 1011 timeout = None 1012 events = selector.select(timeout) 1013 if events: 1014 self.check_ntf() 1015 1016 def operation_do_attributes(self, name): 1017 """ 1018 For a given operation name, find and return a supported 1019 set of attributes (as a dict). 1020 """ 1021 op = self.find_operation(name) 1022 if not op: 1023 return None 1024 1025 return op['do']['request']['attributes'].copy() 1026 1027 def _encode_message(self, op, vals, flags, req_seq): 1028 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK 1029 for flag in flags or []: 1030 nl_flags |= flag 1031 1032 msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq) 1033 if op.fixed_header: 1034 msg += self._encode_struct(op.fixed_header, vals) 1035 search_attrs = SpaceAttrs(op.attr_set, vals) 1036 for name, value in vals.items(): 1037 msg += self._add_attr(op.attr_set.name, name, value, search_attrs) 1038 msg = _genl_msg_finalize(msg) 1039 return msg 1040 1041 def _ops(self, ops): 1042 reqs_by_seq = {} 1043 req_seq = random.randint(1024, 65535) 1044 payload = b'' 1045 for (method, vals, flags) in ops: 1046 op = self.ops[method] 1047 msg = self._encode_message(op, vals, flags, req_seq) 1048 reqs_by_seq[req_seq] = (op, vals, msg, flags) 1049 payload += msg 1050 req_seq += 1 1051 1052 self.sock.send(payload, 0) 1053 1054 done = False 1055 rsp = [] 1056 op_rsp = [] 1057 while not done: 1058 reply = self.sock.recv(self._recv_size) 1059 nms = NlMsgs(reply) 1060 self._recv_dbg_print(reply, nms) 1061 for nl_msg in nms: 1062 if nl_msg.nl_seq in reqs_by_seq: 1063 (op, vals, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq] 1064 if nl_msg.extack: 1065 nl_msg.annotate_extack(op.attr_set) 1066 self._decode_extack(req_msg, op, nl_msg.extack, vals) 1067 else: 1068 op = None 1069 req_flags = [] 1070 1071 if nl_msg.error: 1072 raise NlError(nl_msg) 1073 if nl_msg.done: 1074 if nl_msg.extack: 1075 print("Netlink warning:") 1076 print(nl_msg) 1077 1078 if Netlink.NLM_F_DUMP in req_flags: 1079 rsp.append(op_rsp) 1080 elif not op_rsp: 1081 rsp.append(None) 1082 elif len(op_rsp) == 1: 1083 rsp.append(op_rsp[0]) 1084 else: 1085 rsp.append(op_rsp) 1086 op_rsp = [] 1087 1088 del reqs_by_seq[nl_msg.nl_seq] 1089 done = len(reqs_by_seq) == 0 1090 break 1091 1092 decoded = self.nlproto.decode(self, nl_msg, op) 1093 1094 # Check if this is a reply to our request 1095 if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value: 1096 if decoded.cmd() in self.async_msg_ids: 1097 self.handle_ntf(decoded) 1098 continue 1099 else: 1100 print('Unexpected message: ' + repr(decoded)) 1101 continue 1102 1103 rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name) 1104 if op.fixed_header: 1105 rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header)) 1106 op_rsp.append(rsp_msg) 1107 1108 return rsp 1109 1110 def _op(self, method, vals, flags=None, dump=False): 1111 req_flags = flags or [] 1112 if dump: 1113 req_flags.append(Netlink.NLM_F_DUMP) 1114 1115 ops = [(method, vals, req_flags)] 1116 return self._ops(ops)[0] 1117 1118 def do(self, method, vals, flags=None): 1119 return self._op(method, vals, flags) 1120 1121 def dump(self, method, vals): 1122 return self._op(method, vals, dump=True) 1123 1124 def do_multi(self, ops): 1125 return self._ops(ops) 1126