1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause 2 3from collections import namedtuple 4from enum import Enum 5import functools 6import os 7import random 8import socket 9import struct 10from struct import Struct 11import sys 12import ipaddress 13import uuid 14import queue 15import selectors 16import time 17 18from .nlspec import SpecFamily 19 20# 21# Generic Netlink code which should really be in some library, but I can't quickly find one. 22# 23 24 25class Netlink: 26 # Netlink socket 27 SOL_NETLINK = 270 28 29 NETLINK_ADD_MEMBERSHIP = 1 30 NETLINK_CAP_ACK = 10 31 NETLINK_EXT_ACK = 11 32 NETLINK_GET_STRICT_CHK = 12 33 34 # Netlink message 35 NLMSG_ERROR = 2 36 NLMSG_DONE = 3 37 38 NLM_F_REQUEST = 1 39 NLM_F_ACK = 4 40 NLM_F_ROOT = 0x100 41 NLM_F_MATCH = 0x200 42 43 NLM_F_REPLACE = 0x100 44 NLM_F_EXCL = 0x200 45 NLM_F_CREATE = 0x400 46 NLM_F_APPEND = 0x800 47 48 NLM_F_CAPPED = 0x100 49 NLM_F_ACK_TLVS = 0x200 50 51 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH 52 53 NLA_F_NESTED = 0x8000 54 NLA_F_NET_BYTEORDER = 0x4000 55 56 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER 57 58 # Genetlink defines 59 NETLINK_GENERIC = 16 60 61 GENL_ID_CTRL = 0x10 62 63 # nlctrl 64 CTRL_CMD_GETFAMILY = 3 65 66 CTRL_ATTR_FAMILY_ID = 1 67 CTRL_ATTR_FAMILY_NAME = 2 68 CTRL_ATTR_MAXATTR = 5 69 CTRL_ATTR_MCAST_GROUPS = 7 70 71 CTRL_ATTR_MCAST_GRP_NAME = 1 72 CTRL_ATTR_MCAST_GRP_ID = 2 73 74 # Extack types 75 NLMSGERR_ATTR_MSG = 1 76 NLMSGERR_ATTR_OFFS = 2 77 NLMSGERR_ATTR_COOKIE = 3 78 NLMSGERR_ATTR_POLICY = 4 79 NLMSGERR_ATTR_MISS_TYPE = 5 80 NLMSGERR_ATTR_MISS_NEST = 6 81 82 # Policy types 83 NL_POLICY_TYPE_ATTR_TYPE = 1 84 NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2 85 NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3 86 NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4 87 NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5 88 NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6 89 NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7 90 NL_POLICY_TYPE_ATTR_POLICY_IDX = 8 91 NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9 92 NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10 93 NL_POLICY_TYPE_ATTR_PAD = 11 94 NL_POLICY_TYPE_ATTR_MASK = 12 95 96 AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64', 97 's8', 's16', 's32', 's64', 98 'binary', 'string', 'nul-string', 99 'nested', 'nested-array', 100 'bitfield32', 'sint', 'uint']) 101 102class NlError(Exception): 103 def __init__(self, nl_msg): 104 self.nl_msg = nl_msg 105 self.error = -nl_msg.error 106 107 def __str__(self): 108 return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}" 109 110 111class ConfigError(Exception): 112 pass 113 114 115class NlAttr: 116 ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) 117 type_formats = { 118 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), 119 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), 120 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")), 121 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")), 122 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")), 123 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")), 124 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")), 125 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q")) 126 } 127 128 def __init__(self, raw, offset): 129 self._len, self._type = struct.unpack("HH", raw[offset : offset + 4]) 130 self.type = self._type & ~Netlink.NLA_TYPE_MASK 131 self.is_nest = self._type & Netlink.NLA_F_NESTED 132 self.payload_len = self._len 133 self.full_len = (self.payload_len + 3) & ~3 134 self.raw = raw[offset + 4 : offset + self.payload_len] 135 136 @classmethod 137 def get_format(cls, attr_type, byte_order=None): 138 format = cls.type_formats[attr_type] 139 if byte_order: 140 return format.big if byte_order == "big-endian" \ 141 else format.little 142 return format.native 143 144 def as_scalar(self, attr_type, byte_order=None): 145 format = self.get_format(attr_type, byte_order) 146 return format.unpack(self.raw)[0] 147 148 def as_auto_scalar(self, attr_type, byte_order=None): 149 if len(self.raw) != 4 and len(self.raw) != 8: 150 raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}") 151 real_type = attr_type[0] + str(len(self.raw) * 8) 152 format = self.get_format(real_type, byte_order) 153 return format.unpack(self.raw)[0] 154 155 def as_strz(self): 156 return self.raw.decode('ascii')[:-1] 157 158 def as_bin(self): 159 return self.raw 160 161 def as_c_array(self, type): 162 format = self.get_format(type) 163 return [ x[0] for x in format.iter_unpack(self.raw) ] 164 165 def __repr__(self): 166 return f"[type:{self.type} len:{self._len}] {self.raw}" 167 168 169class NlAttrs: 170 def __init__(self, msg, offset=0): 171 self.attrs = [] 172 173 while offset < len(msg): 174 attr = NlAttr(msg, offset) 175 offset += attr.full_len 176 self.attrs.append(attr) 177 178 def __iter__(self): 179 yield from self.attrs 180 181 def __repr__(self): 182 msg = '' 183 for a in self.attrs: 184 if msg: 185 msg += '\n' 186 msg += repr(a) 187 return msg 188 189 190class NlMsg: 191 def __init__(self, msg, offset, attr_space=None): 192 self.hdr = msg[offset : offset + 16] 193 194 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ 195 struct.unpack("IHHII", self.hdr) 196 197 self.raw = msg[offset + 16 : offset + self.nl_len] 198 199 self.error = 0 200 self.done = 0 201 202 extack_off = None 203 if self.nl_type == Netlink.NLMSG_ERROR: 204 self.error = struct.unpack("i", self.raw[0:4])[0] 205 self.done = 1 206 extack_off = 20 207 elif self.nl_type == Netlink.NLMSG_DONE: 208 self.error = struct.unpack("i", self.raw[0:4])[0] 209 self.done = 1 210 extack_off = 4 211 212 self.extack = None 213 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: 214 self.extack = dict() 215 extack_attrs = NlAttrs(self.raw[extack_off:]) 216 for extack in extack_attrs: 217 if extack.type == Netlink.NLMSGERR_ATTR_MSG: 218 self.extack['msg'] = extack.as_strz() 219 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: 220 self.extack['miss-type'] = extack.as_scalar('u32') 221 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: 222 self.extack['miss-nest'] = extack.as_scalar('u32') 223 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: 224 self.extack['bad-attr-offs'] = extack.as_scalar('u32') 225 elif extack.type == Netlink.NLMSGERR_ATTR_POLICY: 226 self.extack['policy'] = self._decode_policy(extack.raw) 227 else: 228 if 'unknown' not in self.extack: 229 self.extack['unknown'] = [] 230 self.extack['unknown'].append(extack) 231 232 if attr_space: 233 self.annotate_extack(attr_space) 234 235 def _decode_policy(self, raw): 236 policy = {} 237 for attr in NlAttrs(raw): 238 if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE: 239 type = attr.as_scalar('u32') 240 policy['type'] = Netlink.AttrType(type).name 241 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S: 242 policy['min-value'] = attr.as_scalar('s64') 243 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S: 244 policy['max-value'] = attr.as_scalar('s64') 245 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U: 246 policy['min-value'] = attr.as_scalar('u64') 247 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U: 248 policy['max-value'] = attr.as_scalar('u64') 249 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH: 250 policy['min-length'] = attr.as_scalar('u32') 251 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH: 252 policy['max-length'] = attr.as_scalar('u32') 253 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK: 254 policy['bitfield32-mask'] = attr.as_scalar('u32') 255 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK: 256 policy['mask'] = attr.as_scalar('u64') 257 return policy 258 259 def annotate_extack(self, attr_space): 260 """ Make extack more human friendly with attribute information """ 261 262 # We don't have the ability to parse nests yet, so only do global 263 if 'miss-type' in self.extack and 'miss-nest' not in self.extack: 264 miss_type = self.extack['miss-type'] 265 if miss_type in attr_space.attrs_by_val: 266 spec = attr_space.attrs_by_val[miss_type] 267 self.extack['miss-type'] = spec['name'] 268 if 'doc' in spec: 269 self.extack['miss-type-doc'] = spec['doc'] 270 271 def cmd(self): 272 return self.nl_type 273 274 def __repr__(self): 275 msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}" 276 if self.error: 277 msg += '\n\terror: ' + str(self.error) 278 if self.extack: 279 msg += '\n\textack: ' + repr(self.extack) 280 return msg 281 282 283class NlMsgs: 284 def __init__(self, data): 285 self.msgs = [] 286 287 offset = 0 288 while offset < len(data): 289 msg = NlMsg(data, offset) 290 offset += msg.nl_len 291 self.msgs.append(msg) 292 293 def __iter__(self): 294 yield from self.msgs 295 296 297genl_family_name_to_id = None 298 299 300def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): 301 # we prepend length in _genl_msg_finalize() 302 if seq is None: 303 seq = random.randint(1, 1024) 304 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 305 genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0) 306 return nlmsg + genlmsg 307 308 309def _genl_msg_finalize(msg): 310 return struct.pack("I", len(msg) + 4) + msg 311 312 313def _genl_load_families(): 314 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: 315 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 316 317 msg = _genl_msg(Netlink.GENL_ID_CTRL, 318 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, 319 Netlink.CTRL_CMD_GETFAMILY, 1) 320 msg = _genl_msg_finalize(msg) 321 322 sock.send(msg, 0) 323 324 global genl_family_name_to_id 325 genl_family_name_to_id = dict() 326 327 while True: 328 reply = sock.recv(128 * 1024) 329 nms = NlMsgs(reply) 330 for nl_msg in nms: 331 if nl_msg.error: 332 print("Netlink error:", nl_msg.error) 333 return 334 if nl_msg.done: 335 return 336 337 gm = GenlMsg(nl_msg) 338 fam = dict() 339 for attr in NlAttrs(gm.raw): 340 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: 341 fam['id'] = attr.as_scalar('u16') 342 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: 343 fam['name'] = attr.as_strz() 344 elif attr.type == Netlink.CTRL_ATTR_MAXATTR: 345 fam['maxattr'] = attr.as_scalar('u32') 346 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: 347 fam['mcast'] = dict() 348 for entry in NlAttrs(attr.raw): 349 mcast_name = None 350 mcast_id = None 351 for entry_attr in NlAttrs(entry.raw): 352 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: 353 mcast_name = entry_attr.as_strz() 354 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: 355 mcast_id = entry_attr.as_scalar('u32') 356 if mcast_name and mcast_id is not None: 357 fam['mcast'][mcast_name] = mcast_id 358 if 'name' in fam and 'id' in fam: 359 genl_family_name_to_id[fam['name']] = fam 360 361 362class GenlMsg: 363 def __init__(self, nl_msg): 364 self.nl = nl_msg 365 self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0) 366 self.raw = nl_msg.raw[4:] 367 368 def cmd(self): 369 return self.genl_cmd 370 371 def __repr__(self): 372 msg = repr(self.nl) 373 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" 374 for a in self.raw_attrs: 375 msg += '\t\t' + repr(a) + '\n' 376 return msg 377 378 379class NetlinkProtocol: 380 def __init__(self, family_name, proto_num): 381 self.family_name = family_name 382 self.proto_num = proto_num 383 384 def _message(self, nl_type, nl_flags, seq=None): 385 if seq is None: 386 seq = random.randint(1, 1024) 387 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 388 return nlmsg 389 390 def message(self, flags, command, version, seq=None): 391 return self._message(command, flags, seq) 392 393 def _decode(self, nl_msg): 394 return nl_msg 395 396 def decode(self, ynl, nl_msg, op): 397 msg = self._decode(nl_msg) 398 if op is None: 399 op = ynl.rsp_by_value[msg.cmd()] 400 fixed_header_size = ynl._struct_size(op.fixed_header) 401 msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size) 402 return msg 403 404 def get_mcast_id(self, mcast_name, mcast_groups): 405 if mcast_name not in mcast_groups: 406 raise Exception(f'Multicast group "{mcast_name}" not present in the spec') 407 return mcast_groups[mcast_name].value 408 409 def msghdr_size(self): 410 return 16 411 412 413class GenlProtocol(NetlinkProtocol): 414 def __init__(self, family_name): 415 super().__init__(family_name, Netlink.NETLINK_GENERIC) 416 417 global genl_family_name_to_id 418 if genl_family_name_to_id is None: 419 _genl_load_families() 420 421 self.genl_family = genl_family_name_to_id[family_name] 422 self.family_id = genl_family_name_to_id[family_name]['id'] 423 424 def message(self, flags, command, version, seq=None): 425 nlmsg = self._message(self.family_id, flags, seq) 426 genlmsg = struct.pack("BBH", command, version, 0) 427 return nlmsg + genlmsg 428 429 def _decode(self, nl_msg): 430 return GenlMsg(nl_msg) 431 432 def get_mcast_id(self, mcast_name, mcast_groups): 433 if mcast_name not in self.genl_family['mcast']: 434 raise Exception(f'Multicast group "{mcast_name}" not present in the family') 435 return self.genl_family['mcast'][mcast_name] 436 437 def msghdr_size(self): 438 return super().msghdr_size() + 4 439 440 441class SpaceAttrs: 442 SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values']) 443 444 def __init__(self, attr_space, attrs, outer = None): 445 outer_scopes = outer.scopes if outer else [] 446 inner_scope = self.SpecValuesPair(attr_space, attrs) 447 self.scopes = [inner_scope] + outer_scopes 448 449 def lookup(self, name): 450 for scope in self.scopes: 451 if name in scope.spec: 452 if name in scope.values: 453 return scope.values[name] 454 spec_name = scope.spec.yaml['name'] 455 raise Exception( 456 f"No value for '{name}' in attribute space '{spec_name}'") 457 raise Exception(f"Attribute '{name}' not defined in any attribute-set") 458 459 460# 461# YNL implementation details. 462# 463 464 465class YnlFamily(SpecFamily): 466 def __init__(self, def_path, schema=None, process_unknown=False, 467 recv_size=0): 468 super().__init__(def_path, schema) 469 470 self.include_raw = False 471 self.process_unknown = process_unknown 472 473 try: 474 if self.proto == "netlink-raw": 475 self.nlproto = NetlinkProtocol(self.yaml['name'], 476 self.yaml['protonum']) 477 else: 478 self.nlproto = GenlProtocol(self.yaml['name']) 479 except KeyError: 480 raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel") 481 482 self._recv_dbg = False 483 # Note that netlink will use conservative (min) message size for 484 # the first dump recv() on the socket, our setting will only matter 485 # from the second recv() on. 486 self._recv_size = recv_size if recv_size else 131072 487 # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo) 488 # for a message, so smaller receive sizes will lead to truncation. 489 # Note that the min size for other families may be larger than 4k! 490 if self._recv_size < 4000: 491 raise ConfigError() 492 493 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num) 494 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 495 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) 496 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1) 497 498 self.async_msg_ids = set() 499 self.async_msg_queue = queue.Queue() 500 501 for msg in self.msgs.values(): 502 if msg.is_async: 503 self.async_msg_ids.add(msg.rsp_value) 504 505 for op_name, op in self.ops.items(): 506 bound_f = functools.partial(self._op, op_name) 507 setattr(self, op.ident_name, bound_f) 508 509 510 def ntf_subscribe(self, mcast_name): 511 mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups) 512 self.sock.bind((0, 0)) 513 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, 514 mcast_id) 515 516 def set_recv_dbg(self, enabled): 517 self._recv_dbg = enabled 518 519 def _recv_dbg_print(self, reply, nl_msgs): 520 if not self._recv_dbg: 521 return 522 print("Recv: read", len(reply), "bytes,", 523 len(nl_msgs.msgs), "messages", file=sys.stderr) 524 for nl_msg in nl_msgs: 525 print(" ", nl_msg, file=sys.stderr) 526 527 def _encode_enum(self, attr_spec, value): 528 enum = self.consts[attr_spec['enum']] 529 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 530 scalar = 0 531 if isinstance(value, str): 532 value = [value] 533 for single_value in value: 534 scalar += enum.entries[single_value].user_value(as_flags = True) 535 return scalar 536 else: 537 return enum.entries[value].user_value() 538 539 def _get_scalar(self, attr_spec, value): 540 try: 541 return int(value) 542 except (ValueError, TypeError) as e: 543 if 'enum' in attr_spec: 544 return self._encode_enum(attr_spec, value) 545 if attr_spec.display_hint: 546 return self._from_string(value, attr_spec) 547 raise e 548 549 def _add_attr(self, space, name, value, search_attrs): 550 try: 551 attr = self.attr_sets[space][name] 552 except KeyError: 553 raise Exception(f"Space '{space}' has no attribute '{name}'") 554 nl_type = attr.value 555 556 if attr.is_multi and isinstance(value, list): 557 attr_payload = b'' 558 for subvalue in value: 559 attr_payload += self._add_attr(space, name, subvalue, search_attrs) 560 return attr_payload 561 562 if attr["type"] == 'nest': 563 nl_type |= Netlink.NLA_F_NESTED 564 sub_space = attr['nested-attributes'] 565 attr_payload = self._add_nest_attrs(value, sub_space, search_attrs) 566 elif attr['type'] == 'indexed-array' and attr['sub-type'] == 'nest': 567 nl_type |= Netlink.NLA_F_NESTED 568 sub_space = attr['nested-attributes'] 569 attr_payload = self._encode_indexed_array(value, sub_space, 570 search_attrs) 571 elif attr["type"] == 'flag': 572 if not value: 573 # If value is absent or false then skip attribute creation. 574 return b'' 575 attr_payload = b'' 576 elif attr["type"] == 'string': 577 attr_payload = str(value).encode('ascii') + b'\x00' 578 elif attr["type"] == 'binary': 579 if value is None: 580 attr_payload = b'' 581 elif isinstance(value, bytes): 582 attr_payload = value 583 elif isinstance(value, str): 584 if attr.display_hint: 585 attr_payload = self._from_string(value, attr) 586 else: 587 attr_payload = bytes.fromhex(value) 588 elif isinstance(value, dict) and attr.struct_name: 589 attr_payload = self._encode_struct(attr.struct_name, value) 590 elif isinstance(value, list) and attr.sub_type in NlAttr.type_formats: 591 format = NlAttr.get_format(attr.sub_type) 592 attr_payload = b''.join([format.pack(x) for x in value]) 593 else: 594 raise Exception(f'Unknown type for binary attribute, value: {value}') 595 elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar: 596 scalar = self._get_scalar(attr, value) 597 if attr.is_auto_scalar: 598 attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64') 599 else: 600 attr_type = attr["type"] 601 format = NlAttr.get_format(attr_type, attr.byte_order) 602 attr_payload = format.pack(scalar) 603 elif attr['type'] in "bitfield32": 604 scalar_value = self._get_scalar(attr, value["value"]) 605 scalar_selector = self._get_scalar(attr, value["selector"]) 606 attr_payload = struct.pack("II", scalar_value, scalar_selector) 607 elif attr['type'] == 'sub-message': 608 msg_format, _ = self._resolve_selector(attr, search_attrs) 609 attr_payload = b'' 610 if msg_format.fixed_header: 611 attr_payload += self._encode_struct(msg_format.fixed_header, value) 612 if msg_format.attr_set: 613 if msg_format.attr_set in self.attr_sets: 614 nl_type |= Netlink.NLA_F_NESTED 615 sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs) 616 for subname, subvalue in value.items(): 617 attr_payload += self._add_attr(msg_format.attr_set, 618 subname, subvalue, sub_attrs) 619 else: 620 raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'") 621 else: 622 raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') 623 624 return self._add_attr_raw(nl_type, attr_payload) 625 626 def _add_attr_raw(self, nl_type, attr_payload): 627 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) 628 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad 629 630 def _add_nest_attrs(self, value, sub_space, search_attrs): 631 sub_attrs = SpaceAttrs(self.attr_sets[sub_space], value, search_attrs) 632 attr_payload = b'' 633 for subname, subvalue in value.items(): 634 attr_payload += self._add_attr(sub_space, subname, subvalue, 635 sub_attrs) 636 return attr_payload 637 638 def _encode_indexed_array(self, vals, sub_space, search_attrs): 639 attr_payload = b'' 640 for i, val in enumerate(vals): 641 idx = i | Netlink.NLA_F_NESTED 642 val_payload = self._add_nest_attrs(val, sub_space, search_attrs) 643 attr_payload += self._add_attr_raw(idx, val_payload) 644 return attr_payload 645 646 def _get_enum_or_unknown(self, enum, raw): 647 try: 648 name = enum.entries_by_val[raw].name 649 except KeyError as error: 650 if self.process_unknown: 651 name = f"Unknown({raw})" 652 else: 653 raise error 654 return name 655 656 def _decode_enum(self, raw, attr_spec): 657 enum = self.consts[attr_spec['enum']] 658 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 659 i = 0 660 value = set() 661 while raw: 662 if raw & 1: 663 value.add(self._get_enum_or_unknown(enum, i)) 664 raw >>= 1 665 i += 1 666 else: 667 value = self._get_enum_or_unknown(enum, raw) 668 return value 669 670 def _decode_binary(self, attr, attr_spec): 671 if attr_spec.struct_name: 672 decoded = self._decode_struct(attr.raw, attr_spec.struct_name) 673 elif attr_spec.sub_type: 674 decoded = attr.as_c_array(attr_spec.sub_type) 675 if 'enum' in attr_spec: 676 decoded = [ self._decode_enum(x, attr_spec) for x in decoded ] 677 elif attr_spec.display_hint: 678 decoded = [ self._formatted_string(x, attr_spec.display_hint) 679 for x in decoded ] 680 else: 681 decoded = attr.as_bin() 682 if attr_spec.display_hint: 683 decoded = self._formatted_string(decoded, attr_spec.display_hint) 684 return decoded 685 686 def _decode_array_attr(self, attr, attr_spec): 687 decoded = [] 688 offset = 0 689 while offset < len(attr.raw): 690 item = NlAttr(attr.raw, offset) 691 offset += item.full_len 692 693 if attr_spec["sub-type"] == 'nest': 694 subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes']) 695 decoded.append({ item.type: subattrs }) 696 elif attr_spec["sub-type"] == 'binary': 697 subattr = item.as_bin() 698 if attr_spec.display_hint: 699 subattr = self._formatted_string(subattr, attr_spec.display_hint) 700 decoded.append(subattr) 701 elif attr_spec["sub-type"] in NlAttr.type_formats: 702 subattr = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order) 703 if 'enum' in attr_spec: 704 subattr = self._decode_enum(subattr, attr_spec) 705 elif attr_spec.display_hint: 706 subattr = self._formatted_string(subattr, attr_spec.display_hint) 707 decoded.append(subattr) 708 else: 709 raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}') 710 return decoded 711 712 def _decode_nest_type_value(self, attr, attr_spec): 713 decoded = {} 714 value = attr 715 for name in attr_spec['type-value']: 716 value = NlAttr(value.raw, 0) 717 decoded[name] = value.type 718 subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes']) 719 decoded.update(subattrs) 720 return decoded 721 722 def _decode_unknown(self, attr): 723 if attr.is_nest: 724 return self._decode(NlAttrs(attr.raw), None) 725 else: 726 return attr.as_bin() 727 728 def _rsp_add(self, rsp, name, is_multi, decoded): 729 if is_multi is None: 730 if name in rsp and type(rsp[name]) is not list: 731 rsp[name] = [rsp[name]] 732 is_multi = True 733 else: 734 is_multi = False 735 736 if not is_multi: 737 rsp[name] = decoded 738 elif name in rsp: 739 rsp[name].append(decoded) 740 else: 741 rsp[name] = [decoded] 742 743 def _resolve_selector(self, attr_spec, search_attrs): 744 sub_msg = attr_spec.sub_message 745 if sub_msg not in self.sub_msgs: 746 raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}") 747 sub_msg_spec = self.sub_msgs[sub_msg] 748 749 selector = attr_spec.selector 750 value = search_attrs.lookup(selector) 751 if value not in sub_msg_spec.formats: 752 raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'") 753 754 spec = sub_msg_spec.formats[value] 755 return spec, value 756 757 def _decode_sub_msg(self, attr, attr_spec, search_attrs): 758 msg_format, _ = self._resolve_selector(attr_spec, search_attrs) 759 decoded = {} 760 offset = 0 761 if msg_format.fixed_header: 762 decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header)) 763 offset = self._struct_size(msg_format.fixed_header) 764 if msg_format.attr_set: 765 if msg_format.attr_set in self.attr_sets: 766 subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set) 767 decoded.update(subdict) 768 else: 769 raise Exception(f"Unknown attribute-set '{msg_format.attr_set}' when decoding '{attr_spec.name}'") 770 return decoded 771 772 def _decode(self, attrs, space, outer_attrs = None): 773 rsp = dict() 774 if space: 775 attr_space = self.attr_sets[space] 776 search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs) 777 778 for attr in attrs: 779 try: 780 attr_spec = attr_space.attrs_by_val[attr.type] 781 except (KeyError, UnboundLocalError): 782 if not self.process_unknown: 783 raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'") 784 attr_name = f"UnknownAttr({attr.type})" 785 self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr)) 786 continue 787 788 try: 789 if attr_spec["type"] == 'nest': 790 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs) 791 decoded = subdict 792 elif attr_spec["type"] == 'string': 793 decoded = attr.as_strz() 794 elif attr_spec["type"] == 'binary': 795 decoded = self._decode_binary(attr, attr_spec) 796 elif attr_spec["type"] == 'flag': 797 decoded = True 798 elif attr_spec.is_auto_scalar: 799 decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order) 800 if 'enum' in attr_spec: 801 decoded = self._decode_enum(decoded, attr_spec) 802 elif attr_spec["type"] in NlAttr.type_formats: 803 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) 804 if 'enum' in attr_spec: 805 decoded = self._decode_enum(decoded, attr_spec) 806 elif attr_spec.display_hint: 807 decoded = self._formatted_string(decoded, attr_spec.display_hint) 808 elif attr_spec["type"] == 'indexed-array': 809 decoded = self._decode_array_attr(attr, attr_spec) 810 elif attr_spec["type"] == 'bitfield32': 811 value, selector = struct.unpack("II", attr.raw) 812 if 'enum' in attr_spec: 813 value = self._decode_enum(value, attr_spec) 814 selector = self._decode_enum(selector, attr_spec) 815 decoded = {"value": value, "selector": selector} 816 elif attr_spec["type"] == 'sub-message': 817 decoded = self._decode_sub_msg(attr, attr_spec, search_attrs) 818 elif attr_spec["type"] == 'nest-type-value': 819 decoded = self._decode_nest_type_value(attr, attr_spec) 820 else: 821 if not self.process_unknown: 822 raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') 823 decoded = self._decode_unknown(attr) 824 825 self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded) 826 except: 827 print(f"Error decoding '{attr_spec.name}' from '{space}'") 828 raise 829 830 return rsp 831 832 def _decode_extack_path(self, attrs, attr_set, offset, target, search_attrs): 833 for attr in attrs: 834 try: 835 attr_spec = attr_set.attrs_by_val[attr.type] 836 except KeyError: 837 raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") 838 if offset > target: 839 break 840 if offset == target: 841 return '.' + attr_spec.name 842 843 if offset + attr.full_len <= target: 844 offset += attr.full_len 845 continue 846 847 pathname = attr_spec.name 848 if attr_spec['type'] == 'nest': 849 sub_attrs = self.attr_sets[attr_spec['nested-attributes']] 850 search_attrs = SpaceAttrs(sub_attrs, search_attrs.lookup(attr_spec['name'])) 851 elif attr_spec['type'] == 'sub-message': 852 msg_format, value = self._resolve_selector(attr_spec, search_attrs) 853 if msg_format is None: 854 raise Exception(f"Can't resolve sub-message of {attr_spec['name']} for extack") 855 sub_attrs = self.attr_sets[msg_format.attr_set] 856 pathname += f"({value})" 857 else: 858 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") 859 offset += 4 860 subpath = self._decode_extack_path(NlAttrs(attr.raw), sub_attrs, 861 offset, target, search_attrs) 862 if subpath is None: 863 return None 864 return '.' + pathname + subpath 865 866 return None 867 868 def _decode_extack(self, request, op, extack, vals): 869 if 'bad-attr-offs' not in extack: 870 return 871 872 msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op) 873 offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header) 874 search_attrs = SpaceAttrs(op.attr_set, vals) 875 path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset, 876 extack['bad-attr-offs'], search_attrs) 877 if path: 878 del extack['bad-attr-offs'] 879 extack['bad-attr'] = path 880 881 def _struct_size(self, name): 882 if name: 883 members = self.consts[name].members 884 size = 0 885 for m in members: 886 if m.type in ['pad', 'binary']: 887 if m.struct: 888 size += self._struct_size(m.struct) 889 else: 890 size += m.len 891 else: 892 format = NlAttr.get_format(m.type, m.byte_order) 893 size += format.size 894 return size 895 else: 896 return 0 897 898 def _decode_struct(self, data, name): 899 members = self.consts[name].members 900 attrs = dict() 901 offset = 0 902 for m in members: 903 value = None 904 if m.type == 'pad': 905 offset += m.len 906 elif m.type == 'binary': 907 if m.struct: 908 len = self._struct_size(m.struct) 909 value = self._decode_struct(data[offset : offset + len], 910 m.struct) 911 offset += len 912 else: 913 value = data[offset : offset + m.len] 914 offset += m.len 915 else: 916 format = NlAttr.get_format(m.type, m.byte_order) 917 [ value ] = format.unpack_from(data, offset) 918 offset += format.size 919 if value is not None: 920 if m.enum: 921 value = self._decode_enum(value, m) 922 elif m.display_hint: 923 value = self._formatted_string(value, m.display_hint) 924 attrs[m.name] = value 925 return attrs 926 927 def _encode_struct(self, name, vals): 928 members = self.consts[name].members 929 attr_payload = b'' 930 for m in members: 931 value = vals.pop(m.name) if m.name in vals else None 932 if m.type == 'pad': 933 attr_payload += bytearray(m.len) 934 elif m.type == 'binary': 935 if m.struct: 936 if value is None: 937 value = dict() 938 attr_payload += self._encode_struct(m.struct, value) 939 else: 940 if value is None: 941 attr_payload += bytearray(m.len) 942 else: 943 attr_payload += bytes.fromhex(value) 944 else: 945 if value is None: 946 value = 0 947 format = NlAttr.get_format(m.type, m.byte_order) 948 attr_payload += format.pack(value) 949 return attr_payload 950 951 def _formatted_string(self, raw, display_hint): 952 if display_hint == 'mac': 953 formatted = ':'.join('%02x' % b for b in raw) 954 elif display_hint == 'hex': 955 if isinstance(raw, int): 956 formatted = hex(raw) 957 else: 958 formatted = bytes.hex(raw, ' ') 959 elif display_hint in [ 'ipv4', 'ipv6', 'ipv4-or-v6' ]: 960 formatted = format(ipaddress.ip_address(raw)) 961 elif display_hint == 'uuid': 962 formatted = str(uuid.UUID(bytes=raw)) 963 else: 964 formatted = raw 965 return formatted 966 967 def _from_string(self, string, attr_spec): 968 if attr_spec.display_hint in ['ipv4', 'ipv6', 'ipv4-or-v6']: 969 ip = ipaddress.ip_address(string) 970 if attr_spec['type'] == 'binary': 971 raw = ip.packed 972 else: 973 raw = int(ip) 974 elif attr_spec.display_hint == 'hex': 975 if attr_spec['type'] == 'binary': 976 raw = bytes.fromhex(string) 977 else: 978 raw = int(string, 16) 979 else: 980 raise Exception(f"Display hint '{attr_spec.display_hint}' not implemented" 981 f" when parsing '{attr_spec['name']}'") 982 return raw 983 984 def handle_ntf(self, decoded): 985 msg = dict() 986 if self.include_raw: 987 msg['raw'] = decoded 988 op = self.rsp_by_value[decoded.cmd()] 989 attrs = self._decode(decoded.raw_attrs, op.attr_set.name) 990 if op.fixed_header: 991 attrs.update(self._decode_struct(decoded.raw, op.fixed_header)) 992 993 msg['name'] = op['name'] 994 msg['msg'] = attrs 995 self.async_msg_queue.put(msg) 996 997 def check_ntf(self): 998 while True: 999 try: 1000 reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT) 1001 except BlockingIOError: 1002 return 1003 1004 nms = NlMsgs(reply) 1005 self._recv_dbg_print(reply, nms) 1006 for nl_msg in nms: 1007 if nl_msg.error: 1008 print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) 1009 print(nl_msg) 1010 continue 1011 if nl_msg.done: 1012 print("Netlink done while checking for ntf!?") 1013 continue 1014 1015 decoded = self.nlproto.decode(self, nl_msg, None) 1016 if decoded.cmd() not in self.async_msg_ids: 1017 print("Unexpected msg id while checking for ntf", decoded) 1018 continue 1019 1020 self.handle_ntf(decoded) 1021 1022 def poll_ntf(self, duration=None): 1023 start_time = time.time() 1024 selector = selectors.DefaultSelector() 1025 selector.register(self.sock, selectors.EVENT_READ) 1026 1027 while True: 1028 try: 1029 yield self.async_msg_queue.get_nowait() 1030 except queue.Empty: 1031 if duration is not None: 1032 timeout = start_time + duration - time.time() 1033 if timeout <= 0: 1034 return 1035 else: 1036 timeout = None 1037 events = selector.select(timeout) 1038 if events: 1039 self.check_ntf() 1040 1041 def operation_do_attributes(self, name): 1042 """ 1043 For a given operation name, find and return a supported 1044 set of attributes (as a dict). 1045 """ 1046 op = self.find_operation(name) 1047 if not op: 1048 return None 1049 1050 return op['do']['request']['attributes'].copy() 1051 1052 def _encode_message(self, op, vals, flags, req_seq): 1053 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK 1054 for flag in flags or []: 1055 nl_flags |= flag 1056 1057 msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq) 1058 if op.fixed_header: 1059 msg += self._encode_struct(op.fixed_header, vals) 1060 search_attrs = SpaceAttrs(op.attr_set, vals) 1061 for name, value in vals.items(): 1062 msg += self._add_attr(op.attr_set.name, name, value, search_attrs) 1063 msg = _genl_msg_finalize(msg) 1064 return msg 1065 1066 def _ops(self, ops): 1067 reqs_by_seq = {} 1068 req_seq = random.randint(1024, 65535) 1069 payload = b'' 1070 for (method, vals, flags) in ops: 1071 op = self.ops[method] 1072 msg = self._encode_message(op, vals, flags, req_seq) 1073 reqs_by_seq[req_seq] = (op, vals, msg, flags) 1074 payload += msg 1075 req_seq += 1 1076 1077 self.sock.send(payload, 0) 1078 1079 done = False 1080 rsp = [] 1081 op_rsp = [] 1082 while not done: 1083 reply = self.sock.recv(self._recv_size) 1084 nms = NlMsgs(reply) 1085 self._recv_dbg_print(reply, nms) 1086 for nl_msg in nms: 1087 if nl_msg.nl_seq in reqs_by_seq: 1088 (op, vals, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq] 1089 if nl_msg.extack: 1090 nl_msg.annotate_extack(op.attr_set) 1091 self._decode_extack(req_msg, op, nl_msg.extack, vals) 1092 else: 1093 op = None 1094 req_flags = [] 1095 1096 if nl_msg.error: 1097 raise NlError(nl_msg) 1098 if nl_msg.done: 1099 if nl_msg.extack: 1100 print("Netlink warning:") 1101 print(nl_msg) 1102 1103 if Netlink.NLM_F_DUMP in req_flags: 1104 rsp.append(op_rsp) 1105 elif not op_rsp: 1106 rsp.append(None) 1107 elif len(op_rsp) == 1: 1108 rsp.append(op_rsp[0]) 1109 else: 1110 rsp.append(op_rsp) 1111 op_rsp = [] 1112 1113 del reqs_by_seq[nl_msg.nl_seq] 1114 done = len(reqs_by_seq) == 0 1115 break 1116 1117 decoded = self.nlproto.decode(self, nl_msg, op) 1118 1119 # Check if this is a reply to our request 1120 if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value: 1121 if decoded.cmd() in self.async_msg_ids: 1122 self.handle_ntf(decoded) 1123 continue 1124 else: 1125 print('Unexpected message: ' + repr(decoded)) 1126 continue 1127 1128 rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name) 1129 if op.fixed_header: 1130 rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header)) 1131 op_rsp.append(rsp_msg) 1132 1133 return rsp 1134 1135 def _op(self, method, vals, flags=None, dump=False): 1136 req_flags = flags or [] 1137 if dump: 1138 req_flags.append(Netlink.NLM_F_DUMP) 1139 1140 ops = [(method, vals, req_flags)] 1141 return self._ops(ops)[0] 1142 1143 def do(self, method, vals, flags=None): 1144 return self._op(method, vals, flags) 1145 1146 def dump(self, method, vals): 1147 return self._op(method, vals, dump=True) 1148 1149 def do_multi(self, ops): 1150 return self._ops(ops) 1151