1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause 2 3from collections import namedtuple 4from enum import Enum 5import functools 6import os 7import random 8import socket 9import struct 10from struct import Struct 11import sys 12import yaml 13import ipaddress 14import uuid 15import queue 16import selectors 17import time 18 19from .nlspec import SpecFamily 20 21# 22# Generic Netlink code which should really be in some library, but I can't quickly find one. 23# 24 25 26class Netlink: 27 # Netlink socket 28 SOL_NETLINK = 270 29 30 NETLINK_ADD_MEMBERSHIP = 1 31 NETLINK_CAP_ACK = 10 32 NETLINK_EXT_ACK = 11 33 NETLINK_GET_STRICT_CHK = 12 34 35 # Netlink message 36 NLMSG_ERROR = 2 37 NLMSG_DONE = 3 38 39 NLM_F_REQUEST = 1 40 NLM_F_ACK = 4 41 NLM_F_ROOT = 0x100 42 NLM_F_MATCH = 0x200 43 44 NLM_F_REPLACE = 0x100 45 NLM_F_EXCL = 0x200 46 NLM_F_CREATE = 0x400 47 NLM_F_APPEND = 0x800 48 49 NLM_F_CAPPED = 0x100 50 NLM_F_ACK_TLVS = 0x200 51 52 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH 53 54 NLA_F_NESTED = 0x8000 55 NLA_F_NET_BYTEORDER = 0x4000 56 57 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER 58 59 # Genetlink defines 60 NETLINK_GENERIC = 16 61 62 GENL_ID_CTRL = 0x10 63 64 # nlctrl 65 CTRL_CMD_GETFAMILY = 3 66 67 CTRL_ATTR_FAMILY_ID = 1 68 CTRL_ATTR_FAMILY_NAME = 2 69 CTRL_ATTR_MAXATTR = 5 70 CTRL_ATTR_MCAST_GROUPS = 7 71 72 CTRL_ATTR_MCAST_GRP_NAME = 1 73 CTRL_ATTR_MCAST_GRP_ID = 2 74 75 # Extack types 76 NLMSGERR_ATTR_MSG = 1 77 NLMSGERR_ATTR_OFFS = 2 78 NLMSGERR_ATTR_COOKIE = 3 79 NLMSGERR_ATTR_POLICY = 4 80 NLMSGERR_ATTR_MISS_TYPE = 5 81 NLMSGERR_ATTR_MISS_NEST = 6 82 83 # Policy types 84 NL_POLICY_TYPE_ATTR_TYPE = 1 85 NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2 86 NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3 87 NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4 88 NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5 89 NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6 90 NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7 91 NL_POLICY_TYPE_ATTR_POLICY_IDX = 8 92 NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9 93 NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10 94 NL_POLICY_TYPE_ATTR_PAD = 11 95 NL_POLICY_TYPE_ATTR_MASK = 12 96 97 AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64', 98 's8', 's16', 's32', 's64', 99 'binary', 'string', 'nul-string', 100 'nested', 'nested-array', 101 'bitfield32', 'sint', 'uint']) 102 103class NlError(Exception): 104 def __init__(self, nl_msg): 105 self.nl_msg = nl_msg 106 self.error = -nl_msg.error 107 108 def __str__(self): 109 return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}" 110 111 112class ConfigError(Exception): 113 pass 114 115 116class NlAttr: 117 ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) 118 type_formats = { 119 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), 120 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), 121 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")), 122 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")), 123 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")), 124 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")), 125 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")), 126 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q")) 127 } 128 129 def __init__(self, raw, offset): 130 self._len, self._type = struct.unpack("HH", raw[offset : offset + 4]) 131 self.type = self._type & ~Netlink.NLA_TYPE_MASK 132 self.is_nest = self._type & Netlink.NLA_F_NESTED 133 self.payload_len = self._len 134 self.full_len = (self.payload_len + 3) & ~3 135 self.raw = raw[offset + 4 : offset + self.payload_len] 136 137 @classmethod 138 def get_format(cls, attr_type, byte_order=None): 139 format = cls.type_formats[attr_type] 140 if byte_order: 141 return format.big if byte_order == "big-endian" \ 142 else format.little 143 return format.native 144 145 def as_scalar(self, attr_type, byte_order=None): 146 format = self.get_format(attr_type, byte_order) 147 return format.unpack(self.raw)[0] 148 149 def as_auto_scalar(self, attr_type, byte_order=None): 150 if len(self.raw) != 4 and len(self.raw) != 8: 151 raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}") 152 real_type = attr_type[0] + str(len(self.raw) * 8) 153 format = self.get_format(real_type, byte_order) 154 return format.unpack(self.raw)[0] 155 156 def as_strz(self): 157 return self.raw.decode('ascii')[:-1] 158 159 def as_bin(self): 160 return self.raw 161 162 def as_c_array(self, type): 163 format = self.get_format(type) 164 return [ x[0] for x in format.iter_unpack(self.raw) ] 165 166 def __repr__(self): 167 return f"[type:{self.type} len:{self._len}] {self.raw}" 168 169 170class NlAttrs: 171 def __init__(self, msg, offset=0): 172 self.attrs = [] 173 174 while offset < len(msg): 175 attr = NlAttr(msg, offset) 176 offset += attr.full_len 177 self.attrs.append(attr) 178 179 def __iter__(self): 180 yield from self.attrs 181 182 def __repr__(self): 183 msg = '' 184 for a in self.attrs: 185 if msg: 186 msg += '\n' 187 msg += repr(a) 188 return msg 189 190 191class NlMsg: 192 def __init__(self, msg, offset, attr_space=None): 193 self.hdr = msg[offset : offset + 16] 194 195 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ 196 struct.unpack("IHHII", self.hdr) 197 198 self.raw = msg[offset + 16 : offset + self.nl_len] 199 200 self.error = 0 201 self.done = 0 202 203 extack_off = None 204 if self.nl_type == Netlink.NLMSG_ERROR: 205 self.error = struct.unpack("i", self.raw[0:4])[0] 206 self.done = 1 207 extack_off = 20 208 elif self.nl_type == Netlink.NLMSG_DONE: 209 self.error = struct.unpack("i", self.raw[0:4])[0] 210 self.done = 1 211 extack_off = 4 212 213 self.extack = None 214 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: 215 self.extack = dict() 216 extack_attrs = NlAttrs(self.raw[extack_off:]) 217 for extack in extack_attrs: 218 if extack.type == Netlink.NLMSGERR_ATTR_MSG: 219 self.extack['msg'] = extack.as_strz() 220 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: 221 self.extack['miss-type'] = extack.as_scalar('u32') 222 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: 223 self.extack['miss-nest'] = extack.as_scalar('u32') 224 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: 225 self.extack['bad-attr-offs'] = extack.as_scalar('u32') 226 elif extack.type == Netlink.NLMSGERR_ATTR_POLICY: 227 self.extack['policy'] = self._decode_policy(extack.raw) 228 else: 229 if 'unknown' not in self.extack: 230 self.extack['unknown'] = [] 231 self.extack['unknown'].append(extack) 232 233 if attr_space: 234 # We don't have the ability to parse nests yet, so only do global 235 if 'miss-type' in self.extack and 'miss-nest' not in self.extack: 236 miss_type = self.extack['miss-type'] 237 if miss_type in attr_space.attrs_by_val: 238 spec = attr_space.attrs_by_val[miss_type] 239 self.extack['miss-type'] = spec['name'] 240 if 'doc' in spec: 241 self.extack['miss-type-doc'] = spec['doc'] 242 243 def _decode_policy(self, raw): 244 policy = {} 245 for attr in NlAttrs(raw): 246 if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE: 247 type = attr.as_scalar('u32') 248 policy['type'] = Netlink.AttrType(type).name 249 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S: 250 policy['min-value'] = attr.as_scalar('s64') 251 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S: 252 policy['max-value'] = attr.as_scalar('s64') 253 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U: 254 policy['min-value'] = attr.as_scalar('u64') 255 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U: 256 policy['max-value'] = attr.as_scalar('u64') 257 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH: 258 policy['min-length'] = attr.as_scalar('u32') 259 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH: 260 policy['max-length'] = attr.as_scalar('u32') 261 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK: 262 policy['bitfield32-mask'] = attr.as_scalar('u32') 263 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK: 264 policy['mask'] = attr.as_scalar('u64') 265 return policy 266 267 def cmd(self): 268 return self.nl_type 269 270 def __repr__(self): 271 msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}" 272 if self.error: 273 msg += '\n\terror: ' + str(self.error) 274 if self.extack: 275 msg += '\n\textack: ' + repr(self.extack) 276 return msg 277 278 279class NlMsgs: 280 def __init__(self, data, attr_space=None): 281 self.msgs = [] 282 283 offset = 0 284 while offset < len(data): 285 msg = NlMsg(data, offset, attr_space=attr_space) 286 offset += msg.nl_len 287 self.msgs.append(msg) 288 289 def __iter__(self): 290 yield from self.msgs 291 292 293genl_family_name_to_id = None 294 295 296def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): 297 # we prepend length in _genl_msg_finalize() 298 if seq is None: 299 seq = random.randint(1, 1024) 300 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 301 genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0) 302 return nlmsg + genlmsg 303 304 305def _genl_msg_finalize(msg): 306 return struct.pack("I", len(msg) + 4) + msg 307 308 309def _genl_load_families(): 310 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: 311 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 312 313 msg = _genl_msg(Netlink.GENL_ID_CTRL, 314 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, 315 Netlink.CTRL_CMD_GETFAMILY, 1) 316 msg = _genl_msg_finalize(msg) 317 318 sock.send(msg, 0) 319 320 global genl_family_name_to_id 321 genl_family_name_to_id = dict() 322 323 while True: 324 reply = sock.recv(128 * 1024) 325 nms = NlMsgs(reply) 326 for nl_msg in nms: 327 if nl_msg.error: 328 print("Netlink error:", nl_msg.error) 329 return 330 if nl_msg.done: 331 return 332 333 gm = GenlMsg(nl_msg) 334 fam = dict() 335 for attr in NlAttrs(gm.raw): 336 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: 337 fam['id'] = attr.as_scalar('u16') 338 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: 339 fam['name'] = attr.as_strz() 340 elif attr.type == Netlink.CTRL_ATTR_MAXATTR: 341 fam['maxattr'] = attr.as_scalar('u32') 342 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: 343 fam['mcast'] = dict() 344 for entry in NlAttrs(attr.raw): 345 mcast_name = None 346 mcast_id = None 347 for entry_attr in NlAttrs(entry.raw): 348 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: 349 mcast_name = entry_attr.as_strz() 350 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: 351 mcast_id = entry_attr.as_scalar('u32') 352 if mcast_name and mcast_id is not None: 353 fam['mcast'][mcast_name] = mcast_id 354 if 'name' in fam and 'id' in fam: 355 genl_family_name_to_id[fam['name']] = fam 356 357 358class GenlMsg: 359 def __init__(self, nl_msg): 360 self.nl = nl_msg 361 self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0) 362 self.raw = nl_msg.raw[4:] 363 364 def cmd(self): 365 return self.genl_cmd 366 367 def __repr__(self): 368 msg = repr(self.nl) 369 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" 370 for a in self.raw_attrs: 371 msg += '\t\t' + repr(a) + '\n' 372 return msg 373 374 375class NetlinkProtocol: 376 def __init__(self, family_name, proto_num): 377 self.family_name = family_name 378 self.proto_num = proto_num 379 380 def _message(self, nl_type, nl_flags, seq=None): 381 if seq is None: 382 seq = random.randint(1, 1024) 383 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 384 return nlmsg 385 386 def message(self, flags, command, version, seq=None): 387 return self._message(command, flags, seq) 388 389 def _decode(self, nl_msg): 390 return nl_msg 391 392 def decode(self, ynl, nl_msg, op): 393 msg = self._decode(nl_msg) 394 if op is None: 395 op = ynl.rsp_by_value[msg.cmd()] 396 fixed_header_size = ynl._struct_size(op.fixed_header) 397 msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size) 398 return msg 399 400 def get_mcast_id(self, mcast_name, mcast_groups): 401 if mcast_name not in mcast_groups: 402 raise Exception(f'Multicast group "{mcast_name}" not present in the spec') 403 return mcast_groups[mcast_name].value 404 405 def msghdr_size(self): 406 return 16 407 408 409class GenlProtocol(NetlinkProtocol): 410 def __init__(self, family_name): 411 super().__init__(family_name, Netlink.NETLINK_GENERIC) 412 413 global genl_family_name_to_id 414 if genl_family_name_to_id is None: 415 _genl_load_families() 416 417 self.genl_family = genl_family_name_to_id[family_name] 418 self.family_id = genl_family_name_to_id[family_name]['id'] 419 420 def message(self, flags, command, version, seq=None): 421 nlmsg = self._message(self.family_id, flags, seq) 422 genlmsg = struct.pack("BBH", command, version, 0) 423 return nlmsg + genlmsg 424 425 def _decode(self, nl_msg): 426 return GenlMsg(nl_msg) 427 428 def get_mcast_id(self, mcast_name, mcast_groups): 429 if mcast_name not in self.genl_family['mcast']: 430 raise Exception(f'Multicast group "{mcast_name}" not present in the family') 431 return self.genl_family['mcast'][mcast_name] 432 433 def msghdr_size(self): 434 return super().msghdr_size() + 4 435 436 437class SpaceAttrs: 438 SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values']) 439 440 def __init__(self, attr_space, attrs, outer = None): 441 outer_scopes = outer.scopes if outer else [] 442 inner_scope = self.SpecValuesPair(attr_space, attrs) 443 self.scopes = [inner_scope] + outer_scopes 444 445 def lookup(self, name): 446 for scope in self.scopes: 447 if name in scope.spec: 448 if name in scope.values: 449 return scope.values[name] 450 spec_name = scope.spec.yaml['name'] 451 raise Exception( 452 f"No value for '{name}' in attribute space '{spec_name}'") 453 raise Exception(f"Attribute '{name}' not defined in any attribute-set") 454 455 456# 457# YNL implementation details. 458# 459 460 461class YnlFamily(SpecFamily): 462 def __init__(self, def_path, schema=None, process_unknown=False, 463 recv_size=0): 464 super().__init__(def_path, schema) 465 466 self.include_raw = False 467 self.process_unknown = process_unknown 468 469 try: 470 if self.proto == "netlink-raw": 471 self.nlproto = NetlinkProtocol(self.yaml['name'], 472 self.yaml['protonum']) 473 else: 474 self.nlproto = GenlProtocol(self.yaml['name']) 475 except KeyError: 476 raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel") 477 478 self._recv_dbg = False 479 # Note that netlink will use conservative (min) message size for 480 # the first dump recv() on the socket, our setting will only matter 481 # from the second recv() on. 482 self._recv_size = recv_size if recv_size else 131072 483 # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo) 484 # for a message, so smaller receive sizes will lead to truncation. 485 # Note that the min size for other families may be larger than 4k! 486 if self._recv_size < 4000: 487 raise ConfigError() 488 489 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num) 490 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 491 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) 492 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1) 493 494 self.async_msg_ids = set() 495 self.async_msg_queue = queue.Queue() 496 497 for msg in self.msgs.values(): 498 if msg.is_async: 499 self.async_msg_ids.add(msg.rsp_value) 500 501 for op_name, op in self.ops.items(): 502 bound_f = functools.partial(self._op, op_name) 503 setattr(self, op.ident_name, bound_f) 504 505 506 def ntf_subscribe(self, mcast_name): 507 mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups) 508 self.sock.bind((0, 0)) 509 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, 510 mcast_id) 511 512 def set_recv_dbg(self, enabled): 513 self._recv_dbg = enabled 514 515 def _recv_dbg_print(self, reply, nl_msgs): 516 if not self._recv_dbg: 517 return 518 print("Recv: read", len(reply), "bytes,", 519 len(nl_msgs.msgs), "messages", file=sys.stderr) 520 for nl_msg in nl_msgs: 521 print(" ", nl_msg, file=sys.stderr) 522 523 def _encode_enum(self, attr_spec, value): 524 enum = self.consts[attr_spec['enum']] 525 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 526 scalar = 0 527 if isinstance(value, str): 528 value = [value] 529 for single_value in value: 530 scalar += enum.entries[single_value].user_value(as_flags = True) 531 return scalar 532 else: 533 return enum.entries[value].user_value() 534 535 def _get_scalar(self, attr_spec, value): 536 try: 537 return int(value) 538 except (ValueError, TypeError) as e: 539 if 'enum' not in attr_spec: 540 raise e 541 return self._encode_enum(attr_spec, value) 542 543 def _add_attr(self, space, name, value, search_attrs): 544 try: 545 attr = self.attr_sets[space][name] 546 except KeyError: 547 raise Exception(f"Space '{space}' has no attribute '{name}'") 548 nl_type = attr.value 549 550 if attr.is_multi and isinstance(value, list): 551 attr_payload = b'' 552 for subvalue in value: 553 attr_payload += self._add_attr(space, name, subvalue, search_attrs) 554 return attr_payload 555 556 if attr["type"] == 'nest': 557 nl_type |= Netlink.NLA_F_NESTED 558 attr_payload = b'' 559 sub_space = attr['nested-attributes'] 560 sub_attrs = SpaceAttrs(self.attr_sets[sub_space], value, search_attrs) 561 for subname, subvalue in value.items(): 562 attr_payload += self._add_attr(sub_space, subname, subvalue, sub_attrs) 563 elif attr["type"] == 'flag': 564 if not value: 565 # If value is absent or false then skip attribute creation. 566 return b'' 567 attr_payload = b'' 568 elif attr["type"] == 'string': 569 attr_payload = str(value).encode('ascii') + b'\x00' 570 elif attr["type"] == 'binary': 571 if isinstance(value, bytes): 572 attr_payload = value 573 elif isinstance(value, str): 574 attr_payload = bytes.fromhex(value) 575 elif isinstance(value, dict) and attr.struct_name: 576 attr_payload = self._encode_struct(attr.struct_name, value) 577 else: 578 raise Exception(f'Unknown type for binary attribute, value: {value}') 579 elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar: 580 scalar = self._get_scalar(attr, value) 581 if attr.is_auto_scalar: 582 attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64') 583 else: 584 attr_type = attr["type"] 585 format = NlAttr.get_format(attr_type, attr.byte_order) 586 attr_payload = format.pack(scalar) 587 elif attr['type'] in "bitfield32": 588 scalar_value = self._get_scalar(attr, value["value"]) 589 scalar_selector = self._get_scalar(attr, value["selector"]) 590 attr_payload = struct.pack("II", scalar_value, scalar_selector) 591 elif attr['type'] == 'sub-message': 592 msg_format = self._resolve_selector(attr, search_attrs) 593 attr_payload = b'' 594 if msg_format.fixed_header: 595 attr_payload += self._encode_struct(msg_format.fixed_header, value) 596 if msg_format.attr_set: 597 if msg_format.attr_set in self.attr_sets: 598 nl_type |= Netlink.NLA_F_NESTED 599 sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs) 600 for subname, subvalue in value.items(): 601 attr_payload += self._add_attr(msg_format.attr_set, 602 subname, subvalue, sub_attrs) 603 else: 604 raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'") 605 else: 606 raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') 607 608 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) 609 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad 610 611 def _decode_enum(self, raw, attr_spec): 612 enum = self.consts[attr_spec['enum']] 613 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 614 i = 0 615 value = set() 616 while raw: 617 if raw & 1: 618 value.add(enum.entries_by_val[i].name) 619 raw >>= 1 620 i += 1 621 else: 622 value = enum.entries_by_val[raw].name 623 return value 624 625 def _decode_binary(self, attr, attr_spec): 626 if attr_spec.struct_name: 627 decoded = self._decode_struct(attr.raw, attr_spec.struct_name) 628 elif attr_spec.sub_type: 629 decoded = attr.as_c_array(attr_spec.sub_type) 630 else: 631 decoded = attr.as_bin() 632 if attr_spec.display_hint: 633 decoded = self._formatted_string(decoded, attr_spec.display_hint) 634 return decoded 635 636 def _decode_array_attr(self, attr, attr_spec): 637 decoded = [] 638 offset = 0 639 while offset < len(attr.raw): 640 item = NlAttr(attr.raw, offset) 641 offset += item.full_len 642 643 if attr_spec["sub-type"] == 'nest': 644 subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes']) 645 decoded.append({ item.type: subattrs }) 646 elif attr_spec["sub-type"] == 'binary': 647 subattrs = item.as_bin() 648 if attr_spec.display_hint: 649 subattrs = self._formatted_string(subattrs, attr_spec.display_hint) 650 decoded.append(subattrs) 651 elif attr_spec["sub-type"] in NlAttr.type_formats: 652 subattrs = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order) 653 if attr_spec.display_hint: 654 subattrs = self._formatted_string(subattrs, attr_spec.display_hint) 655 decoded.append(subattrs) 656 else: 657 raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}') 658 return decoded 659 660 def _decode_nest_type_value(self, attr, attr_spec): 661 decoded = {} 662 value = attr 663 for name in attr_spec['type-value']: 664 value = NlAttr(value.raw, 0) 665 decoded[name] = value.type 666 subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes']) 667 decoded.update(subattrs) 668 return decoded 669 670 def _decode_unknown(self, attr): 671 if attr.is_nest: 672 return self._decode(NlAttrs(attr.raw), None) 673 else: 674 return attr.as_bin() 675 676 def _rsp_add(self, rsp, name, is_multi, decoded): 677 if is_multi == None: 678 if name in rsp and type(rsp[name]) is not list: 679 rsp[name] = [rsp[name]] 680 is_multi = True 681 else: 682 is_multi = False 683 684 if not is_multi: 685 rsp[name] = decoded 686 elif name in rsp: 687 rsp[name].append(decoded) 688 else: 689 rsp[name] = [decoded] 690 691 def _resolve_selector(self, attr_spec, search_attrs): 692 sub_msg = attr_spec.sub_message 693 if sub_msg not in self.sub_msgs: 694 raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}") 695 sub_msg_spec = self.sub_msgs[sub_msg] 696 697 selector = attr_spec.selector 698 value = search_attrs.lookup(selector) 699 if value not in sub_msg_spec.formats: 700 raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'") 701 702 spec = sub_msg_spec.formats[value] 703 return spec 704 705 def _decode_sub_msg(self, attr, attr_spec, search_attrs): 706 msg_format = self._resolve_selector(attr_spec, search_attrs) 707 decoded = {} 708 offset = 0 709 if msg_format.fixed_header: 710 decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header)); 711 offset = self._struct_size(msg_format.fixed_header) 712 if msg_format.attr_set: 713 if msg_format.attr_set in self.attr_sets: 714 subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set) 715 decoded.update(subdict) 716 else: 717 raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'") 718 return decoded 719 720 def _decode(self, attrs, space, outer_attrs = None): 721 rsp = dict() 722 if space: 723 attr_space = self.attr_sets[space] 724 search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs) 725 726 for attr in attrs: 727 try: 728 attr_spec = attr_space.attrs_by_val[attr.type] 729 except (KeyError, UnboundLocalError): 730 if not self.process_unknown: 731 raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'") 732 attr_name = f"UnknownAttr({attr.type})" 733 self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr)) 734 continue 735 736 try: 737 if attr_spec["type"] == 'nest': 738 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs) 739 decoded = subdict 740 elif attr_spec["type"] == 'string': 741 decoded = attr.as_strz() 742 elif attr_spec["type"] == 'binary': 743 decoded = self._decode_binary(attr, attr_spec) 744 elif attr_spec["type"] == 'flag': 745 decoded = True 746 elif attr_spec.is_auto_scalar: 747 decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order) 748 elif attr_spec["type"] in NlAttr.type_formats: 749 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) 750 if 'enum' in attr_spec: 751 decoded = self._decode_enum(decoded, attr_spec) 752 elif attr_spec.display_hint: 753 decoded = self._formatted_string(decoded, attr_spec.display_hint) 754 elif attr_spec["type"] == 'indexed-array': 755 decoded = self._decode_array_attr(attr, attr_spec) 756 elif attr_spec["type"] == 'bitfield32': 757 value, selector = struct.unpack("II", attr.raw) 758 if 'enum' in attr_spec: 759 value = self._decode_enum(value, attr_spec) 760 selector = self._decode_enum(selector, attr_spec) 761 decoded = {"value": value, "selector": selector} 762 elif attr_spec["type"] == 'sub-message': 763 decoded = self._decode_sub_msg(attr, attr_spec, search_attrs) 764 elif attr_spec["type"] == 'nest-type-value': 765 decoded = self._decode_nest_type_value(attr, attr_spec) 766 else: 767 if not self.process_unknown: 768 raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') 769 decoded = self._decode_unknown(attr) 770 771 self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded) 772 except: 773 print(f"Error decoding '{attr_spec.name}' from '{space}'") 774 raise 775 776 return rsp 777 778 def _decode_extack_path(self, attrs, attr_set, offset, target): 779 for attr in attrs: 780 try: 781 attr_spec = attr_set.attrs_by_val[attr.type] 782 except KeyError: 783 raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") 784 if offset > target: 785 break 786 if offset == target: 787 return '.' + attr_spec.name 788 789 if offset + attr.full_len <= target: 790 offset += attr.full_len 791 continue 792 if attr_spec['type'] != 'nest': 793 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") 794 offset += 4 795 subpath = self._decode_extack_path(NlAttrs(attr.raw), 796 self.attr_sets[attr_spec['nested-attributes']], 797 offset, target) 798 if subpath is None: 799 return None 800 return '.' + attr_spec.name + subpath 801 802 return None 803 804 def _decode_extack(self, request, op, extack): 805 if 'bad-attr-offs' not in extack: 806 return 807 808 msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op) 809 offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header) 810 path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset, 811 extack['bad-attr-offs']) 812 if path: 813 del extack['bad-attr-offs'] 814 extack['bad-attr'] = path 815 816 def _struct_size(self, name): 817 if name: 818 members = self.consts[name].members 819 size = 0 820 for m in members: 821 if m.type in ['pad', 'binary']: 822 if m.struct: 823 size += self._struct_size(m.struct) 824 else: 825 size += m.len 826 else: 827 format = NlAttr.get_format(m.type, m.byte_order) 828 size += format.size 829 return size 830 else: 831 return 0 832 833 def _decode_struct(self, data, name): 834 members = self.consts[name].members 835 attrs = dict() 836 offset = 0 837 for m in members: 838 value = None 839 if m.type == 'pad': 840 offset += m.len 841 elif m.type == 'binary': 842 if m.struct: 843 len = self._struct_size(m.struct) 844 value = self._decode_struct(data[offset : offset + len], 845 m.struct) 846 offset += len 847 else: 848 value = data[offset : offset + m.len] 849 offset += m.len 850 else: 851 format = NlAttr.get_format(m.type, m.byte_order) 852 [ value ] = format.unpack_from(data, offset) 853 offset += format.size 854 if value is not None: 855 if m.enum: 856 value = self._decode_enum(value, m) 857 elif m.display_hint: 858 value = self._formatted_string(value, m.display_hint) 859 attrs[m.name] = value 860 return attrs 861 862 def _encode_struct(self, name, vals): 863 members = self.consts[name].members 864 attr_payload = b'' 865 for m in members: 866 value = vals.pop(m.name) if m.name in vals else None 867 if m.type == 'pad': 868 attr_payload += bytearray(m.len) 869 elif m.type == 'binary': 870 if m.struct: 871 if value is None: 872 value = dict() 873 attr_payload += self._encode_struct(m.struct, value) 874 else: 875 if value is None: 876 attr_payload += bytearray(m.len) 877 else: 878 attr_payload += bytes.fromhex(value) 879 else: 880 if value is None: 881 value = 0 882 format = NlAttr.get_format(m.type, m.byte_order) 883 attr_payload += format.pack(value) 884 return attr_payload 885 886 def _formatted_string(self, raw, display_hint): 887 if display_hint == 'mac': 888 formatted = ':'.join('%02x' % b for b in raw) 889 elif display_hint == 'hex': 890 if isinstance(raw, int): 891 formatted = hex(raw) 892 else: 893 formatted = bytes.hex(raw, ' ') 894 elif display_hint in [ 'ipv4', 'ipv6' ]: 895 formatted = format(ipaddress.ip_address(raw)) 896 elif display_hint == 'uuid': 897 formatted = str(uuid.UUID(bytes=raw)) 898 else: 899 formatted = raw 900 return formatted 901 902 def handle_ntf(self, decoded): 903 msg = dict() 904 if self.include_raw: 905 msg['raw'] = decoded 906 op = self.rsp_by_value[decoded.cmd()] 907 attrs = self._decode(decoded.raw_attrs, op.attr_set.name) 908 if op.fixed_header: 909 attrs.update(self._decode_struct(decoded.raw, op.fixed_header)) 910 911 msg['name'] = op['name'] 912 msg['msg'] = attrs 913 self.async_msg_queue.put(msg) 914 915 def check_ntf(self): 916 while True: 917 try: 918 reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT) 919 except BlockingIOError: 920 return 921 922 nms = NlMsgs(reply) 923 self._recv_dbg_print(reply, nms) 924 for nl_msg in nms: 925 if nl_msg.error: 926 print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) 927 print(nl_msg) 928 continue 929 if nl_msg.done: 930 print("Netlink done while checking for ntf!?") 931 continue 932 933 decoded = self.nlproto.decode(self, nl_msg, None) 934 if decoded.cmd() not in self.async_msg_ids: 935 print("Unexpected msg id while checking for ntf", decoded) 936 continue 937 938 self.handle_ntf(decoded) 939 940 def poll_ntf(self, duration=None): 941 start_time = time.time() 942 selector = selectors.DefaultSelector() 943 selector.register(self.sock, selectors.EVENT_READ) 944 945 while True: 946 try: 947 yield self.async_msg_queue.get_nowait() 948 except queue.Empty: 949 if duration is not None: 950 timeout = start_time + duration - time.time() 951 if timeout <= 0: 952 return 953 else: 954 timeout = None 955 events = selector.select(timeout) 956 if events: 957 self.check_ntf() 958 959 def operation_do_attributes(self, name): 960 """ 961 For a given operation name, find and return a supported 962 set of attributes (as a dict). 963 """ 964 op = self.find_operation(name) 965 if not op: 966 return None 967 968 return op['do']['request']['attributes'].copy() 969 970 def _encode_message(self, op, vals, flags, req_seq): 971 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK 972 for flag in flags or []: 973 nl_flags |= flag 974 975 msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq) 976 if op.fixed_header: 977 msg += self._encode_struct(op.fixed_header, vals) 978 search_attrs = SpaceAttrs(op.attr_set, vals) 979 for name, value in vals.items(): 980 msg += self._add_attr(op.attr_set.name, name, value, search_attrs) 981 msg = _genl_msg_finalize(msg) 982 return msg 983 984 def _ops(self, ops): 985 reqs_by_seq = {} 986 req_seq = random.randint(1024, 65535) 987 payload = b'' 988 for (method, vals, flags) in ops: 989 op = self.ops[method] 990 msg = self._encode_message(op, vals, flags, req_seq) 991 reqs_by_seq[req_seq] = (op, msg, flags) 992 payload += msg 993 req_seq += 1 994 995 self.sock.send(payload, 0) 996 997 done = False 998 rsp = [] 999 op_rsp = [] 1000 while not done: 1001 reply = self.sock.recv(self._recv_size) 1002 nms = NlMsgs(reply, attr_space=op.attr_set) 1003 self._recv_dbg_print(reply, nms) 1004 for nl_msg in nms: 1005 if nl_msg.nl_seq in reqs_by_seq: 1006 (op, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq] 1007 if nl_msg.extack: 1008 self._decode_extack(req_msg, op, nl_msg.extack) 1009 else: 1010 op = None 1011 req_flags = [] 1012 1013 if nl_msg.error: 1014 raise NlError(nl_msg) 1015 if nl_msg.done: 1016 if nl_msg.extack: 1017 print("Netlink warning:") 1018 print(nl_msg) 1019 1020 if Netlink.NLM_F_DUMP in req_flags: 1021 rsp.append(op_rsp) 1022 elif not op_rsp: 1023 rsp.append(None) 1024 elif len(op_rsp) == 1: 1025 rsp.append(op_rsp[0]) 1026 else: 1027 rsp.append(op_rsp) 1028 op_rsp = [] 1029 1030 del reqs_by_seq[nl_msg.nl_seq] 1031 done = len(reqs_by_seq) == 0 1032 break 1033 1034 decoded = self.nlproto.decode(self, nl_msg, op) 1035 1036 # Check if this is a reply to our request 1037 if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value: 1038 if decoded.cmd() in self.async_msg_ids: 1039 self.handle_ntf(decoded) 1040 continue 1041 else: 1042 print('Unexpected message: ' + repr(decoded)) 1043 continue 1044 1045 rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name) 1046 if op.fixed_header: 1047 rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header)) 1048 op_rsp.append(rsp_msg) 1049 1050 return rsp 1051 1052 def _op(self, method, vals, flags=None, dump=False): 1053 req_flags = flags or [] 1054 if dump: 1055 req_flags.append(Netlink.NLM_F_DUMP) 1056 1057 ops = [(method, vals, req_flags)] 1058 return self._ops(ops)[0] 1059 1060 def do(self, method, vals, flags=None): 1061 return self._op(method, vals, flags) 1062 1063 def dump(self, method, vals): 1064 return self._op(method, vals, dump=True) 1065 1066 def do_multi(self, ops): 1067 return self._ops(ops) 1068