1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause 2 3from collections import namedtuple 4from enum import Enum 5import functools 6import os 7import random 8import socket 9import struct 10from struct import Struct 11import sys 12import yaml 13import ipaddress 14import uuid 15import queue 16import time 17 18from .nlspec import SpecFamily 19 20# 21# Generic Netlink code which should really be in some library, but I can't quickly find one. 22# 23 24 25class Netlink: 26 # Netlink socket 27 SOL_NETLINK = 270 28 29 NETLINK_ADD_MEMBERSHIP = 1 30 NETLINK_CAP_ACK = 10 31 NETLINK_EXT_ACK = 11 32 NETLINK_GET_STRICT_CHK = 12 33 34 # Netlink message 35 NLMSG_ERROR = 2 36 NLMSG_DONE = 3 37 38 NLM_F_REQUEST = 1 39 NLM_F_ACK = 4 40 NLM_F_ROOT = 0x100 41 NLM_F_MATCH = 0x200 42 43 NLM_F_REPLACE = 0x100 44 NLM_F_EXCL = 0x200 45 NLM_F_CREATE = 0x400 46 NLM_F_APPEND = 0x800 47 48 NLM_F_CAPPED = 0x100 49 NLM_F_ACK_TLVS = 0x200 50 51 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH 52 53 NLA_F_NESTED = 0x8000 54 NLA_F_NET_BYTEORDER = 0x4000 55 56 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER 57 58 # Genetlink defines 59 NETLINK_GENERIC = 16 60 61 GENL_ID_CTRL = 0x10 62 63 # nlctrl 64 CTRL_CMD_GETFAMILY = 3 65 66 CTRL_ATTR_FAMILY_ID = 1 67 CTRL_ATTR_FAMILY_NAME = 2 68 CTRL_ATTR_MAXATTR = 5 69 CTRL_ATTR_MCAST_GROUPS = 7 70 71 CTRL_ATTR_MCAST_GRP_NAME = 1 72 CTRL_ATTR_MCAST_GRP_ID = 2 73 74 # Extack types 75 NLMSGERR_ATTR_MSG = 1 76 NLMSGERR_ATTR_OFFS = 2 77 NLMSGERR_ATTR_COOKIE = 3 78 NLMSGERR_ATTR_POLICY = 4 79 NLMSGERR_ATTR_MISS_TYPE = 5 80 NLMSGERR_ATTR_MISS_NEST = 6 81 82 # Policy types 83 NL_POLICY_TYPE_ATTR_TYPE = 1 84 NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2 85 NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3 86 NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4 87 NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5 88 NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6 89 NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7 90 NL_POLICY_TYPE_ATTR_POLICY_IDX = 8 91 NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9 92 NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10 93 NL_POLICY_TYPE_ATTR_PAD = 11 94 NL_POLICY_TYPE_ATTR_MASK = 12 95 96 AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64', 97 's8', 's16', 's32', 's64', 98 'binary', 'string', 'nul-string', 99 'nested', 'nested-array', 100 'bitfield32', 'sint', 'uint']) 101 102class NlError(Exception): 103 def __init__(self, nl_msg): 104 self.nl_msg = nl_msg 105 self.error = -nl_msg.error 106 107 def __str__(self): 108 return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}" 109 110 111class ConfigError(Exception): 112 pass 113 114 115class NlAttr: 116 ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) 117 type_formats = { 118 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), 119 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), 120 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")), 121 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")), 122 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")), 123 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")), 124 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")), 125 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q")) 126 } 127 128 def __init__(self, raw, offset): 129 self._len, self._type = struct.unpack("HH", raw[offset : offset + 4]) 130 self.type = self._type & ~Netlink.NLA_TYPE_MASK 131 self.is_nest = self._type & Netlink.NLA_F_NESTED 132 self.payload_len = self._len 133 self.full_len = (self.payload_len + 3) & ~3 134 self.raw = raw[offset + 4 : offset + self.payload_len] 135 136 @classmethod 137 def get_format(cls, attr_type, byte_order=None): 138 format = cls.type_formats[attr_type] 139 if byte_order: 140 return format.big if byte_order == "big-endian" \ 141 else format.little 142 return format.native 143 144 def as_scalar(self, attr_type, byte_order=None): 145 format = self.get_format(attr_type, byte_order) 146 return format.unpack(self.raw)[0] 147 148 def as_auto_scalar(self, attr_type, byte_order=None): 149 if len(self.raw) != 4 and len(self.raw) != 8: 150 raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}") 151 real_type = attr_type[0] + str(len(self.raw) * 8) 152 format = self.get_format(real_type, byte_order) 153 return format.unpack(self.raw)[0] 154 155 def as_strz(self): 156 return self.raw.decode('ascii')[:-1] 157 158 def as_bin(self): 159 return self.raw 160 161 def as_c_array(self, type): 162 format = self.get_format(type) 163 return [ x[0] for x in format.iter_unpack(self.raw) ] 164 165 def __repr__(self): 166 return f"[type:{self.type} len:{self._len}] {self.raw}" 167 168 169class NlAttrs: 170 def __init__(self, msg, offset=0): 171 self.attrs = [] 172 173 while offset < len(msg): 174 attr = NlAttr(msg, offset) 175 offset += attr.full_len 176 self.attrs.append(attr) 177 178 def __iter__(self): 179 yield from self.attrs 180 181 def __repr__(self): 182 msg = '' 183 for a in self.attrs: 184 if msg: 185 msg += '\n' 186 msg += repr(a) 187 return msg 188 189 190class NlMsg: 191 def __init__(self, msg, offset, attr_space=None): 192 self.hdr = msg[offset : offset + 16] 193 194 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ 195 struct.unpack("IHHII", self.hdr) 196 197 self.raw = msg[offset + 16 : offset + self.nl_len] 198 199 self.error = 0 200 self.done = 0 201 202 extack_off = None 203 if self.nl_type == Netlink.NLMSG_ERROR: 204 self.error = struct.unpack("i", self.raw[0:4])[0] 205 self.done = 1 206 extack_off = 20 207 elif self.nl_type == Netlink.NLMSG_DONE: 208 self.error = struct.unpack("i", self.raw[0:4])[0] 209 self.done = 1 210 extack_off = 4 211 212 self.extack = None 213 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: 214 self.extack = dict() 215 extack_attrs = NlAttrs(self.raw[extack_off:]) 216 for extack in extack_attrs: 217 if extack.type == Netlink.NLMSGERR_ATTR_MSG: 218 self.extack['msg'] = extack.as_strz() 219 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: 220 self.extack['miss-type'] = extack.as_scalar('u32') 221 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: 222 self.extack['miss-nest'] = extack.as_scalar('u32') 223 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: 224 self.extack['bad-attr-offs'] = extack.as_scalar('u32') 225 elif extack.type == Netlink.NLMSGERR_ATTR_POLICY: 226 self.extack['policy'] = self._decode_policy(extack.raw) 227 else: 228 if 'unknown' not in self.extack: 229 self.extack['unknown'] = [] 230 self.extack['unknown'].append(extack) 231 232 if attr_space: 233 # We don't have the ability to parse nests yet, so only do global 234 if 'miss-type' in self.extack and 'miss-nest' not in self.extack: 235 miss_type = self.extack['miss-type'] 236 if miss_type in attr_space.attrs_by_val: 237 spec = attr_space.attrs_by_val[miss_type] 238 self.extack['miss-type'] = spec['name'] 239 if 'doc' in spec: 240 self.extack['miss-type-doc'] = spec['doc'] 241 242 def _decode_policy(self, raw): 243 policy = {} 244 for attr in NlAttrs(raw): 245 if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE: 246 type = attr.as_scalar('u32') 247 policy['type'] = Netlink.AttrType(type).name 248 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S: 249 policy['min-value'] = attr.as_scalar('s64') 250 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S: 251 policy['max-value'] = attr.as_scalar('s64') 252 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U: 253 policy['min-value'] = attr.as_scalar('u64') 254 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U: 255 policy['max-value'] = attr.as_scalar('u64') 256 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH: 257 policy['min-length'] = attr.as_scalar('u32') 258 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH: 259 policy['max-length'] = attr.as_scalar('u32') 260 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK: 261 policy['bitfield32-mask'] = attr.as_scalar('u32') 262 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK: 263 policy['mask'] = attr.as_scalar('u64') 264 return policy 265 266 def cmd(self): 267 return self.nl_type 268 269 def __repr__(self): 270 msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}" 271 if self.error: 272 msg += '\n\terror: ' + str(self.error) 273 if self.extack: 274 msg += '\n\textack: ' + repr(self.extack) 275 return msg 276 277 278class NlMsgs: 279 def __init__(self, data, attr_space=None): 280 self.msgs = [] 281 282 offset = 0 283 while offset < len(data): 284 msg = NlMsg(data, offset, attr_space=attr_space) 285 offset += msg.nl_len 286 self.msgs.append(msg) 287 288 def __iter__(self): 289 yield from self.msgs 290 291 292genl_family_name_to_id = None 293 294 295def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): 296 # we prepend length in _genl_msg_finalize() 297 if seq is None: 298 seq = random.randint(1, 1024) 299 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 300 genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0) 301 return nlmsg + genlmsg 302 303 304def _genl_msg_finalize(msg): 305 return struct.pack("I", len(msg) + 4) + msg 306 307 308def _genl_load_families(): 309 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: 310 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 311 312 msg = _genl_msg(Netlink.GENL_ID_CTRL, 313 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, 314 Netlink.CTRL_CMD_GETFAMILY, 1) 315 msg = _genl_msg_finalize(msg) 316 317 sock.send(msg, 0) 318 319 global genl_family_name_to_id 320 genl_family_name_to_id = dict() 321 322 while True: 323 reply = sock.recv(128 * 1024) 324 nms = NlMsgs(reply) 325 for nl_msg in nms: 326 if nl_msg.error: 327 print("Netlink error:", nl_msg.error) 328 return 329 if nl_msg.done: 330 return 331 332 gm = GenlMsg(nl_msg) 333 fam = dict() 334 for attr in NlAttrs(gm.raw): 335 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: 336 fam['id'] = attr.as_scalar('u16') 337 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: 338 fam['name'] = attr.as_strz() 339 elif attr.type == Netlink.CTRL_ATTR_MAXATTR: 340 fam['maxattr'] = attr.as_scalar('u32') 341 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: 342 fam['mcast'] = dict() 343 for entry in NlAttrs(attr.raw): 344 mcast_name = None 345 mcast_id = None 346 for entry_attr in NlAttrs(entry.raw): 347 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: 348 mcast_name = entry_attr.as_strz() 349 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: 350 mcast_id = entry_attr.as_scalar('u32') 351 if mcast_name and mcast_id is not None: 352 fam['mcast'][mcast_name] = mcast_id 353 if 'name' in fam and 'id' in fam: 354 genl_family_name_to_id[fam['name']] = fam 355 356 357class GenlMsg: 358 def __init__(self, nl_msg): 359 self.nl = nl_msg 360 self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0) 361 self.raw = nl_msg.raw[4:] 362 363 def cmd(self): 364 return self.genl_cmd 365 366 def __repr__(self): 367 msg = repr(self.nl) 368 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" 369 for a in self.raw_attrs: 370 msg += '\t\t' + repr(a) + '\n' 371 return msg 372 373 374class NetlinkProtocol: 375 def __init__(self, family_name, proto_num): 376 self.family_name = family_name 377 self.proto_num = proto_num 378 379 def _message(self, nl_type, nl_flags, seq=None): 380 if seq is None: 381 seq = random.randint(1, 1024) 382 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 383 return nlmsg 384 385 def message(self, flags, command, version, seq=None): 386 return self._message(command, flags, seq) 387 388 def _decode(self, nl_msg): 389 return nl_msg 390 391 def decode(self, ynl, nl_msg, op): 392 msg = self._decode(nl_msg) 393 if op is None: 394 op = ynl.rsp_by_value[msg.cmd()] 395 fixed_header_size = ynl._struct_size(op.fixed_header) 396 msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size) 397 return msg 398 399 def get_mcast_id(self, mcast_name, mcast_groups): 400 if mcast_name not in mcast_groups: 401 raise Exception(f'Multicast group "{mcast_name}" not present in the spec') 402 return mcast_groups[mcast_name].value 403 404 def msghdr_size(self): 405 return 16 406 407 408class GenlProtocol(NetlinkProtocol): 409 def __init__(self, family_name): 410 super().__init__(family_name, Netlink.NETLINK_GENERIC) 411 412 global genl_family_name_to_id 413 if genl_family_name_to_id is None: 414 _genl_load_families() 415 416 self.genl_family = genl_family_name_to_id[family_name] 417 self.family_id = genl_family_name_to_id[family_name]['id'] 418 419 def message(self, flags, command, version, seq=None): 420 nlmsg = self._message(self.family_id, flags, seq) 421 genlmsg = struct.pack("BBH", command, version, 0) 422 return nlmsg + genlmsg 423 424 def _decode(self, nl_msg): 425 return GenlMsg(nl_msg) 426 427 def get_mcast_id(self, mcast_name, mcast_groups): 428 if mcast_name not in self.genl_family['mcast']: 429 raise Exception(f'Multicast group "{mcast_name}" not present in the family') 430 return self.genl_family['mcast'][mcast_name] 431 432 def msghdr_size(self): 433 return super().msghdr_size() + 4 434 435 436class SpaceAttrs: 437 SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values']) 438 439 def __init__(self, attr_space, attrs, outer = None): 440 outer_scopes = outer.scopes if outer else [] 441 inner_scope = self.SpecValuesPair(attr_space, attrs) 442 self.scopes = [inner_scope] + outer_scopes 443 444 def lookup(self, name): 445 for scope in self.scopes: 446 if name in scope.spec: 447 if name in scope.values: 448 return scope.values[name] 449 spec_name = scope.spec.yaml['name'] 450 raise Exception( 451 f"No value for '{name}' in attribute space '{spec_name}'") 452 raise Exception(f"Attribute '{name}' not defined in any attribute-set") 453 454 455# 456# YNL implementation details. 457# 458 459 460class YnlFamily(SpecFamily): 461 def __init__(self, def_path, schema=None, process_unknown=False, 462 recv_size=0): 463 super().__init__(def_path, schema) 464 465 self.include_raw = False 466 self.process_unknown = process_unknown 467 468 try: 469 if self.proto == "netlink-raw": 470 self.nlproto = NetlinkProtocol(self.yaml['name'], 471 self.yaml['protonum']) 472 else: 473 self.nlproto = GenlProtocol(self.yaml['name']) 474 except KeyError: 475 raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel") 476 477 self._recv_dbg = False 478 # Note that netlink will use conservative (min) message size for 479 # the first dump recv() on the socket, our setting will only matter 480 # from the second recv() on. 481 self._recv_size = recv_size if recv_size else 131072 482 # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo) 483 # for a message, so smaller receive sizes will lead to truncation. 484 # Note that the min size for other families may be larger than 4k! 485 if self._recv_size < 4000: 486 raise ConfigError() 487 488 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num) 489 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 490 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) 491 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1) 492 493 self.async_msg_ids = set() 494 self.async_msg_queue = queue.Queue() 495 496 for msg in self.msgs.values(): 497 if msg.is_async: 498 self.async_msg_ids.add(msg.rsp_value) 499 500 for op_name, op in self.ops.items(): 501 bound_f = functools.partial(self._op, op_name) 502 setattr(self, op.ident_name, bound_f) 503 504 505 def ntf_subscribe(self, mcast_name): 506 mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups) 507 self.sock.bind((0, 0)) 508 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, 509 mcast_id) 510 511 def set_recv_dbg(self, enabled): 512 self._recv_dbg = enabled 513 514 def _recv_dbg_print(self, reply, nl_msgs): 515 if not self._recv_dbg: 516 return 517 print("Recv: read", len(reply), "bytes,", 518 len(nl_msgs.msgs), "messages", file=sys.stderr) 519 for nl_msg in nl_msgs: 520 print(" ", nl_msg, file=sys.stderr) 521 522 def _encode_enum(self, attr_spec, value): 523 enum = self.consts[attr_spec['enum']] 524 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 525 scalar = 0 526 if isinstance(value, str): 527 value = [value] 528 for single_value in value: 529 scalar += enum.entries[single_value].user_value(as_flags = True) 530 return scalar 531 else: 532 return enum.entries[value].user_value() 533 534 def _get_scalar(self, attr_spec, value): 535 try: 536 return int(value) 537 except (ValueError, TypeError) as e: 538 if 'enum' not in attr_spec: 539 raise e 540 return self._encode_enum(attr_spec, value) 541 542 def _add_attr(self, space, name, value, search_attrs): 543 try: 544 attr = self.attr_sets[space][name] 545 except KeyError: 546 raise Exception(f"Space '{space}' has no attribute '{name}'") 547 nl_type = attr.value 548 549 if attr.is_multi and isinstance(value, list): 550 attr_payload = b'' 551 for subvalue in value: 552 attr_payload += self._add_attr(space, name, subvalue, search_attrs) 553 return attr_payload 554 555 if attr["type"] == 'nest': 556 nl_type |= Netlink.NLA_F_NESTED 557 attr_payload = b'' 558 sub_attrs = SpaceAttrs(self.attr_sets[space], value, search_attrs) 559 for subname, subvalue in value.items(): 560 attr_payload += self._add_attr(attr['nested-attributes'], 561 subname, subvalue, sub_attrs) 562 elif attr["type"] == 'flag': 563 if not value: 564 # If value is absent or false then skip attribute creation. 565 return b'' 566 attr_payload = b'' 567 elif attr["type"] == 'string': 568 attr_payload = str(value).encode('ascii') + b'\x00' 569 elif attr["type"] == 'binary': 570 if isinstance(value, bytes): 571 attr_payload = value 572 elif isinstance(value, str): 573 attr_payload = bytes.fromhex(value) 574 elif isinstance(value, dict) and attr.struct_name: 575 attr_payload = self._encode_struct(attr.struct_name, value) 576 else: 577 raise Exception(f'Unknown type for binary attribute, value: {value}') 578 elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar: 579 scalar = self._get_scalar(attr, value) 580 if attr.is_auto_scalar: 581 attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64') 582 else: 583 attr_type = attr["type"] 584 format = NlAttr.get_format(attr_type, attr.byte_order) 585 attr_payload = format.pack(scalar) 586 elif attr['type'] in "bitfield32": 587 scalar_value = self._get_scalar(attr, value["value"]) 588 scalar_selector = self._get_scalar(attr, value["selector"]) 589 attr_payload = struct.pack("II", scalar_value, scalar_selector) 590 elif attr['type'] == 'sub-message': 591 msg_format = self._resolve_selector(attr, search_attrs) 592 attr_payload = b'' 593 if msg_format.fixed_header: 594 attr_payload += self._encode_struct(msg_format.fixed_header, value) 595 if msg_format.attr_set: 596 if msg_format.attr_set in self.attr_sets: 597 nl_type |= Netlink.NLA_F_NESTED 598 sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs) 599 for subname, subvalue in value.items(): 600 attr_payload += self._add_attr(msg_format.attr_set, 601 subname, subvalue, sub_attrs) 602 else: 603 raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'") 604 else: 605 raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') 606 607 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) 608 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad 609 610 def _decode_enum(self, raw, attr_spec): 611 enum = self.consts[attr_spec['enum']] 612 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 613 i = 0 614 value = set() 615 while raw: 616 if raw & 1: 617 value.add(enum.entries_by_val[i].name) 618 raw >>= 1 619 i += 1 620 else: 621 value = enum.entries_by_val[raw].name 622 return value 623 624 def _decode_binary(self, attr, attr_spec): 625 if attr_spec.struct_name: 626 decoded = self._decode_struct(attr.raw, attr_spec.struct_name) 627 elif attr_spec.sub_type: 628 decoded = attr.as_c_array(attr_spec.sub_type) 629 else: 630 decoded = attr.as_bin() 631 if attr_spec.display_hint: 632 decoded = self._formatted_string(decoded, attr_spec.display_hint) 633 return decoded 634 635 def _decode_array_attr(self, attr, attr_spec): 636 decoded = [] 637 offset = 0 638 while offset < len(attr.raw): 639 item = NlAttr(attr.raw, offset) 640 offset += item.full_len 641 642 if attr_spec["sub-type"] == 'nest': 643 subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes']) 644 decoded.append({ item.type: subattrs }) 645 elif attr_spec["sub-type"] == 'binary': 646 subattrs = item.as_bin() 647 if attr_spec.display_hint: 648 subattrs = self._formatted_string(subattrs, attr_spec.display_hint) 649 decoded.append(subattrs) 650 elif attr_spec["sub-type"] in NlAttr.type_formats: 651 subattrs = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order) 652 if attr_spec.display_hint: 653 subattrs = self._formatted_string(subattrs, attr_spec.display_hint) 654 decoded.append(subattrs) 655 else: 656 raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}') 657 return decoded 658 659 def _decode_nest_type_value(self, attr, attr_spec): 660 decoded = {} 661 value = attr 662 for name in attr_spec['type-value']: 663 value = NlAttr(value.raw, 0) 664 decoded[name] = value.type 665 subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes']) 666 decoded.update(subattrs) 667 return decoded 668 669 def _decode_unknown(self, attr): 670 if attr.is_nest: 671 return self._decode(NlAttrs(attr.raw), None) 672 else: 673 return attr.as_bin() 674 675 def _rsp_add(self, rsp, name, is_multi, decoded): 676 if is_multi == None: 677 if name in rsp and type(rsp[name]) is not list: 678 rsp[name] = [rsp[name]] 679 is_multi = True 680 else: 681 is_multi = False 682 683 if not is_multi: 684 rsp[name] = decoded 685 elif name in rsp: 686 rsp[name].append(decoded) 687 else: 688 rsp[name] = [decoded] 689 690 def _resolve_selector(self, attr_spec, search_attrs): 691 sub_msg = attr_spec.sub_message 692 if sub_msg not in self.sub_msgs: 693 raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}") 694 sub_msg_spec = self.sub_msgs[sub_msg] 695 696 selector = attr_spec.selector 697 value = search_attrs.lookup(selector) 698 if value not in sub_msg_spec.formats: 699 raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'") 700 701 spec = sub_msg_spec.formats[value] 702 return spec 703 704 def _decode_sub_msg(self, attr, attr_spec, search_attrs): 705 msg_format = self._resolve_selector(attr_spec, search_attrs) 706 decoded = {} 707 offset = 0 708 if msg_format.fixed_header: 709 decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header)); 710 offset = self._struct_size(msg_format.fixed_header) 711 if msg_format.attr_set: 712 if msg_format.attr_set in self.attr_sets: 713 subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set) 714 decoded.update(subdict) 715 else: 716 raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'") 717 return decoded 718 719 def _decode(self, attrs, space, outer_attrs = None): 720 rsp = dict() 721 if space: 722 attr_space = self.attr_sets[space] 723 search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs) 724 725 for attr in attrs: 726 try: 727 attr_spec = attr_space.attrs_by_val[attr.type] 728 except (KeyError, UnboundLocalError): 729 if not self.process_unknown: 730 raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'") 731 attr_name = f"UnknownAttr({attr.type})" 732 self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr)) 733 continue 734 735 if attr_spec["type"] == 'nest': 736 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs) 737 decoded = subdict 738 elif attr_spec["type"] == 'string': 739 decoded = attr.as_strz() 740 elif attr_spec["type"] == 'binary': 741 decoded = self._decode_binary(attr, attr_spec) 742 elif attr_spec["type"] == 'flag': 743 decoded = True 744 elif attr_spec.is_auto_scalar: 745 decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order) 746 elif attr_spec["type"] in NlAttr.type_formats: 747 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) 748 if 'enum' in attr_spec: 749 decoded = self._decode_enum(decoded, attr_spec) 750 elif attr_spec.display_hint: 751 decoded = self._formatted_string(decoded, attr_spec.display_hint) 752 elif attr_spec["type"] == 'indexed-array': 753 decoded = self._decode_array_attr(attr, attr_spec) 754 elif attr_spec["type"] == 'bitfield32': 755 value, selector = struct.unpack("II", attr.raw) 756 if 'enum' in attr_spec: 757 value = self._decode_enum(value, attr_spec) 758 selector = self._decode_enum(selector, attr_spec) 759 decoded = {"value": value, "selector": selector} 760 elif attr_spec["type"] == 'sub-message': 761 decoded = self._decode_sub_msg(attr, attr_spec, search_attrs) 762 elif attr_spec["type"] == 'nest-type-value': 763 decoded = self._decode_nest_type_value(attr, attr_spec) 764 else: 765 if not self.process_unknown: 766 raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') 767 decoded = self._decode_unknown(attr) 768 769 self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded) 770 771 return rsp 772 773 def _decode_extack_path(self, attrs, attr_set, offset, target): 774 for attr in attrs: 775 try: 776 attr_spec = attr_set.attrs_by_val[attr.type] 777 except KeyError: 778 raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") 779 if offset > target: 780 break 781 if offset == target: 782 return '.' + attr_spec.name 783 784 if offset + attr.full_len <= target: 785 offset += attr.full_len 786 continue 787 if attr_spec['type'] != 'nest': 788 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") 789 offset += 4 790 subpath = self._decode_extack_path(NlAttrs(attr.raw), 791 self.attr_sets[attr_spec['nested-attributes']], 792 offset, target) 793 if subpath is None: 794 return None 795 return '.' + attr_spec.name + subpath 796 797 return None 798 799 def _decode_extack(self, request, op, extack): 800 if 'bad-attr-offs' not in extack: 801 return 802 803 msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op) 804 offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header) 805 path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset, 806 extack['bad-attr-offs']) 807 if path: 808 del extack['bad-attr-offs'] 809 extack['bad-attr'] = path 810 811 def _struct_size(self, name): 812 if name: 813 members = self.consts[name].members 814 size = 0 815 for m in members: 816 if m.type in ['pad', 'binary']: 817 if m.struct: 818 size += self._struct_size(m.struct) 819 else: 820 size += m.len 821 else: 822 format = NlAttr.get_format(m.type, m.byte_order) 823 size += format.size 824 return size 825 else: 826 return 0 827 828 def _decode_struct(self, data, name): 829 members = self.consts[name].members 830 attrs = dict() 831 offset = 0 832 for m in members: 833 value = None 834 if m.type == 'pad': 835 offset += m.len 836 elif m.type == 'binary': 837 if m.struct: 838 len = self._struct_size(m.struct) 839 value = self._decode_struct(data[offset : offset + len], 840 m.struct) 841 offset += len 842 else: 843 value = data[offset : offset + m.len] 844 offset += m.len 845 else: 846 format = NlAttr.get_format(m.type, m.byte_order) 847 [ value ] = format.unpack_from(data, offset) 848 offset += format.size 849 if value is not None: 850 if m.enum: 851 value = self._decode_enum(value, m) 852 elif m.display_hint: 853 value = self._formatted_string(value, m.display_hint) 854 attrs[m.name] = value 855 return attrs 856 857 def _encode_struct(self, name, vals): 858 members = self.consts[name].members 859 attr_payload = b'' 860 for m in members: 861 value = vals.pop(m.name) if m.name in vals else None 862 if m.type == 'pad': 863 attr_payload += bytearray(m.len) 864 elif m.type == 'binary': 865 if m.struct: 866 if value is None: 867 value = dict() 868 attr_payload += self._encode_struct(m.struct, value) 869 else: 870 if value is None: 871 attr_payload += bytearray(m.len) 872 else: 873 attr_payload += bytes.fromhex(value) 874 else: 875 if value is None: 876 value = 0 877 format = NlAttr.get_format(m.type, m.byte_order) 878 attr_payload += format.pack(value) 879 return attr_payload 880 881 def _formatted_string(self, raw, display_hint): 882 if display_hint == 'mac': 883 formatted = ':'.join('%02x' % b for b in raw) 884 elif display_hint == 'hex': 885 if isinstance(raw, int): 886 formatted = hex(raw) 887 else: 888 formatted = bytes.hex(raw, ' ') 889 elif display_hint in [ 'ipv4', 'ipv6' ]: 890 formatted = format(ipaddress.ip_address(raw)) 891 elif display_hint == 'uuid': 892 formatted = str(uuid.UUID(bytes=raw)) 893 else: 894 formatted = raw 895 return formatted 896 897 def handle_ntf(self, decoded): 898 msg = dict() 899 if self.include_raw: 900 msg['raw'] = decoded 901 op = self.rsp_by_value[decoded.cmd()] 902 attrs = self._decode(decoded.raw_attrs, op.attr_set.name) 903 if op.fixed_header: 904 attrs.update(self._decode_struct(decoded.raw, op.fixed_header)) 905 906 msg['name'] = op['name'] 907 msg['msg'] = attrs 908 self.async_msg_queue.put(msg) 909 910 def check_ntf(self, interval=0.1): 911 while True: 912 try: 913 reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT) 914 nms = NlMsgs(reply) 915 self._recv_dbg_print(reply, nms) 916 for nl_msg in nms: 917 if nl_msg.error: 918 print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) 919 print(nl_msg) 920 continue 921 if nl_msg.done: 922 print("Netlink done while checking for ntf!?") 923 continue 924 925 decoded = self.nlproto.decode(self, nl_msg, None) 926 if decoded.cmd() not in self.async_msg_ids: 927 print("Unexpected msg id while checking for ntf", decoded) 928 continue 929 930 self.handle_ntf(decoded) 931 except BlockingIOError: 932 pass 933 934 try: 935 yield self.async_msg_queue.get_nowait() 936 except queue.Empty: 937 try: 938 time.sleep(interval) 939 except KeyboardInterrupt: 940 return 941 942 def operation_do_attributes(self, name): 943 """ 944 For a given operation name, find and return a supported 945 set of attributes (as a dict). 946 """ 947 op = self.find_operation(name) 948 if not op: 949 return None 950 951 return op['do']['request']['attributes'].copy() 952 953 def _encode_message(self, op, vals, flags, req_seq): 954 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK 955 for flag in flags or []: 956 nl_flags |= flag 957 958 msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq) 959 if op.fixed_header: 960 msg += self._encode_struct(op.fixed_header, vals) 961 search_attrs = SpaceAttrs(op.attr_set, vals) 962 for name, value in vals.items(): 963 msg += self._add_attr(op.attr_set.name, name, value, search_attrs) 964 msg = _genl_msg_finalize(msg) 965 return msg 966 967 def _ops(self, ops): 968 reqs_by_seq = {} 969 req_seq = random.randint(1024, 65535) 970 payload = b'' 971 for (method, vals, flags) in ops: 972 op = self.ops[method] 973 msg = self._encode_message(op, vals, flags, req_seq) 974 reqs_by_seq[req_seq] = (op, msg, flags) 975 payload += msg 976 req_seq += 1 977 978 self.sock.send(payload, 0) 979 980 done = False 981 rsp = [] 982 op_rsp = [] 983 while not done: 984 reply = self.sock.recv(self._recv_size) 985 nms = NlMsgs(reply, attr_space=op.attr_set) 986 self._recv_dbg_print(reply, nms) 987 for nl_msg in nms: 988 if nl_msg.nl_seq in reqs_by_seq: 989 (op, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq] 990 if nl_msg.extack: 991 self._decode_extack(req_msg, op, nl_msg.extack) 992 else: 993 op = None 994 req_flags = [] 995 996 if nl_msg.error: 997 raise NlError(nl_msg) 998 if nl_msg.done: 999 if nl_msg.extack: 1000 print("Netlink warning:") 1001 print(nl_msg) 1002 1003 if Netlink.NLM_F_DUMP in req_flags: 1004 rsp.append(op_rsp) 1005 elif not op_rsp: 1006 rsp.append(None) 1007 elif len(op_rsp) == 1: 1008 rsp.append(op_rsp[0]) 1009 else: 1010 rsp.append(op_rsp) 1011 op_rsp = [] 1012 1013 del reqs_by_seq[nl_msg.nl_seq] 1014 done = len(reqs_by_seq) == 0 1015 break 1016 1017 decoded = self.nlproto.decode(self, nl_msg, op) 1018 1019 # Check if this is a reply to our request 1020 if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value: 1021 if decoded.cmd() in self.async_msg_ids: 1022 self.handle_ntf(decoded) 1023 continue 1024 else: 1025 print('Unexpected message: ' + repr(decoded)) 1026 continue 1027 1028 rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name) 1029 if op.fixed_header: 1030 rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header)) 1031 op_rsp.append(rsp_msg) 1032 1033 return rsp 1034 1035 def _op(self, method, vals, flags=None, dump=False): 1036 req_flags = flags or [] 1037 if dump: 1038 req_flags.append(Netlink.NLM_F_DUMP) 1039 1040 ops = [(method, vals, req_flags)] 1041 return self._ops(ops)[0] 1042 1043 def do(self, method, vals, flags=None): 1044 return self._op(method, vals, flags) 1045 1046 def dump(self, method, vals): 1047 return self._op(method, vals, dump=True) 1048 1049 def do_multi(self, ops): 1050 return self._ops(ops) 1051