1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause 2 3from collections import namedtuple 4from enum import Enum 5import functools 6import os 7import random 8import socket 9import struct 10from struct import Struct 11import sys 12import yaml 13import ipaddress 14import uuid 15import queue 16import selectors 17import time 18 19from .nlspec import SpecFamily 20 21# 22# Generic Netlink code which should really be in some library, but I can't quickly find one. 23# 24 25 26class Netlink: 27 # Netlink socket 28 SOL_NETLINK = 270 29 30 NETLINK_ADD_MEMBERSHIP = 1 31 NETLINK_CAP_ACK = 10 32 NETLINK_EXT_ACK = 11 33 NETLINK_GET_STRICT_CHK = 12 34 35 # Netlink message 36 NLMSG_ERROR = 2 37 NLMSG_DONE = 3 38 39 NLM_F_REQUEST = 1 40 NLM_F_ACK = 4 41 NLM_F_ROOT = 0x100 42 NLM_F_MATCH = 0x200 43 44 NLM_F_REPLACE = 0x100 45 NLM_F_EXCL = 0x200 46 NLM_F_CREATE = 0x400 47 NLM_F_APPEND = 0x800 48 49 NLM_F_CAPPED = 0x100 50 NLM_F_ACK_TLVS = 0x200 51 52 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH 53 54 NLA_F_NESTED = 0x8000 55 NLA_F_NET_BYTEORDER = 0x4000 56 57 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER 58 59 # Genetlink defines 60 NETLINK_GENERIC = 16 61 62 GENL_ID_CTRL = 0x10 63 64 # nlctrl 65 CTRL_CMD_GETFAMILY = 3 66 67 CTRL_ATTR_FAMILY_ID = 1 68 CTRL_ATTR_FAMILY_NAME = 2 69 CTRL_ATTR_MAXATTR = 5 70 CTRL_ATTR_MCAST_GROUPS = 7 71 72 CTRL_ATTR_MCAST_GRP_NAME = 1 73 CTRL_ATTR_MCAST_GRP_ID = 2 74 75 # Extack types 76 NLMSGERR_ATTR_MSG = 1 77 NLMSGERR_ATTR_OFFS = 2 78 NLMSGERR_ATTR_COOKIE = 3 79 NLMSGERR_ATTR_POLICY = 4 80 NLMSGERR_ATTR_MISS_TYPE = 5 81 NLMSGERR_ATTR_MISS_NEST = 6 82 83 # Policy types 84 NL_POLICY_TYPE_ATTR_TYPE = 1 85 NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2 86 NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3 87 NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4 88 NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5 89 NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6 90 NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7 91 NL_POLICY_TYPE_ATTR_POLICY_IDX = 8 92 NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9 93 NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10 94 NL_POLICY_TYPE_ATTR_PAD = 11 95 NL_POLICY_TYPE_ATTR_MASK = 12 96 97 AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64', 98 's8', 's16', 's32', 's64', 99 'binary', 'string', 'nul-string', 100 'nested', 'nested-array', 101 'bitfield32', 'sint', 'uint']) 102 103class NlError(Exception): 104 def __init__(self, nl_msg): 105 self.nl_msg = nl_msg 106 self.error = -nl_msg.error 107 108 def __str__(self): 109 return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}" 110 111 112class ConfigError(Exception): 113 pass 114 115 116class NlAttr: 117 ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) 118 type_formats = { 119 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), 120 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), 121 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")), 122 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")), 123 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")), 124 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")), 125 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")), 126 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q")) 127 } 128 129 def __init__(self, raw, offset): 130 self._len, self._type = struct.unpack("HH", raw[offset : offset + 4]) 131 self.type = self._type & ~Netlink.NLA_TYPE_MASK 132 self.is_nest = self._type & Netlink.NLA_F_NESTED 133 self.payload_len = self._len 134 self.full_len = (self.payload_len + 3) & ~3 135 self.raw = raw[offset + 4 : offset + self.payload_len] 136 137 @classmethod 138 def get_format(cls, attr_type, byte_order=None): 139 format = cls.type_formats[attr_type] 140 if byte_order: 141 return format.big if byte_order == "big-endian" \ 142 else format.little 143 return format.native 144 145 def as_scalar(self, attr_type, byte_order=None): 146 format = self.get_format(attr_type, byte_order) 147 return format.unpack(self.raw)[0] 148 149 def as_auto_scalar(self, attr_type, byte_order=None): 150 if len(self.raw) != 4 and len(self.raw) != 8: 151 raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}") 152 real_type = attr_type[0] + str(len(self.raw) * 8) 153 format = self.get_format(real_type, byte_order) 154 return format.unpack(self.raw)[0] 155 156 def as_strz(self): 157 return self.raw.decode('ascii')[:-1] 158 159 def as_bin(self): 160 return self.raw 161 162 def as_c_array(self, type): 163 format = self.get_format(type) 164 return [ x[0] for x in format.iter_unpack(self.raw) ] 165 166 def __repr__(self): 167 return f"[type:{self.type} len:{self._len}] {self.raw}" 168 169 170class NlAttrs: 171 def __init__(self, msg, offset=0): 172 self.attrs = [] 173 174 while offset < len(msg): 175 attr = NlAttr(msg, offset) 176 offset += attr.full_len 177 self.attrs.append(attr) 178 179 def __iter__(self): 180 yield from self.attrs 181 182 def __repr__(self): 183 msg = '' 184 for a in self.attrs: 185 if msg: 186 msg += '\n' 187 msg += repr(a) 188 return msg 189 190 191class NlMsg: 192 def __init__(self, msg, offset, attr_space=None): 193 self.hdr = msg[offset : offset + 16] 194 195 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ 196 struct.unpack("IHHII", self.hdr) 197 198 self.raw = msg[offset + 16 : offset + self.nl_len] 199 200 self.error = 0 201 self.done = 0 202 203 extack_off = None 204 if self.nl_type == Netlink.NLMSG_ERROR: 205 self.error = struct.unpack("i", self.raw[0:4])[0] 206 self.done = 1 207 extack_off = 20 208 elif self.nl_type == Netlink.NLMSG_DONE: 209 self.error = struct.unpack("i", self.raw[0:4])[0] 210 self.done = 1 211 extack_off = 4 212 213 self.extack = None 214 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: 215 self.extack = dict() 216 extack_attrs = NlAttrs(self.raw[extack_off:]) 217 for extack in extack_attrs: 218 if extack.type == Netlink.NLMSGERR_ATTR_MSG: 219 self.extack['msg'] = extack.as_strz() 220 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: 221 self.extack['miss-type'] = extack.as_scalar('u32') 222 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: 223 self.extack['miss-nest'] = extack.as_scalar('u32') 224 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: 225 self.extack['bad-attr-offs'] = extack.as_scalar('u32') 226 elif extack.type == Netlink.NLMSGERR_ATTR_POLICY: 227 self.extack['policy'] = self._decode_policy(extack.raw) 228 else: 229 if 'unknown' not in self.extack: 230 self.extack['unknown'] = [] 231 self.extack['unknown'].append(extack) 232 233 if attr_space: 234 # We don't have the ability to parse nests yet, so only do global 235 if 'miss-type' in self.extack and 'miss-nest' not in self.extack: 236 miss_type = self.extack['miss-type'] 237 if miss_type in attr_space.attrs_by_val: 238 spec = attr_space.attrs_by_val[miss_type] 239 self.extack['miss-type'] = spec['name'] 240 if 'doc' in spec: 241 self.extack['miss-type-doc'] = spec['doc'] 242 243 def _decode_policy(self, raw): 244 policy = {} 245 for attr in NlAttrs(raw): 246 if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE: 247 type = attr.as_scalar('u32') 248 policy['type'] = Netlink.AttrType(type).name 249 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S: 250 policy['min-value'] = attr.as_scalar('s64') 251 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S: 252 policy['max-value'] = attr.as_scalar('s64') 253 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U: 254 policy['min-value'] = attr.as_scalar('u64') 255 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U: 256 policy['max-value'] = attr.as_scalar('u64') 257 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH: 258 policy['min-length'] = attr.as_scalar('u32') 259 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH: 260 policy['max-length'] = attr.as_scalar('u32') 261 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK: 262 policy['bitfield32-mask'] = attr.as_scalar('u32') 263 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK: 264 policy['mask'] = attr.as_scalar('u64') 265 return policy 266 267 def cmd(self): 268 return self.nl_type 269 270 def __repr__(self): 271 msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}" 272 if self.error: 273 msg += '\n\terror: ' + str(self.error) 274 if self.extack: 275 msg += '\n\textack: ' + repr(self.extack) 276 return msg 277 278 279class NlMsgs: 280 def __init__(self, data, attr_space=None): 281 self.msgs = [] 282 283 offset = 0 284 while offset < len(data): 285 msg = NlMsg(data, offset, attr_space=attr_space) 286 offset += msg.nl_len 287 self.msgs.append(msg) 288 289 def __iter__(self): 290 yield from self.msgs 291 292 293genl_family_name_to_id = None 294 295 296def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): 297 # we prepend length in _genl_msg_finalize() 298 if seq is None: 299 seq = random.randint(1, 1024) 300 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 301 genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0) 302 return nlmsg + genlmsg 303 304 305def _genl_msg_finalize(msg): 306 return struct.pack("I", len(msg) + 4) + msg 307 308 309def _genl_load_families(): 310 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: 311 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 312 313 msg = _genl_msg(Netlink.GENL_ID_CTRL, 314 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, 315 Netlink.CTRL_CMD_GETFAMILY, 1) 316 msg = _genl_msg_finalize(msg) 317 318 sock.send(msg, 0) 319 320 global genl_family_name_to_id 321 genl_family_name_to_id = dict() 322 323 while True: 324 reply = sock.recv(128 * 1024) 325 nms = NlMsgs(reply) 326 for nl_msg in nms: 327 if nl_msg.error: 328 print("Netlink error:", nl_msg.error) 329 return 330 if nl_msg.done: 331 return 332 333 gm = GenlMsg(nl_msg) 334 fam = dict() 335 for attr in NlAttrs(gm.raw): 336 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: 337 fam['id'] = attr.as_scalar('u16') 338 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: 339 fam['name'] = attr.as_strz() 340 elif attr.type == Netlink.CTRL_ATTR_MAXATTR: 341 fam['maxattr'] = attr.as_scalar('u32') 342 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: 343 fam['mcast'] = dict() 344 for entry in NlAttrs(attr.raw): 345 mcast_name = None 346 mcast_id = None 347 for entry_attr in NlAttrs(entry.raw): 348 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: 349 mcast_name = entry_attr.as_strz() 350 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: 351 mcast_id = entry_attr.as_scalar('u32') 352 if mcast_name and mcast_id is not None: 353 fam['mcast'][mcast_name] = mcast_id 354 if 'name' in fam and 'id' in fam: 355 genl_family_name_to_id[fam['name']] = fam 356 357 358class GenlMsg: 359 def __init__(self, nl_msg): 360 self.nl = nl_msg 361 self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0) 362 self.raw = nl_msg.raw[4:] 363 364 def cmd(self): 365 return self.genl_cmd 366 367 def __repr__(self): 368 msg = repr(self.nl) 369 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" 370 for a in self.raw_attrs: 371 msg += '\t\t' + repr(a) + '\n' 372 return msg 373 374 375class NetlinkProtocol: 376 def __init__(self, family_name, proto_num): 377 self.family_name = family_name 378 self.proto_num = proto_num 379 380 def _message(self, nl_type, nl_flags, seq=None): 381 if seq is None: 382 seq = random.randint(1, 1024) 383 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 384 return nlmsg 385 386 def message(self, flags, command, version, seq=None): 387 return self._message(command, flags, seq) 388 389 def _decode(self, nl_msg): 390 return nl_msg 391 392 def decode(self, ynl, nl_msg, op): 393 msg = self._decode(nl_msg) 394 if op is None: 395 op = ynl.rsp_by_value[msg.cmd()] 396 fixed_header_size = ynl._struct_size(op.fixed_header) 397 msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size) 398 return msg 399 400 def get_mcast_id(self, mcast_name, mcast_groups): 401 if mcast_name not in mcast_groups: 402 raise Exception(f'Multicast group "{mcast_name}" not present in the spec') 403 return mcast_groups[mcast_name].value 404 405 def msghdr_size(self): 406 return 16 407 408 409class GenlProtocol(NetlinkProtocol): 410 def __init__(self, family_name): 411 super().__init__(family_name, Netlink.NETLINK_GENERIC) 412 413 global genl_family_name_to_id 414 if genl_family_name_to_id is None: 415 _genl_load_families() 416 417 self.genl_family = genl_family_name_to_id[family_name] 418 self.family_id = genl_family_name_to_id[family_name]['id'] 419 420 def message(self, flags, command, version, seq=None): 421 nlmsg = self._message(self.family_id, flags, seq) 422 genlmsg = struct.pack("BBH", command, version, 0) 423 return nlmsg + genlmsg 424 425 def _decode(self, nl_msg): 426 return GenlMsg(nl_msg) 427 428 def get_mcast_id(self, mcast_name, mcast_groups): 429 if mcast_name not in self.genl_family['mcast']: 430 raise Exception(f'Multicast group "{mcast_name}" not present in the family') 431 return self.genl_family['mcast'][mcast_name] 432 433 def msghdr_size(self): 434 return super().msghdr_size() + 4 435 436 437class SpaceAttrs: 438 SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values']) 439 440 def __init__(self, attr_space, attrs, outer = None): 441 outer_scopes = outer.scopes if outer else [] 442 inner_scope = self.SpecValuesPair(attr_space, attrs) 443 self.scopes = [inner_scope] + outer_scopes 444 445 def lookup(self, name): 446 for scope in self.scopes: 447 if name in scope.spec: 448 if name in scope.values: 449 return scope.values[name] 450 spec_name = scope.spec.yaml['name'] 451 raise Exception( 452 f"No value for '{name}' in attribute space '{spec_name}'") 453 raise Exception(f"Attribute '{name}' not defined in any attribute-set") 454 455 456# 457# YNL implementation details. 458# 459 460 461class YnlFamily(SpecFamily): 462 def __init__(self, def_path, schema=None, process_unknown=False, 463 recv_size=0): 464 super().__init__(def_path, schema) 465 466 self.include_raw = False 467 self.process_unknown = process_unknown 468 469 try: 470 if self.proto == "netlink-raw": 471 self.nlproto = NetlinkProtocol(self.yaml['name'], 472 self.yaml['protonum']) 473 else: 474 self.nlproto = GenlProtocol(self.yaml['name']) 475 except KeyError: 476 raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel") 477 478 self._recv_dbg = False 479 # Note that netlink will use conservative (min) message size for 480 # the first dump recv() on the socket, our setting will only matter 481 # from the second recv() on. 482 self._recv_size = recv_size if recv_size else 131072 483 # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo) 484 # for a message, so smaller receive sizes will lead to truncation. 485 # Note that the min size for other families may be larger than 4k! 486 if self._recv_size < 4000: 487 raise ConfigError() 488 489 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num) 490 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 491 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) 492 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1) 493 494 self.async_msg_ids = set() 495 self.async_msg_queue = queue.Queue() 496 497 for msg in self.msgs.values(): 498 if msg.is_async: 499 self.async_msg_ids.add(msg.rsp_value) 500 501 for op_name, op in self.ops.items(): 502 bound_f = functools.partial(self._op, op_name) 503 setattr(self, op.ident_name, bound_f) 504 505 506 def ntf_subscribe(self, mcast_name): 507 mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups) 508 self.sock.bind((0, 0)) 509 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, 510 mcast_id) 511 512 def set_recv_dbg(self, enabled): 513 self._recv_dbg = enabled 514 515 def _recv_dbg_print(self, reply, nl_msgs): 516 if not self._recv_dbg: 517 return 518 print("Recv: read", len(reply), "bytes,", 519 len(nl_msgs.msgs), "messages", file=sys.stderr) 520 for nl_msg in nl_msgs: 521 print(" ", nl_msg, file=sys.stderr) 522 523 def _encode_enum(self, attr_spec, value): 524 enum = self.consts[attr_spec['enum']] 525 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 526 scalar = 0 527 if isinstance(value, str): 528 value = [value] 529 for single_value in value: 530 scalar += enum.entries[single_value].user_value(as_flags = True) 531 return scalar 532 else: 533 return enum.entries[value].user_value() 534 535 def _get_scalar(self, attr_spec, value): 536 try: 537 return int(value) 538 except (ValueError, TypeError) as e: 539 if 'enum' not in attr_spec: 540 raise e 541 return self._encode_enum(attr_spec, value) 542 543 def _add_attr(self, space, name, value, search_attrs): 544 try: 545 attr = self.attr_sets[space][name] 546 except KeyError: 547 raise Exception(f"Space '{space}' has no attribute '{name}'") 548 nl_type = attr.value 549 550 if attr.is_multi and isinstance(value, list): 551 attr_payload = b'' 552 for subvalue in value: 553 attr_payload += self._add_attr(space, name, subvalue, search_attrs) 554 return attr_payload 555 556 if attr["type"] == 'nest': 557 nl_type |= Netlink.NLA_F_NESTED 558 attr_payload = b'' 559 sub_space = attr['nested-attributes'] 560 sub_attrs = SpaceAttrs(self.attr_sets[sub_space], value, search_attrs) 561 for subname, subvalue in value.items(): 562 attr_payload += self._add_attr(sub_space, subname, subvalue, sub_attrs) 563 elif attr["type"] == 'flag': 564 if not value: 565 # If value is absent or false then skip attribute creation. 566 return b'' 567 attr_payload = b'' 568 elif attr["type"] == 'string': 569 attr_payload = str(value).encode('ascii') + b'\x00' 570 elif attr["type"] == 'binary': 571 if isinstance(value, bytes): 572 attr_payload = value 573 elif isinstance(value, str): 574 attr_payload = bytes.fromhex(value) 575 elif isinstance(value, dict) and attr.struct_name: 576 attr_payload = self._encode_struct(attr.struct_name, value) 577 else: 578 raise Exception(f'Unknown type for binary attribute, value: {value}') 579 elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar: 580 scalar = self._get_scalar(attr, value) 581 if attr.is_auto_scalar: 582 attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64') 583 else: 584 attr_type = attr["type"] 585 format = NlAttr.get_format(attr_type, attr.byte_order) 586 attr_payload = format.pack(scalar) 587 elif attr['type'] in "bitfield32": 588 scalar_value = self._get_scalar(attr, value["value"]) 589 scalar_selector = self._get_scalar(attr, value["selector"]) 590 attr_payload = struct.pack("II", scalar_value, scalar_selector) 591 elif attr['type'] == 'sub-message': 592 msg_format = self._resolve_selector(attr, search_attrs) 593 attr_payload = b'' 594 if msg_format.fixed_header: 595 attr_payload += self._encode_struct(msg_format.fixed_header, value) 596 if msg_format.attr_set: 597 if msg_format.attr_set in self.attr_sets: 598 nl_type |= Netlink.NLA_F_NESTED 599 sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs) 600 for subname, subvalue in value.items(): 601 attr_payload += self._add_attr(msg_format.attr_set, 602 subname, subvalue, sub_attrs) 603 else: 604 raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'") 605 else: 606 raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') 607 608 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) 609 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad 610 611 def _decode_enum(self, raw, attr_spec): 612 enum = self.consts[attr_spec['enum']] 613 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 614 i = 0 615 value = set() 616 while raw: 617 if raw & 1: 618 value.add(enum.entries_by_val[i].name) 619 raw >>= 1 620 i += 1 621 else: 622 value = enum.entries_by_val[raw].name 623 return value 624 625 def _decode_binary(self, attr, attr_spec): 626 if attr_spec.struct_name: 627 decoded = self._decode_struct(attr.raw, attr_spec.struct_name) 628 elif attr_spec.sub_type: 629 decoded = attr.as_c_array(attr_spec.sub_type) 630 if 'enum' in attr_spec: 631 decoded = [ self._decode_enum(x, attr_spec) for x in decoded ] 632 elif attr_spec.display_hint: 633 decoded = [ self._formatted_string(x, attr_spec.display_hint) 634 for x in decoded ] 635 else: 636 decoded = attr.as_bin() 637 if attr_spec.display_hint: 638 decoded = self._formatted_string(decoded, attr_spec.display_hint) 639 return decoded 640 641 def _decode_array_attr(self, attr, attr_spec): 642 decoded = [] 643 offset = 0 644 while offset < len(attr.raw): 645 item = NlAttr(attr.raw, offset) 646 offset += item.full_len 647 648 if attr_spec["sub-type"] == 'nest': 649 subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes']) 650 decoded.append({ item.type: subattrs }) 651 elif attr_spec["sub-type"] == 'binary': 652 subattr = item.as_bin() 653 if attr_spec.display_hint: 654 subattr = self._formatted_string(subattr, attr_spec.display_hint) 655 decoded.append(subattr) 656 elif attr_spec["sub-type"] in NlAttr.type_formats: 657 subattr = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order) 658 if 'enum' in attr_spec: 659 subattr = self._decode_enum(subattr, attr_spec) 660 elif attr_spec.display_hint: 661 subattr = self._formatted_string(subattr, attr_spec.display_hint) 662 decoded.append(subattr) 663 else: 664 raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}') 665 return decoded 666 667 def _decode_nest_type_value(self, attr, attr_spec): 668 decoded = {} 669 value = attr 670 for name in attr_spec['type-value']: 671 value = NlAttr(value.raw, 0) 672 decoded[name] = value.type 673 subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes']) 674 decoded.update(subattrs) 675 return decoded 676 677 def _decode_unknown(self, attr): 678 if attr.is_nest: 679 return self._decode(NlAttrs(attr.raw), None) 680 else: 681 return attr.as_bin() 682 683 def _rsp_add(self, rsp, name, is_multi, decoded): 684 if is_multi == None: 685 if name in rsp and type(rsp[name]) is not list: 686 rsp[name] = [rsp[name]] 687 is_multi = True 688 else: 689 is_multi = False 690 691 if not is_multi: 692 rsp[name] = decoded 693 elif name in rsp: 694 rsp[name].append(decoded) 695 else: 696 rsp[name] = [decoded] 697 698 def _resolve_selector(self, attr_spec, search_attrs): 699 sub_msg = attr_spec.sub_message 700 if sub_msg not in self.sub_msgs: 701 raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}") 702 sub_msg_spec = self.sub_msgs[sub_msg] 703 704 selector = attr_spec.selector 705 value = search_attrs.lookup(selector) 706 if value not in sub_msg_spec.formats: 707 raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'") 708 709 spec = sub_msg_spec.formats[value] 710 return spec 711 712 def _decode_sub_msg(self, attr, attr_spec, search_attrs): 713 msg_format = self._resolve_selector(attr_spec, search_attrs) 714 decoded = {} 715 offset = 0 716 if msg_format.fixed_header: 717 decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header)); 718 offset = self._struct_size(msg_format.fixed_header) 719 if msg_format.attr_set: 720 if msg_format.attr_set in self.attr_sets: 721 subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set) 722 decoded.update(subdict) 723 else: 724 raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'") 725 return decoded 726 727 def _decode(self, attrs, space, outer_attrs = None): 728 rsp = dict() 729 if space: 730 attr_space = self.attr_sets[space] 731 search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs) 732 733 for attr in attrs: 734 try: 735 attr_spec = attr_space.attrs_by_val[attr.type] 736 except (KeyError, UnboundLocalError): 737 if not self.process_unknown: 738 raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'") 739 attr_name = f"UnknownAttr({attr.type})" 740 self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr)) 741 continue 742 743 try: 744 if attr_spec["type"] == 'nest': 745 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs) 746 decoded = subdict 747 elif attr_spec["type"] == 'string': 748 decoded = attr.as_strz() 749 elif attr_spec["type"] == 'binary': 750 decoded = self._decode_binary(attr, attr_spec) 751 elif attr_spec["type"] == 'flag': 752 decoded = True 753 elif attr_spec.is_auto_scalar: 754 decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order) 755 elif attr_spec["type"] in NlAttr.type_formats: 756 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) 757 if 'enum' in attr_spec: 758 decoded = self._decode_enum(decoded, attr_spec) 759 elif attr_spec.display_hint: 760 decoded = self._formatted_string(decoded, attr_spec.display_hint) 761 elif attr_spec["type"] == 'indexed-array': 762 decoded = self._decode_array_attr(attr, attr_spec) 763 elif attr_spec["type"] == 'bitfield32': 764 value, selector = struct.unpack("II", attr.raw) 765 if 'enum' in attr_spec: 766 value = self._decode_enum(value, attr_spec) 767 selector = self._decode_enum(selector, attr_spec) 768 decoded = {"value": value, "selector": selector} 769 elif attr_spec["type"] == 'sub-message': 770 decoded = self._decode_sub_msg(attr, attr_spec, search_attrs) 771 elif attr_spec["type"] == 'nest-type-value': 772 decoded = self._decode_nest_type_value(attr, attr_spec) 773 else: 774 if not self.process_unknown: 775 raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') 776 decoded = self._decode_unknown(attr) 777 778 self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded) 779 except: 780 print(f"Error decoding '{attr_spec.name}' from '{space}'") 781 raise 782 783 return rsp 784 785 def _decode_extack_path(self, attrs, attr_set, offset, target): 786 for attr in attrs: 787 try: 788 attr_spec = attr_set.attrs_by_val[attr.type] 789 except KeyError: 790 raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") 791 if offset > target: 792 break 793 if offset == target: 794 return '.' + attr_spec.name 795 796 if offset + attr.full_len <= target: 797 offset += attr.full_len 798 continue 799 if attr_spec['type'] != 'nest': 800 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") 801 offset += 4 802 subpath = self._decode_extack_path(NlAttrs(attr.raw), 803 self.attr_sets[attr_spec['nested-attributes']], 804 offset, target) 805 if subpath is None: 806 return None 807 return '.' + attr_spec.name + subpath 808 809 return None 810 811 def _decode_extack(self, request, op, extack): 812 if 'bad-attr-offs' not in extack: 813 return 814 815 msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op) 816 offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header) 817 path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset, 818 extack['bad-attr-offs']) 819 if path: 820 del extack['bad-attr-offs'] 821 extack['bad-attr'] = path 822 823 def _struct_size(self, name): 824 if name: 825 members = self.consts[name].members 826 size = 0 827 for m in members: 828 if m.type in ['pad', 'binary']: 829 if m.struct: 830 size += self._struct_size(m.struct) 831 else: 832 size += m.len 833 else: 834 format = NlAttr.get_format(m.type, m.byte_order) 835 size += format.size 836 return size 837 else: 838 return 0 839 840 def _decode_struct(self, data, name): 841 members = self.consts[name].members 842 attrs = dict() 843 offset = 0 844 for m in members: 845 value = None 846 if m.type == 'pad': 847 offset += m.len 848 elif m.type == 'binary': 849 if m.struct: 850 len = self._struct_size(m.struct) 851 value = self._decode_struct(data[offset : offset + len], 852 m.struct) 853 offset += len 854 else: 855 value = data[offset : offset + m.len] 856 offset += m.len 857 else: 858 format = NlAttr.get_format(m.type, m.byte_order) 859 [ value ] = format.unpack_from(data, offset) 860 offset += format.size 861 if value is not None: 862 if m.enum: 863 value = self._decode_enum(value, m) 864 elif m.display_hint: 865 value = self._formatted_string(value, m.display_hint) 866 attrs[m.name] = value 867 return attrs 868 869 def _encode_struct(self, name, vals): 870 members = self.consts[name].members 871 attr_payload = b'' 872 for m in members: 873 value = vals.pop(m.name) if m.name in vals else None 874 if m.type == 'pad': 875 attr_payload += bytearray(m.len) 876 elif m.type == 'binary': 877 if m.struct: 878 if value is None: 879 value = dict() 880 attr_payload += self._encode_struct(m.struct, value) 881 else: 882 if value is None: 883 attr_payload += bytearray(m.len) 884 else: 885 attr_payload += bytes.fromhex(value) 886 else: 887 if value is None: 888 value = 0 889 format = NlAttr.get_format(m.type, m.byte_order) 890 attr_payload += format.pack(value) 891 return attr_payload 892 893 def _formatted_string(self, raw, display_hint): 894 if display_hint == 'mac': 895 formatted = ':'.join('%02x' % b for b in raw) 896 elif display_hint == 'hex': 897 if isinstance(raw, int): 898 formatted = hex(raw) 899 else: 900 formatted = bytes.hex(raw, ' ') 901 elif display_hint in [ 'ipv4', 'ipv6' ]: 902 formatted = format(ipaddress.ip_address(raw)) 903 elif display_hint == 'uuid': 904 formatted = str(uuid.UUID(bytes=raw)) 905 else: 906 formatted = raw 907 return formatted 908 909 def handle_ntf(self, decoded): 910 msg = dict() 911 if self.include_raw: 912 msg['raw'] = decoded 913 op = self.rsp_by_value[decoded.cmd()] 914 attrs = self._decode(decoded.raw_attrs, op.attr_set.name) 915 if op.fixed_header: 916 attrs.update(self._decode_struct(decoded.raw, op.fixed_header)) 917 918 msg['name'] = op['name'] 919 msg['msg'] = attrs 920 self.async_msg_queue.put(msg) 921 922 def check_ntf(self): 923 while True: 924 try: 925 reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT) 926 except BlockingIOError: 927 return 928 929 nms = NlMsgs(reply) 930 self._recv_dbg_print(reply, nms) 931 for nl_msg in nms: 932 if nl_msg.error: 933 print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) 934 print(nl_msg) 935 continue 936 if nl_msg.done: 937 print("Netlink done while checking for ntf!?") 938 continue 939 940 decoded = self.nlproto.decode(self, nl_msg, None) 941 if decoded.cmd() not in self.async_msg_ids: 942 print("Unexpected msg id while checking for ntf", decoded) 943 continue 944 945 self.handle_ntf(decoded) 946 947 def poll_ntf(self, duration=None): 948 start_time = time.time() 949 selector = selectors.DefaultSelector() 950 selector.register(self.sock, selectors.EVENT_READ) 951 952 while True: 953 try: 954 yield self.async_msg_queue.get_nowait() 955 except queue.Empty: 956 if duration is not None: 957 timeout = start_time + duration - time.time() 958 if timeout <= 0: 959 return 960 else: 961 timeout = None 962 events = selector.select(timeout) 963 if events: 964 self.check_ntf() 965 966 def operation_do_attributes(self, name): 967 """ 968 For a given operation name, find and return a supported 969 set of attributes (as a dict). 970 """ 971 op = self.find_operation(name) 972 if not op: 973 return None 974 975 return op['do']['request']['attributes'].copy() 976 977 def _encode_message(self, op, vals, flags, req_seq): 978 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK 979 for flag in flags or []: 980 nl_flags |= flag 981 982 msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq) 983 if op.fixed_header: 984 msg += self._encode_struct(op.fixed_header, vals) 985 search_attrs = SpaceAttrs(op.attr_set, vals) 986 for name, value in vals.items(): 987 msg += self._add_attr(op.attr_set.name, name, value, search_attrs) 988 msg = _genl_msg_finalize(msg) 989 return msg 990 991 def _ops(self, ops): 992 reqs_by_seq = {} 993 req_seq = random.randint(1024, 65535) 994 payload = b'' 995 for (method, vals, flags) in ops: 996 op = self.ops[method] 997 msg = self._encode_message(op, vals, flags, req_seq) 998 reqs_by_seq[req_seq] = (op, msg, flags) 999 payload += msg 1000 req_seq += 1 1001 1002 self.sock.send(payload, 0) 1003 1004 done = False 1005 rsp = [] 1006 op_rsp = [] 1007 while not done: 1008 reply = self.sock.recv(self._recv_size) 1009 nms = NlMsgs(reply, attr_space=op.attr_set) 1010 self._recv_dbg_print(reply, nms) 1011 for nl_msg in nms: 1012 if nl_msg.nl_seq in reqs_by_seq: 1013 (op, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq] 1014 if nl_msg.extack: 1015 self._decode_extack(req_msg, op, nl_msg.extack) 1016 else: 1017 op = None 1018 req_flags = [] 1019 1020 if nl_msg.error: 1021 raise NlError(nl_msg) 1022 if nl_msg.done: 1023 if nl_msg.extack: 1024 print("Netlink warning:") 1025 print(nl_msg) 1026 1027 if Netlink.NLM_F_DUMP in req_flags: 1028 rsp.append(op_rsp) 1029 elif not op_rsp: 1030 rsp.append(None) 1031 elif len(op_rsp) == 1: 1032 rsp.append(op_rsp[0]) 1033 else: 1034 rsp.append(op_rsp) 1035 op_rsp = [] 1036 1037 del reqs_by_seq[nl_msg.nl_seq] 1038 done = len(reqs_by_seq) == 0 1039 break 1040 1041 decoded = self.nlproto.decode(self, nl_msg, op) 1042 1043 # Check if this is a reply to our request 1044 if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value: 1045 if decoded.cmd() in self.async_msg_ids: 1046 self.handle_ntf(decoded) 1047 continue 1048 else: 1049 print('Unexpected message: ' + repr(decoded)) 1050 continue 1051 1052 rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name) 1053 if op.fixed_header: 1054 rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header)) 1055 op_rsp.append(rsp_msg) 1056 1057 return rsp 1058 1059 def _op(self, method, vals, flags=None, dump=False): 1060 req_flags = flags or [] 1061 if dump: 1062 req_flags.append(Netlink.NLM_F_DUMP) 1063 1064 ops = [(method, vals, req_flags)] 1065 return self._ops(ops)[0] 1066 1067 def do(self, method, vals, flags=None): 1068 return self._op(method, vals, flags) 1069 1070 def dump(self, method, vals): 1071 return self._op(method, vals, dump=True) 1072 1073 def do_multi(self, ops): 1074 return self._ops(ops) 1075