1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause 2 3from collections import namedtuple 4from enum import Enum 5import functools 6import os 7import random 8import socket 9import struct 10from struct import Struct 11import sys 12import yaml 13import ipaddress 14import uuid 15import queue 16import selectors 17import time 18 19from .nlspec import SpecFamily 20 21# 22# Generic Netlink code which should really be in some library, but I can't quickly find one. 23# 24 25 26class Netlink: 27 # Netlink socket 28 SOL_NETLINK = 270 29 30 NETLINK_ADD_MEMBERSHIP = 1 31 NETLINK_CAP_ACK = 10 32 NETLINK_EXT_ACK = 11 33 NETLINK_GET_STRICT_CHK = 12 34 35 # Netlink message 36 NLMSG_ERROR = 2 37 NLMSG_DONE = 3 38 39 NLM_F_REQUEST = 1 40 NLM_F_ACK = 4 41 NLM_F_ROOT = 0x100 42 NLM_F_MATCH = 0x200 43 44 NLM_F_REPLACE = 0x100 45 NLM_F_EXCL = 0x200 46 NLM_F_CREATE = 0x400 47 NLM_F_APPEND = 0x800 48 49 NLM_F_CAPPED = 0x100 50 NLM_F_ACK_TLVS = 0x200 51 52 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH 53 54 NLA_F_NESTED = 0x8000 55 NLA_F_NET_BYTEORDER = 0x4000 56 57 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER 58 59 # Genetlink defines 60 NETLINK_GENERIC = 16 61 62 GENL_ID_CTRL = 0x10 63 64 # nlctrl 65 CTRL_CMD_GETFAMILY = 3 66 67 CTRL_ATTR_FAMILY_ID = 1 68 CTRL_ATTR_FAMILY_NAME = 2 69 CTRL_ATTR_MAXATTR = 5 70 CTRL_ATTR_MCAST_GROUPS = 7 71 72 CTRL_ATTR_MCAST_GRP_NAME = 1 73 CTRL_ATTR_MCAST_GRP_ID = 2 74 75 # Extack types 76 NLMSGERR_ATTR_MSG = 1 77 NLMSGERR_ATTR_OFFS = 2 78 NLMSGERR_ATTR_COOKIE = 3 79 NLMSGERR_ATTR_POLICY = 4 80 NLMSGERR_ATTR_MISS_TYPE = 5 81 NLMSGERR_ATTR_MISS_NEST = 6 82 83 # Policy types 84 NL_POLICY_TYPE_ATTR_TYPE = 1 85 NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2 86 NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3 87 NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4 88 NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5 89 NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6 90 NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7 91 NL_POLICY_TYPE_ATTR_POLICY_IDX = 8 92 NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9 93 NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10 94 NL_POLICY_TYPE_ATTR_PAD = 11 95 NL_POLICY_TYPE_ATTR_MASK = 12 96 97 AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64', 98 's8', 's16', 's32', 's64', 99 'binary', 'string', 'nul-string', 100 'nested', 'nested-array', 101 'bitfield32', 'sint', 'uint']) 102 103class NlError(Exception): 104 def __init__(self, nl_msg): 105 self.nl_msg = nl_msg 106 self.error = -nl_msg.error 107 108 def __str__(self): 109 return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}" 110 111 112class ConfigError(Exception): 113 pass 114 115 116class NlAttr: 117 ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) 118 type_formats = { 119 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), 120 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), 121 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")), 122 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")), 123 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")), 124 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")), 125 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")), 126 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q")) 127 } 128 129 def __init__(self, raw, offset): 130 self._len, self._type = struct.unpack("HH", raw[offset : offset + 4]) 131 self.type = self._type & ~Netlink.NLA_TYPE_MASK 132 self.is_nest = self._type & Netlink.NLA_F_NESTED 133 self.payload_len = self._len 134 self.full_len = (self.payload_len + 3) & ~3 135 self.raw = raw[offset + 4 : offset + self.payload_len] 136 137 @classmethod 138 def get_format(cls, attr_type, byte_order=None): 139 format = cls.type_formats[attr_type] 140 if byte_order: 141 return format.big if byte_order == "big-endian" \ 142 else format.little 143 return format.native 144 145 def as_scalar(self, attr_type, byte_order=None): 146 format = self.get_format(attr_type, byte_order) 147 return format.unpack(self.raw)[0] 148 149 def as_auto_scalar(self, attr_type, byte_order=None): 150 if len(self.raw) != 4 and len(self.raw) != 8: 151 raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}") 152 real_type = attr_type[0] + str(len(self.raw) * 8) 153 format = self.get_format(real_type, byte_order) 154 return format.unpack(self.raw)[0] 155 156 def as_strz(self): 157 return self.raw.decode('ascii')[:-1] 158 159 def as_bin(self): 160 return self.raw 161 162 def as_c_array(self, type): 163 format = self.get_format(type) 164 return [ x[0] for x in format.iter_unpack(self.raw) ] 165 166 def __repr__(self): 167 return f"[type:{self.type} len:{self._len}] {self.raw}" 168 169 170class NlAttrs: 171 def __init__(self, msg, offset=0): 172 self.attrs = [] 173 174 while offset < len(msg): 175 attr = NlAttr(msg, offset) 176 offset += attr.full_len 177 self.attrs.append(attr) 178 179 def __iter__(self): 180 yield from self.attrs 181 182 def __repr__(self): 183 msg = '' 184 for a in self.attrs: 185 if msg: 186 msg += '\n' 187 msg += repr(a) 188 return msg 189 190 191class NlMsg: 192 def __init__(self, msg, offset, attr_space=None): 193 self.hdr = msg[offset : offset + 16] 194 195 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ 196 struct.unpack("IHHII", self.hdr) 197 198 self.raw = msg[offset + 16 : offset + self.nl_len] 199 200 self.error = 0 201 self.done = 0 202 203 extack_off = None 204 if self.nl_type == Netlink.NLMSG_ERROR: 205 self.error = struct.unpack("i", self.raw[0:4])[0] 206 self.done = 1 207 extack_off = 20 208 elif self.nl_type == Netlink.NLMSG_DONE: 209 self.error = struct.unpack("i", self.raw[0:4])[0] 210 self.done = 1 211 extack_off = 4 212 213 self.extack = None 214 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: 215 self.extack = dict() 216 extack_attrs = NlAttrs(self.raw[extack_off:]) 217 for extack in extack_attrs: 218 if extack.type == Netlink.NLMSGERR_ATTR_MSG: 219 self.extack['msg'] = extack.as_strz() 220 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: 221 self.extack['miss-type'] = extack.as_scalar('u32') 222 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: 223 self.extack['miss-nest'] = extack.as_scalar('u32') 224 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: 225 self.extack['bad-attr-offs'] = extack.as_scalar('u32') 226 elif extack.type == Netlink.NLMSGERR_ATTR_POLICY: 227 self.extack['policy'] = self._decode_policy(extack.raw) 228 else: 229 if 'unknown' not in self.extack: 230 self.extack['unknown'] = [] 231 self.extack['unknown'].append(extack) 232 233 if attr_space: 234 self.annotate_extack(attr_space) 235 236 def _decode_policy(self, raw): 237 policy = {} 238 for attr in NlAttrs(raw): 239 if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE: 240 type = attr.as_scalar('u32') 241 policy['type'] = Netlink.AttrType(type).name 242 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S: 243 policy['min-value'] = attr.as_scalar('s64') 244 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S: 245 policy['max-value'] = attr.as_scalar('s64') 246 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U: 247 policy['min-value'] = attr.as_scalar('u64') 248 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U: 249 policy['max-value'] = attr.as_scalar('u64') 250 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH: 251 policy['min-length'] = attr.as_scalar('u32') 252 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH: 253 policy['max-length'] = attr.as_scalar('u32') 254 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK: 255 policy['bitfield32-mask'] = attr.as_scalar('u32') 256 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK: 257 policy['mask'] = attr.as_scalar('u64') 258 return policy 259 260 def annotate_extack(self, attr_space): 261 """ Make extack more human friendly with attribute information """ 262 263 # We don't have the ability to parse nests yet, so only do global 264 if 'miss-type' in self.extack and 'miss-nest' not in self.extack: 265 miss_type = self.extack['miss-type'] 266 if miss_type in attr_space.attrs_by_val: 267 spec = attr_space.attrs_by_val[miss_type] 268 self.extack['miss-type'] = spec['name'] 269 if 'doc' in spec: 270 self.extack['miss-type-doc'] = spec['doc'] 271 272 def cmd(self): 273 return self.nl_type 274 275 def __repr__(self): 276 msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}" 277 if self.error: 278 msg += '\n\terror: ' + str(self.error) 279 if self.extack: 280 msg += '\n\textack: ' + repr(self.extack) 281 return msg 282 283 284class NlMsgs: 285 def __init__(self, data): 286 self.msgs = [] 287 288 offset = 0 289 while offset < len(data): 290 msg = NlMsg(data, offset) 291 offset += msg.nl_len 292 self.msgs.append(msg) 293 294 def __iter__(self): 295 yield from self.msgs 296 297 298genl_family_name_to_id = None 299 300 301def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): 302 # we prepend length in _genl_msg_finalize() 303 if seq is None: 304 seq = random.randint(1, 1024) 305 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 306 genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0) 307 return nlmsg + genlmsg 308 309 310def _genl_msg_finalize(msg): 311 return struct.pack("I", len(msg) + 4) + msg 312 313 314def _genl_load_families(): 315 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: 316 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 317 318 msg = _genl_msg(Netlink.GENL_ID_CTRL, 319 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, 320 Netlink.CTRL_CMD_GETFAMILY, 1) 321 msg = _genl_msg_finalize(msg) 322 323 sock.send(msg, 0) 324 325 global genl_family_name_to_id 326 genl_family_name_to_id = dict() 327 328 while True: 329 reply = sock.recv(128 * 1024) 330 nms = NlMsgs(reply) 331 for nl_msg in nms: 332 if nl_msg.error: 333 print("Netlink error:", nl_msg.error) 334 return 335 if nl_msg.done: 336 return 337 338 gm = GenlMsg(nl_msg) 339 fam = dict() 340 for attr in NlAttrs(gm.raw): 341 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: 342 fam['id'] = attr.as_scalar('u16') 343 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: 344 fam['name'] = attr.as_strz() 345 elif attr.type == Netlink.CTRL_ATTR_MAXATTR: 346 fam['maxattr'] = attr.as_scalar('u32') 347 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: 348 fam['mcast'] = dict() 349 for entry in NlAttrs(attr.raw): 350 mcast_name = None 351 mcast_id = None 352 for entry_attr in NlAttrs(entry.raw): 353 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: 354 mcast_name = entry_attr.as_strz() 355 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: 356 mcast_id = entry_attr.as_scalar('u32') 357 if mcast_name and mcast_id is not None: 358 fam['mcast'][mcast_name] = mcast_id 359 if 'name' in fam and 'id' in fam: 360 genl_family_name_to_id[fam['name']] = fam 361 362 363class GenlMsg: 364 def __init__(self, nl_msg): 365 self.nl = nl_msg 366 self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0) 367 self.raw = nl_msg.raw[4:] 368 369 def cmd(self): 370 return self.genl_cmd 371 372 def __repr__(self): 373 msg = repr(self.nl) 374 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" 375 for a in self.raw_attrs: 376 msg += '\t\t' + repr(a) + '\n' 377 return msg 378 379 380class NetlinkProtocol: 381 def __init__(self, family_name, proto_num): 382 self.family_name = family_name 383 self.proto_num = proto_num 384 385 def _message(self, nl_type, nl_flags, seq=None): 386 if seq is None: 387 seq = random.randint(1, 1024) 388 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 389 return nlmsg 390 391 def message(self, flags, command, version, seq=None): 392 return self._message(command, flags, seq) 393 394 def _decode(self, nl_msg): 395 return nl_msg 396 397 def decode(self, ynl, nl_msg, op): 398 msg = self._decode(nl_msg) 399 if op is None: 400 op = ynl.rsp_by_value[msg.cmd()] 401 fixed_header_size = ynl._struct_size(op.fixed_header) 402 msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size) 403 return msg 404 405 def get_mcast_id(self, mcast_name, mcast_groups): 406 if mcast_name not in mcast_groups: 407 raise Exception(f'Multicast group "{mcast_name}" not present in the spec') 408 return mcast_groups[mcast_name].value 409 410 def msghdr_size(self): 411 return 16 412 413 414class GenlProtocol(NetlinkProtocol): 415 def __init__(self, family_name): 416 super().__init__(family_name, Netlink.NETLINK_GENERIC) 417 418 global genl_family_name_to_id 419 if genl_family_name_to_id is None: 420 _genl_load_families() 421 422 self.genl_family = genl_family_name_to_id[family_name] 423 self.family_id = genl_family_name_to_id[family_name]['id'] 424 425 def message(self, flags, command, version, seq=None): 426 nlmsg = self._message(self.family_id, flags, seq) 427 genlmsg = struct.pack("BBH", command, version, 0) 428 return nlmsg + genlmsg 429 430 def _decode(self, nl_msg): 431 return GenlMsg(nl_msg) 432 433 def get_mcast_id(self, mcast_name, mcast_groups): 434 if mcast_name not in self.genl_family['mcast']: 435 raise Exception(f'Multicast group "{mcast_name}" not present in the family') 436 return self.genl_family['mcast'][mcast_name] 437 438 def msghdr_size(self): 439 return super().msghdr_size() + 4 440 441 442class SpaceAttrs: 443 SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values']) 444 445 def __init__(self, attr_space, attrs, outer = None): 446 outer_scopes = outer.scopes if outer else [] 447 inner_scope = self.SpecValuesPair(attr_space, attrs) 448 self.scopes = [inner_scope] + outer_scopes 449 450 def lookup(self, name): 451 for scope in self.scopes: 452 if name in scope.spec: 453 if name in scope.values: 454 return scope.values[name] 455 spec_name = scope.spec.yaml['name'] 456 raise Exception( 457 f"No value for '{name}' in attribute space '{spec_name}'") 458 raise Exception(f"Attribute '{name}' not defined in any attribute-set") 459 460 461# 462# YNL implementation details. 463# 464 465 466class YnlFamily(SpecFamily): 467 def __init__(self, def_path, schema=None, process_unknown=False, 468 recv_size=0): 469 super().__init__(def_path, schema) 470 471 self.include_raw = False 472 self.process_unknown = process_unknown 473 474 try: 475 if self.proto == "netlink-raw": 476 self.nlproto = NetlinkProtocol(self.yaml['name'], 477 self.yaml['protonum']) 478 else: 479 self.nlproto = GenlProtocol(self.yaml['name']) 480 except KeyError: 481 raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel") 482 483 self._recv_dbg = False 484 # Note that netlink will use conservative (min) message size for 485 # the first dump recv() on the socket, our setting will only matter 486 # from the second recv() on. 487 self._recv_size = recv_size if recv_size else 131072 488 # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo) 489 # for a message, so smaller receive sizes will lead to truncation. 490 # Note that the min size for other families may be larger than 4k! 491 if self._recv_size < 4000: 492 raise ConfigError() 493 494 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num) 495 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 496 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) 497 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1) 498 499 self.async_msg_ids = set() 500 self.async_msg_queue = queue.Queue() 501 502 for msg in self.msgs.values(): 503 if msg.is_async: 504 self.async_msg_ids.add(msg.rsp_value) 505 506 for op_name, op in self.ops.items(): 507 bound_f = functools.partial(self._op, op_name) 508 setattr(self, op.ident_name, bound_f) 509 510 511 def ntf_subscribe(self, mcast_name): 512 mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups) 513 self.sock.bind((0, 0)) 514 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, 515 mcast_id) 516 517 def set_recv_dbg(self, enabled): 518 self._recv_dbg = enabled 519 520 def _recv_dbg_print(self, reply, nl_msgs): 521 if not self._recv_dbg: 522 return 523 print("Recv: read", len(reply), "bytes,", 524 len(nl_msgs.msgs), "messages", file=sys.stderr) 525 for nl_msg in nl_msgs: 526 print(" ", nl_msg, file=sys.stderr) 527 528 def _encode_enum(self, attr_spec, value): 529 enum = self.consts[attr_spec['enum']] 530 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 531 scalar = 0 532 if isinstance(value, str): 533 value = [value] 534 for single_value in value: 535 scalar += enum.entries[single_value].user_value(as_flags = True) 536 return scalar 537 else: 538 return enum.entries[value].user_value() 539 540 def _get_scalar(self, attr_spec, value): 541 try: 542 return int(value) 543 except (ValueError, TypeError) as e: 544 if 'enum' in attr_spec: 545 return self._encode_enum(attr_spec, value) 546 if attr_spec.display_hint: 547 return self._from_string(value, attr_spec) 548 raise e 549 550 def _add_attr(self, space, name, value, search_attrs): 551 try: 552 attr = self.attr_sets[space][name] 553 except KeyError: 554 raise Exception(f"Space '{space}' has no attribute '{name}'") 555 nl_type = attr.value 556 557 if attr.is_multi and isinstance(value, list): 558 attr_payload = b'' 559 for subvalue in value: 560 attr_payload += self._add_attr(space, name, subvalue, search_attrs) 561 return attr_payload 562 563 if attr["type"] == 'nest': 564 nl_type |= Netlink.NLA_F_NESTED 565 attr_payload = b'' 566 sub_space = attr['nested-attributes'] 567 sub_attrs = SpaceAttrs(self.attr_sets[sub_space], value, search_attrs) 568 for subname, subvalue in value.items(): 569 attr_payload += self._add_attr(sub_space, subname, subvalue, sub_attrs) 570 elif attr["type"] == 'flag': 571 if not value: 572 # If value is absent or false then skip attribute creation. 573 return b'' 574 attr_payload = b'' 575 elif attr["type"] == 'string': 576 attr_payload = str(value).encode('ascii') + b'\x00' 577 elif attr["type"] == 'binary': 578 if isinstance(value, bytes): 579 attr_payload = value 580 elif isinstance(value, str): 581 if attr.display_hint: 582 attr_payload = self._from_string(value, attr) 583 else: 584 attr_payload = bytes.fromhex(value) 585 elif isinstance(value, dict) and attr.struct_name: 586 attr_payload = self._encode_struct(attr.struct_name, value) 587 else: 588 raise Exception(f'Unknown type for binary attribute, value: {value}') 589 elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar: 590 scalar = self._get_scalar(attr, value) 591 if attr.is_auto_scalar: 592 attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64') 593 else: 594 attr_type = attr["type"] 595 format = NlAttr.get_format(attr_type, attr.byte_order) 596 attr_payload = format.pack(scalar) 597 elif attr['type'] in "bitfield32": 598 scalar_value = self._get_scalar(attr, value["value"]) 599 scalar_selector = self._get_scalar(attr, value["selector"]) 600 attr_payload = struct.pack("II", scalar_value, scalar_selector) 601 elif attr['type'] == 'sub-message': 602 msg_format, _ = self._resolve_selector(attr, search_attrs) 603 attr_payload = b'' 604 if msg_format.fixed_header: 605 attr_payload += self._encode_struct(msg_format.fixed_header, value) 606 if msg_format.attr_set: 607 if msg_format.attr_set in self.attr_sets: 608 nl_type |= Netlink.NLA_F_NESTED 609 sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs) 610 for subname, subvalue in value.items(): 611 attr_payload += self._add_attr(msg_format.attr_set, 612 subname, subvalue, sub_attrs) 613 else: 614 raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'") 615 else: 616 raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') 617 618 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) 619 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad 620 621 def _decode_enum(self, raw, attr_spec): 622 enum = self.consts[attr_spec['enum']] 623 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 624 i = 0 625 value = set() 626 while raw: 627 if raw & 1: 628 value.add(enum.entries_by_val[i].name) 629 raw >>= 1 630 i += 1 631 else: 632 value = enum.entries_by_val[raw].name 633 return value 634 635 def _decode_binary(self, attr, attr_spec): 636 if attr_spec.struct_name: 637 decoded = self._decode_struct(attr.raw, attr_spec.struct_name) 638 elif attr_spec.sub_type: 639 decoded = attr.as_c_array(attr_spec.sub_type) 640 if 'enum' in attr_spec: 641 decoded = [ self._decode_enum(x, attr_spec) for x in decoded ] 642 elif attr_spec.display_hint: 643 decoded = [ self._formatted_string(x, attr_spec.display_hint) 644 for x in decoded ] 645 else: 646 decoded = attr.as_bin() 647 if attr_spec.display_hint: 648 decoded = self._formatted_string(decoded, attr_spec.display_hint) 649 return decoded 650 651 def _decode_array_attr(self, attr, attr_spec): 652 decoded = [] 653 offset = 0 654 while offset < len(attr.raw): 655 item = NlAttr(attr.raw, offset) 656 offset += item.full_len 657 658 if attr_spec["sub-type"] == 'nest': 659 subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes']) 660 decoded.append({ item.type: subattrs }) 661 elif attr_spec["sub-type"] == 'binary': 662 subattr = item.as_bin() 663 if attr_spec.display_hint: 664 subattr = self._formatted_string(subattr, attr_spec.display_hint) 665 decoded.append(subattr) 666 elif attr_spec["sub-type"] in NlAttr.type_formats: 667 subattr = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order) 668 if 'enum' in attr_spec: 669 subattr = self._decode_enum(subattr, attr_spec) 670 elif attr_spec.display_hint: 671 subattr = self._formatted_string(subattr, attr_spec.display_hint) 672 decoded.append(subattr) 673 else: 674 raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}') 675 return decoded 676 677 def _decode_nest_type_value(self, attr, attr_spec): 678 decoded = {} 679 value = attr 680 for name in attr_spec['type-value']: 681 value = NlAttr(value.raw, 0) 682 decoded[name] = value.type 683 subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes']) 684 decoded.update(subattrs) 685 return decoded 686 687 def _decode_unknown(self, attr): 688 if attr.is_nest: 689 return self._decode(NlAttrs(attr.raw), None) 690 else: 691 return attr.as_bin() 692 693 def _rsp_add(self, rsp, name, is_multi, decoded): 694 if is_multi == None: 695 if name in rsp and type(rsp[name]) is not list: 696 rsp[name] = [rsp[name]] 697 is_multi = True 698 else: 699 is_multi = False 700 701 if not is_multi: 702 rsp[name] = decoded 703 elif name in rsp: 704 rsp[name].append(decoded) 705 else: 706 rsp[name] = [decoded] 707 708 def _resolve_selector(self, attr_spec, search_attrs): 709 sub_msg = attr_spec.sub_message 710 if sub_msg not in self.sub_msgs: 711 raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}") 712 sub_msg_spec = self.sub_msgs[sub_msg] 713 714 selector = attr_spec.selector 715 value = search_attrs.lookup(selector) 716 if value not in sub_msg_spec.formats: 717 raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'") 718 719 spec = sub_msg_spec.formats[value] 720 return spec, value 721 722 def _decode_sub_msg(self, attr, attr_spec, search_attrs): 723 msg_format, _ = self._resolve_selector(attr_spec, search_attrs) 724 decoded = {} 725 offset = 0 726 if msg_format.fixed_header: 727 decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header)); 728 offset = self._struct_size(msg_format.fixed_header) 729 if msg_format.attr_set: 730 if msg_format.attr_set in self.attr_sets: 731 subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set) 732 decoded.update(subdict) 733 else: 734 raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'") 735 return decoded 736 737 def _decode(self, attrs, space, outer_attrs = None): 738 rsp = dict() 739 if space: 740 attr_space = self.attr_sets[space] 741 search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs) 742 743 for attr in attrs: 744 try: 745 attr_spec = attr_space.attrs_by_val[attr.type] 746 except (KeyError, UnboundLocalError): 747 if not self.process_unknown: 748 raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'") 749 attr_name = f"UnknownAttr({attr.type})" 750 self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr)) 751 continue 752 753 try: 754 if attr_spec["type"] == 'nest': 755 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs) 756 decoded = subdict 757 elif attr_spec["type"] == 'string': 758 decoded = attr.as_strz() 759 elif attr_spec["type"] == 'binary': 760 decoded = self._decode_binary(attr, attr_spec) 761 elif attr_spec["type"] == 'flag': 762 decoded = True 763 elif attr_spec.is_auto_scalar: 764 decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order) 765 if 'enum' in attr_spec: 766 decoded = self._decode_enum(decoded, attr_spec) 767 elif attr_spec["type"] in NlAttr.type_formats: 768 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) 769 if 'enum' in attr_spec: 770 decoded = self._decode_enum(decoded, attr_spec) 771 elif attr_spec.display_hint: 772 decoded = self._formatted_string(decoded, attr_spec.display_hint) 773 elif attr_spec["type"] == 'indexed-array': 774 decoded = self._decode_array_attr(attr, attr_spec) 775 elif attr_spec["type"] == 'bitfield32': 776 value, selector = struct.unpack("II", attr.raw) 777 if 'enum' in attr_spec: 778 value = self._decode_enum(value, attr_spec) 779 selector = self._decode_enum(selector, attr_spec) 780 decoded = {"value": value, "selector": selector} 781 elif attr_spec["type"] == 'sub-message': 782 decoded = self._decode_sub_msg(attr, attr_spec, search_attrs) 783 elif attr_spec["type"] == 'nest-type-value': 784 decoded = self._decode_nest_type_value(attr, attr_spec) 785 else: 786 if not self.process_unknown: 787 raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') 788 decoded = self._decode_unknown(attr) 789 790 self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded) 791 except: 792 print(f"Error decoding '{attr_spec.name}' from '{space}'") 793 raise 794 795 return rsp 796 797 def _decode_extack_path(self, attrs, attr_set, offset, target, search_attrs): 798 for attr in attrs: 799 try: 800 attr_spec = attr_set.attrs_by_val[attr.type] 801 except KeyError: 802 raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") 803 if offset > target: 804 break 805 if offset == target: 806 return '.' + attr_spec.name 807 808 if offset + attr.full_len <= target: 809 offset += attr.full_len 810 continue 811 812 pathname = attr_spec.name 813 if attr_spec['type'] == 'nest': 814 sub_attrs = self.attr_sets[attr_spec['nested-attributes']] 815 search_attrs = SpaceAttrs(sub_attrs, search_attrs.lookup(attr_spec['name'])) 816 elif attr_spec['type'] == 'sub-message': 817 msg_format, value = self._resolve_selector(attr_spec, search_attrs) 818 if msg_format is None: 819 raise Exception(f"Can't resolve sub-message of {attr_spec['name']} for extack") 820 sub_attrs = self.attr_sets[msg_format.attr_set] 821 pathname += f"({value})" 822 else: 823 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") 824 offset += 4 825 subpath = self._decode_extack_path(NlAttrs(attr.raw), sub_attrs, 826 offset, target, search_attrs) 827 if subpath is None: 828 return None 829 return '.' + pathname + subpath 830 831 return None 832 833 def _decode_extack(self, request, op, extack, vals): 834 if 'bad-attr-offs' not in extack: 835 return 836 837 msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op) 838 offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header) 839 search_attrs = SpaceAttrs(op.attr_set, vals) 840 path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset, 841 extack['bad-attr-offs'], search_attrs) 842 if path: 843 del extack['bad-attr-offs'] 844 extack['bad-attr'] = path 845 846 def _struct_size(self, name): 847 if name: 848 members = self.consts[name].members 849 size = 0 850 for m in members: 851 if m.type in ['pad', 'binary']: 852 if m.struct: 853 size += self._struct_size(m.struct) 854 else: 855 size += m.len 856 else: 857 format = NlAttr.get_format(m.type, m.byte_order) 858 size += format.size 859 return size 860 else: 861 return 0 862 863 def _decode_struct(self, data, name): 864 members = self.consts[name].members 865 attrs = dict() 866 offset = 0 867 for m in members: 868 value = None 869 if m.type == 'pad': 870 offset += m.len 871 elif m.type == 'binary': 872 if m.struct: 873 len = self._struct_size(m.struct) 874 value = self._decode_struct(data[offset : offset + len], 875 m.struct) 876 offset += len 877 else: 878 value = data[offset : offset + m.len] 879 offset += m.len 880 else: 881 format = NlAttr.get_format(m.type, m.byte_order) 882 [ value ] = format.unpack_from(data, offset) 883 offset += format.size 884 if value is not None: 885 if m.enum: 886 value = self._decode_enum(value, m) 887 elif m.display_hint: 888 value = self._formatted_string(value, m.display_hint) 889 attrs[m.name] = value 890 return attrs 891 892 def _encode_struct(self, name, vals): 893 members = self.consts[name].members 894 attr_payload = b'' 895 for m in members: 896 value = vals.pop(m.name) if m.name in vals else None 897 if m.type == 'pad': 898 attr_payload += bytearray(m.len) 899 elif m.type == 'binary': 900 if m.struct: 901 if value is None: 902 value = dict() 903 attr_payload += self._encode_struct(m.struct, value) 904 else: 905 if value is None: 906 attr_payload += bytearray(m.len) 907 else: 908 attr_payload += bytes.fromhex(value) 909 else: 910 if value is None: 911 value = 0 912 format = NlAttr.get_format(m.type, m.byte_order) 913 attr_payload += format.pack(value) 914 return attr_payload 915 916 def _formatted_string(self, raw, display_hint): 917 if display_hint == 'mac': 918 formatted = ':'.join('%02x' % b for b in raw) 919 elif display_hint == 'hex': 920 if isinstance(raw, int): 921 formatted = hex(raw) 922 else: 923 formatted = bytes.hex(raw, ' ') 924 elif display_hint in [ 'ipv4', 'ipv6' ]: 925 formatted = format(ipaddress.ip_address(raw)) 926 elif display_hint == 'uuid': 927 formatted = str(uuid.UUID(bytes=raw)) 928 else: 929 formatted = raw 930 return formatted 931 932 def _from_string(self, string, attr_spec): 933 if attr_spec.display_hint in ['ipv4', 'ipv6']: 934 ip = ipaddress.ip_address(string) 935 if attr_spec['type'] == 'binary': 936 raw = ip.packed 937 else: 938 raw = int(ip) 939 else: 940 raise Exception(f"Display hint '{attr_spec.display_hint}' not implemented" 941 f" when parsing '{attr_spec['name']}'") 942 return raw 943 944 def handle_ntf(self, decoded): 945 msg = dict() 946 if self.include_raw: 947 msg['raw'] = decoded 948 op = self.rsp_by_value[decoded.cmd()] 949 attrs = self._decode(decoded.raw_attrs, op.attr_set.name) 950 if op.fixed_header: 951 attrs.update(self._decode_struct(decoded.raw, op.fixed_header)) 952 953 msg['name'] = op['name'] 954 msg['msg'] = attrs 955 self.async_msg_queue.put(msg) 956 957 def check_ntf(self): 958 while True: 959 try: 960 reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT) 961 except BlockingIOError: 962 return 963 964 nms = NlMsgs(reply) 965 self._recv_dbg_print(reply, nms) 966 for nl_msg in nms: 967 if nl_msg.error: 968 print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) 969 print(nl_msg) 970 continue 971 if nl_msg.done: 972 print("Netlink done while checking for ntf!?") 973 continue 974 975 decoded = self.nlproto.decode(self, nl_msg, None) 976 if decoded.cmd() not in self.async_msg_ids: 977 print("Unexpected msg id while checking for ntf", decoded) 978 continue 979 980 self.handle_ntf(decoded) 981 982 def poll_ntf(self, duration=None): 983 start_time = time.time() 984 selector = selectors.DefaultSelector() 985 selector.register(self.sock, selectors.EVENT_READ) 986 987 while True: 988 try: 989 yield self.async_msg_queue.get_nowait() 990 except queue.Empty: 991 if duration is not None: 992 timeout = start_time + duration - time.time() 993 if timeout <= 0: 994 return 995 else: 996 timeout = None 997 events = selector.select(timeout) 998 if events: 999 self.check_ntf() 1000 1001 def operation_do_attributes(self, name): 1002 """ 1003 For a given operation name, find and return a supported 1004 set of attributes (as a dict). 1005 """ 1006 op = self.find_operation(name) 1007 if not op: 1008 return None 1009 1010 return op['do']['request']['attributes'].copy() 1011 1012 def _encode_message(self, op, vals, flags, req_seq): 1013 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK 1014 for flag in flags or []: 1015 nl_flags |= flag 1016 1017 msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq) 1018 if op.fixed_header: 1019 msg += self._encode_struct(op.fixed_header, vals) 1020 search_attrs = SpaceAttrs(op.attr_set, vals) 1021 for name, value in vals.items(): 1022 msg += self._add_attr(op.attr_set.name, name, value, search_attrs) 1023 msg = _genl_msg_finalize(msg) 1024 return msg 1025 1026 def _ops(self, ops): 1027 reqs_by_seq = {} 1028 req_seq = random.randint(1024, 65535) 1029 payload = b'' 1030 for (method, vals, flags) in ops: 1031 op = self.ops[method] 1032 msg = self._encode_message(op, vals, flags, req_seq) 1033 reqs_by_seq[req_seq] = (op, vals, msg, flags) 1034 payload += msg 1035 req_seq += 1 1036 1037 self.sock.send(payload, 0) 1038 1039 done = False 1040 rsp = [] 1041 op_rsp = [] 1042 while not done: 1043 reply = self.sock.recv(self._recv_size) 1044 nms = NlMsgs(reply) 1045 self._recv_dbg_print(reply, nms) 1046 for nl_msg in nms: 1047 if nl_msg.nl_seq in reqs_by_seq: 1048 (op, vals, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq] 1049 if nl_msg.extack: 1050 nl_msg.annotate_extack(op.attr_set) 1051 self._decode_extack(req_msg, op, nl_msg.extack, vals) 1052 else: 1053 op = None 1054 req_flags = [] 1055 1056 if nl_msg.error: 1057 raise NlError(nl_msg) 1058 if nl_msg.done: 1059 if nl_msg.extack: 1060 print("Netlink warning:") 1061 print(nl_msg) 1062 1063 if Netlink.NLM_F_DUMP in req_flags: 1064 rsp.append(op_rsp) 1065 elif not op_rsp: 1066 rsp.append(None) 1067 elif len(op_rsp) == 1: 1068 rsp.append(op_rsp[0]) 1069 else: 1070 rsp.append(op_rsp) 1071 op_rsp = [] 1072 1073 del reqs_by_seq[nl_msg.nl_seq] 1074 done = len(reqs_by_seq) == 0 1075 break 1076 1077 decoded = self.nlproto.decode(self, nl_msg, op) 1078 1079 # Check if this is a reply to our request 1080 if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value: 1081 if decoded.cmd() in self.async_msg_ids: 1082 self.handle_ntf(decoded) 1083 continue 1084 else: 1085 print('Unexpected message: ' + repr(decoded)) 1086 continue 1087 1088 rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name) 1089 if op.fixed_header: 1090 rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header)) 1091 op_rsp.append(rsp_msg) 1092 1093 return rsp 1094 1095 def _op(self, method, vals, flags=None, dump=False): 1096 req_flags = flags or [] 1097 if dump: 1098 req_flags.append(Netlink.NLM_F_DUMP) 1099 1100 ops = [(method, vals, req_flags)] 1101 return self._ops(ops)[0] 1102 1103 def do(self, method, vals, flags=None): 1104 return self._op(method, vals, flags) 1105 1106 def dump(self, method, vals): 1107 return self._op(method, vals, dump=True) 1108 1109 def do_multi(self, ops): 1110 return self._ops(ops) 1111