1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause 2 3from collections import namedtuple 4from enum import Enum 5import functools 6import os 7import random 8import socket 9import struct 10from struct import Struct 11import sys 12import yaml 13import ipaddress 14import uuid 15import queue 16import selectors 17import time 18 19from .nlspec import SpecFamily 20 21# 22# Generic Netlink code which should really be in some library, but I can't quickly find one. 23# 24 25 26class Netlink: 27 # Netlink socket 28 SOL_NETLINK = 270 29 30 NETLINK_ADD_MEMBERSHIP = 1 31 NETLINK_CAP_ACK = 10 32 NETLINK_EXT_ACK = 11 33 NETLINK_GET_STRICT_CHK = 12 34 35 # Netlink message 36 NLMSG_ERROR = 2 37 NLMSG_DONE = 3 38 39 NLM_F_REQUEST = 1 40 NLM_F_ACK = 4 41 NLM_F_ROOT = 0x100 42 NLM_F_MATCH = 0x200 43 44 NLM_F_REPLACE = 0x100 45 NLM_F_EXCL = 0x200 46 NLM_F_CREATE = 0x400 47 NLM_F_APPEND = 0x800 48 49 NLM_F_CAPPED = 0x100 50 NLM_F_ACK_TLVS = 0x200 51 52 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH 53 54 NLA_F_NESTED = 0x8000 55 NLA_F_NET_BYTEORDER = 0x4000 56 57 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER 58 59 # Genetlink defines 60 NETLINK_GENERIC = 16 61 62 GENL_ID_CTRL = 0x10 63 64 # nlctrl 65 CTRL_CMD_GETFAMILY = 3 66 67 CTRL_ATTR_FAMILY_ID = 1 68 CTRL_ATTR_FAMILY_NAME = 2 69 CTRL_ATTR_MAXATTR = 5 70 CTRL_ATTR_MCAST_GROUPS = 7 71 72 CTRL_ATTR_MCAST_GRP_NAME = 1 73 CTRL_ATTR_MCAST_GRP_ID = 2 74 75 # Extack types 76 NLMSGERR_ATTR_MSG = 1 77 NLMSGERR_ATTR_OFFS = 2 78 NLMSGERR_ATTR_COOKIE = 3 79 NLMSGERR_ATTR_POLICY = 4 80 NLMSGERR_ATTR_MISS_TYPE = 5 81 NLMSGERR_ATTR_MISS_NEST = 6 82 83 # Policy types 84 NL_POLICY_TYPE_ATTR_TYPE = 1 85 NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2 86 NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3 87 NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4 88 NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5 89 NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6 90 NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7 91 NL_POLICY_TYPE_ATTR_POLICY_IDX = 8 92 NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9 93 NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10 94 NL_POLICY_TYPE_ATTR_PAD = 11 95 NL_POLICY_TYPE_ATTR_MASK = 12 96 97 AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64', 98 's8', 's16', 's32', 's64', 99 'binary', 'string', 'nul-string', 100 'nested', 'nested-array', 101 'bitfield32', 'sint', 'uint']) 102 103class NlError(Exception): 104 def __init__(self, nl_msg): 105 self.nl_msg = nl_msg 106 self.error = -nl_msg.error 107 108 def __str__(self): 109 return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}" 110 111 112class ConfigError(Exception): 113 pass 114 115 116class NlAttr: 117 ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) 118 type_formats = { 119 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), 120 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), 121 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")), 122 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")), 123 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")), 124 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")), 125 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")), 126 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q")) 127 } 128 129 def __init__(self, raw, offset): 130 self._len, self._type = struct.unpack("HH", raw[offset : offset + 4]) 131 self.type = self._type & ~Netlink.NLA_TYPE_MASK 132 self.is_nest = self._type & Netlink.NLA_F_NESTED 133 self.payload_len = self._len 134 self.full_len = (self.payload_len + 3) & ~3 135 self.raw = raw[offset + 4 : offset + self.payload_len] 136 137 @classmethod 138 def get_format(cls, attr_type, byte_order=None): 139 format = cls.type_formats[attr_type] 140 if byte_order: 141 return format.big if byte_order == "big-endian" \ 142 else format.little 143 return format.native 144 145 def as_scalar(self, attr_type, byte_order=None): 146 format = self.get_format(attr_type, byte_order) 147 return format.unpack(self.raw)[0] 148 149 def as_auto_scalar(self, attr_type, byte_order=None): 150 if len(self.raw) != 4 and len(self.raw) != 8: 151 raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}") 152 real_type = attr_type[0] + str(len(self.raw) * 8) 153 format = self.get_format(real_type, byte_order) 154 return format.unpack(self.raw)[0] 155 156 def as_strz(self): 157 return self.raw.decode('ascii')[:-1] 158 159 def as_bin(self): 160 return self.raw 161 162 def as_c_array(self, type): 163 format = self.get_format(type) 164 return [ x[0] for x in format.iter_unpack(self.raw) ] 165 166 def __repr__(self): 167 return f"[type:{self.type} len:{self._len}] {self.raw}" 168 169 170class NlAttrs: 171 def __init__(self, msg, offset=0): 172 self.attrs = [] 173 174 while offset < len(msg): 175 attr = NlAttr(msg, offset) 176 offset += attr.full_len 177 self.attrs.append(attr) 178 179 def __iter__(self): 180 yield from self.attrs 181 182 def __repr__(self): 183 msg = '' 184 for a in self.attrs: 185 if msg: 186 msg += '\n' 187 msg += repr(a) 188 return msg 189 190 191class NlMsg: 192 def __init__(self, msg, offset, attr_space=None): 193 self.hdr = msg[offset : offset + 16] 194 195 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ 196 struct.unpack("IHHII", self.hdr) 197 198 self.raw = msg[offset + 16 : offset + self.nl_len] 199 200 self.error = 0 201 self.done = 0 202 203 extack_off = None 204 if self.nl_type == Netlink.NLMSG_ERROR: 205 self.error = struct.unpack("i", self.raw[0:4])[0] 206 self.done = 1 207 extack_off = 20 208 elif self.nl_type == Netlink.NLMSG_DONE: 209 self.error = struct.unpack("i", self.raw[0:4])[0] 210 self.done = 1 211 extack_off = 4 212 213 self.extack = None 214 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: 215 self.extack = dict() 216 extack_attrs = NlAttrs(self.raw[extack_off:]) 217 for extack in extack_attrs: 218 if extack.type == Netlink.NLMSGERR_ATTR_MSG: 219 self.extack['msg'] = extack.as_strz() 220 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: 221 self.extack['miss-type'] = extack.as_scalar('u32') 222 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: 223 self.extack['miss-nest'] = extack.as_scalar('u32') 224 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: 225 self.extack['bad-attr-offs'] = extack.as_scalar('u32') 226 elif extack.type == Netlink.NLMSGERR_ATTR_POLICY: 227 self.extack['policy'] = self._decode_policy(extack.raw) 228 else: 229 if 'unknown' not in self.extack: 230 self.extack['unknown'] = [] 231 self.extack['unknown'].append(extack) 232 233 if attr_space: 234 # We don't have the ability to parse nests yet, so only do global 235 if 'miss-type' in self.extack and 'miss-nest' not in self.extack: 236 miss_type = self.extack['miss-type'] 237 if miss_type in attr_space.attrs_by_val: 238 spec = attr_space.attrs_by_val[miss_type] 239 self.extack['miss-type'] = spec['name'] 240 if 'doc' in spec: 241 self.extack['miss-type-doc'] = spec['doc'] 242 243 def _decode_policy(self, raw): 244 policy = {} 245 for attr in NlAttrs(raw): 246 if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE: 247 type = attr.as_scalar('u32') 248 policy['type'] = Netlink.AttrType(type).name 249 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S: 250 policy['min-value'] = attr.as_scalar('s64') 251 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S: 252 policy['max-value'] = attr.as_scalar('s64') 253 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U: 254 policy['min-value'] = attr.as_scalar('u64') 255 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U: 256 policy['max-value'] = attr.as_scalar('u64') 257 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH: 258 policy['min-length'] = attr.as_scalar('u32') 259 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH: 260 policy['max-length'] = attr.as_scalar('u32') 261 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK: 262 policy['bitfield32-mask'] = attr.as_scalar('u32') 263 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK: 264 policy['mask'] = attr.as_scalar('u64') 265 return policy 266 267 def cmd(self): 268 return self.nl_type 269 270 def __repr__(self): 271 msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}" 272 if self.error: 273 msg += '\n\terror: ' + str(self.error) 274 if self.extack: 275 msg += '\n\textack: ' + repr(self.extack) 276 return msg 277 278 279class NlMsgs: 280 def __init__(self, data, attr_space=None): 281 self.msgs = [] 282 283 offset = 0 284 while offset < len(data): 285 msg = NlMsg(data, offset, attr_space=attr_space) 286 offset += msg.nl_len 287 self.msgs.append(msg) 288 289 def __iter__(self): 290 yield from self.msgs 291 292 293genl_family_name_to_id = None 294 295 296def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): 297 # we prepend length in _genl_msg_finalize() 298 if seq is None: 299 seq = random.randint(1, 1024) 300 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 301 genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0) 302 return nlmsg + genlmsg 303 304 305def _genl_msg_finalize(msg): 306 return struct.pack("I", len(msg) + 4) + msg 307 308 309def _genl_load_families(): 310 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: 311 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 312 313 msg = _genl_msg(Netlink.GENL_ID_CTRL, 314 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, 315 Netlink.CTRL_CMD_GETFAMILY, 1) 316 msg = _genl_msg_finalize(msg) 317 318 sock.send(msg, 0) 319 320 global genl_family_name_to_id 321 genl_family_name_to_id = dict() 322 323 while True: 324 reply = sock.recv(128 * 1024) 325 nms = NlMsgs(reply) 326 for nl_msg in nms: 327 if nl_msg.error: 328 print("Netlink error:", nl_msg.error) 329 return 330 if nl_msg.done: 331 return 332 333 gm = GenlMsg(nl_msg) 334 fam = dict() 335 for attr in NlAttrs(gm.raw): 336 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: 337 fam['id'] = attr.as_scalar('u16') 338 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: 339 fam['name'] = attr.as_strz() 340 elif attr.type == Netlink.CTRL_ATTR_MAXATTR: 341 fam['maxattr'] = attr.as_scalar('u32') 342 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: 343 fam['mcast'] = dict() 344 for entry in NlAttrs(attr.raw): 345 mcast_name = None 346 mcast_id = None 347 for entry_attr in NlAttrs(entry.raw): 348 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: 349 mcast_name = entry_attr.as_strz() 350 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: 351 mcast_id = entry_attr.as_scalar('u32') 352 if mcast_name and mcast_id is not None: 353 fam['mcast'][mcast_name] = mcast_id 354 if 'name' in fam and 'id' in fam: 355 genl_family_name_to_id[fam['name']] = fam 356 357 358class GenlMsg: 359 def __init__(self, nl_msg): 360 self.nl = nl_msg 361 self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0) 362 self.raw = nl_msg.raw[4:] 363 364 def cmd(self): 365 return self.genl_cmd 366 367 def __repr__(self): 368 msg = repr(self.nl) 369 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" 370 for a in self.raw_attrs: 371 msg += '\t\t' + repr(a) + '\n' 372 return msg 373 374 375class NetlinkProtocol: 376 def __init__(self, family_name, proto_num): 377 self.family_name = family_name 378 self.proto_num = proto_num 379 380 def _message(self, nl_type, nl_flags, seq=None): 381 if seq is None: 382 seq = random.randint(1, 1024) 383 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 384 return nlmsg 385 386 def message(self, flags, command, version, seq=None): 387 return self._message(command, flags, seq) 388 389 def _decode(self, nl_msg): 390 return nl_msg 391 392 def decode(self, ynl, nl_msg, op): 393 msg = self._decode(nl_msg) 394 if op is None: 395 op = ynl.rsp_by_value[msg.cmd()] 396 fixed_header_size = ynl._struct_size(op.fixed_header) 397 msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size) 398 return msg 399 400 def get_mcast_id(self, mcast_name, mcast_groups): 401 if mcast_name not in mcast_groups: 402 raise Exception(f'Multicast group "{mcast_name}" not present in the spec') 403 return mcast_groups[mcast_name].value 404 405 def msghdr_size(self): 406 return 16 407 408 409class GenlProtocol(NetlinkProtocol): 410 def __init__(self, family_name): 411 super().__init__(family_name, Netlink.NETLINK_GENERIC) 412 413 global genl_family_name_to_id 414 if genl_family_name_to_id is None: 415 _genl_load_families() 416 417 self.genl_family = genl_family_name_to_id[family_name] 418 self.family_id = genl_family_name_to_id[family_name]['id'] 419 420 def message(self, flags, command, version, seq=None): 421 nlmsg = self._message(self.family_id, flags, seq) 422 genlmsg = struct.pack("BBH", command, version, 0) 423 return nlmsg + genlmsg 424 425 def _decode(self, nl_msg): 426 return GenlMsg(nl_msg) 427 428 def get_mcast_id(self, mcast_name, mcast_groups): 429 if mcast_name not in self.genl_family['mcast']: 430 raise Exception(f'Multicast group "{mcast_name}" not present in the family') 431 return self.genl_family['mcast'][mcast_name] 432 433 def msghdr_size(self): 434 return super().msghdr_size() + 4 435 436 437class SpaceAttrs: 438 SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values']) 439 440 def __init__(self, attr_space, attrs, outer = None): 441 outer_scopes = outer.scopes if outer else [] 442 inner_scope = self.SpecValuesPair(attr_space, attrs) 443 self.scopes = [inner_scope] + outer_scopes 444 445 def lookup(self, name): 446 for scope in self.scopes: 447 if name in scope.spec: 448 if name in scope.values: 449 return scope.values[name] 450 spec_name = scope.spec.yaml['name'] 451 raise Exception( 452 f"No value for '{name}' in attribute space '{spec_name}'") 453 raise Exception(f"Attribute '{name}' not defined in any attribute-set") 454 455 456# 457# YNL implementation details. 458# 459 460 461class YnlFamily(SpecFamily): 462 def __init__(self, def_path, schema=None, process_unknown=False, 463 recv_size=0): 464 super().__init__(def_path, schema) 465 466 self.include_raw = False 467 self.process_unknown = process_unknown 468 469 try: 470 if self.proto == "netlink-raw": 471 self.nlproto = NetlinkProtocol(self.yaml['name'], 472 self.yaml['protonum']) 473 else: 474 self.nlproto = GenlProtocol(self.yaml['name']) 475 except KeyError: 476 raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel") 477 478 self._recv_dbg = False 479 # Note that netlink will use conservative (min) message size for 480 # the first dump recv() on the socket, our setting will only matter 481 # from the second recv() on. 482 self._recv_size = recv_size if recv_size else 131072 483 # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo) 484 # for a message, so smaller receive sizes will lead to truncation. 485 # Note that the min size for other families may be larger than 4k! 486 if self._recv_size < 4000: 487 raise ConfigError() 488 489 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num) 490 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 491 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) 492 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1) 493 494 self.async_msg_ids = set() 495 self.async_msg_queue = queue.Queue() 496 497 for msg in self.msgs.values(): 498 if msg.is_async: 499 self.async_msg_ids.add(msg.rsp_value) 500 501 for op_name, op in self.ops.items(): 502 bound_f = functools.partial(self._op, op_name) 503 setattr(self, op.ident_name, bound_f) 504 505 506 def ntf_subscribe(self, mcast_name): 507 mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups) 508 self.sock.bind((0, 0)) 509 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, 510 mcast_id) 511 512 def set_recv_dbg(self, enabled): 513 self._recv_dbg = enabled 514 515 def _recv_dbg_print(self, reply, nl_msgs): 516 if not self._recv_dbg: 517 return 518 print("Recv: read", len(reply), "bytes,", 519 len(nl_msgs.msgs), "messages", file=sys.stderr) 520 for nl_msg in nl_msgs: 521 print(" ", nl_msg, file=sys.stderr) 522 523 def _encode_enum(self, attr_spec, value): 524 enum = self.consts[attr_spec['enum']] 525 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 526 scalar = 0 527 if isinstance(value, str): 528 value = [value] 529 for single_value in value: 530 scalar += enum.entries[single_value].user_value(as_flags = True) 531 return scalar 532 else: 533 return enum.entries[value].user_value() 534 535 def _get_scalar(self, attr_spec, value): 536 try: 537 return int(value) 538 except (ValueError, TypeError) as e: 539 if 'enum' not in attr_spec: 540 raise e 541 return self._encode_enum(attr_spec, value) 542 543 def _add_attr(self, space, name, value, search_attrs): 544 try: 545 attr = self.attr_sets[space][name] 546 except KeyError: 547 raise Exception(f"Space '{space}' has no attribute '{name}'") 548 nl_type = attr.value 549 550 if attr.is_multi and isinstance(value, list): 551 attr_payload = b'' 552 for subvalue in value: 553 attr_payload += self._add_attr(space, name, subvalue, search_attrs) 554 return attr_payload 555 556 if attr["type"] == 'nest': 557 nl_type |= Netlink.NLA_F_NESTED 558 attr_payload = b'' 559 sub_space = attr['nested-attributes'] 560 sub_attrs = SpaceAttrs(self.attr_sets[sub_space], value, search_attrs) 561 for subname, subvalue in value.items(): 562 attr_payload += self._add_attr(sub_space, subname, subvalue, sub_attrs) 563 elif attr["type"] == 'flag': 564 if not value: 565 # If value is absent or false then skip attribute creation. 566 return b'' 567 attr_payload = b'' 568 elif attr["type"] == 'string': 569 attr_payload = str(value).encode('ascii') + b'\x00' 570 elif attr["type"] == 'binary': 571 if isinstance(value, bytes): 572 attr_payload = value 573 elif isinstance(value, str): 574 attr_payload = bytes.fromhex(value) 575 elif isinstance(value, dict) and attr.struct_name: 576 attr_payload = self._encode_struct(attr.struct_name, value) 577 else: 578 raise Exception(f'Unknown type for binary attribute, value: {value}') 579 elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar: 580 scalar = self._get_scalar(attr, value) 581 if attr.is_auto_scalar: 582 attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64') 583 else: 584 attr_type = attr["type"] 585 format = NlAttr.get_format(attr_type, attr.byte_order) 586 attr_payload = format.pack(scalar) 587 elif attr['type'] in "bitfield32": 588 scalar_value = self._get_scalar(attr, value["value"]) 589 scalar_selector = self._get_scalar(attr, value["selector"]) 590 attr_payload = struct.pack("II", scalar_value, scalar_selector) 591 elif attr['type'] == 'sub-message': 592 msg_format = self._resolve_selector(attr, search_attrs) 593 attr_payload = b'' 594 if msg_format.fixed_header: 595 attr_payload += self._encode_struct(msg_format.fixed_header, value) 596 if msg_format.attr_set: 597 if msg_format.attr_set in self.attr_sets: 598 nl_type |= Netlink.NLA_F_NESTED 599 sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs) 600 for subname, subvalue in value.items(): 601 attr_payload += self._add_attr(msg_format.attr_set, 602 subname, subvalue, sub_attrs) 603 else: 604 raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'") 605 else: 606 raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') 607 608 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) 609 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad 610 611 def _decode_enum(self, raw, attr_spec): 612 enum = self.consts[attr_spec['enum']] 613 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 614 i = 0 615 value = set() 616 while raw: 617 if raw & 1: 618 value.add(enum.entries_by_val[i].name) 619 raw >>= 1 620 i += 1 621 else: 622 value = enum.entries_by_val[raw].name 623 return value 624 625 def _decode_binary(self, attr, attr_spec): 626 if attr_spec.struct_name: 627 decoded = self._decode_struct(attr.raw, attr_spec.struct_name) 628 elif attr_spec.sub_type: 629 decoded = attr.as_c_array(attr_spec.sub_type) 630 else: 631 decoded = attr.as_bin() 632 if attr_spec.display_hint: 633 decoded = self._formatted_string(decoded, attr_spec.display_hint) 634 return decoded 635 636 def _decode_array_attr(self, attr, attr_spec): 637 decoded = [] 638 offset = 0 639 while offset < len(attr.raw): 640 item = NlAttr(attr.raw, offset) 641 offset += item.full_len 642 643 if attr_spec["sub-type"] == 'nest': 644 subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes']) 645 decoded.append({ item.type: subattrs }) 646 elif attr_spec["sub-type"] == 'binary': 647 subattr = item.as_bin() 648 if attr_spec.display_hint: 649 subattr = self._formatted_string(subattr, attr_spec.display_hint) 650 decoded.append(subattr) 651 elif attr_spec["sub-type"] in NlAttr.type_formats: 652 subattr = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order) 653 if 'enum' in attr_spec: 654 subattr = self._decode_enum(subattr, attr_spec) 655 elif attr_spec.display_hint: 656 subattr = self._formatted_string(subattr, attr_spec.display_hint) 657 decoded.append(subattr) 658 else: 659 raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}') 660 return decoded 661 662 def _decode_nest_type_value(self, attr, attr_spec): 663 decoded = {} 664 value = attr 665 for name in attr_spec['type-value']: 666 value = NlAttr(value.raw, 0) 667 decoded[name] = value.type 668 subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes']) 669 decoded.update(subattrs) 670 return decoded 671 672 def _decode_unknown(self, attr): 673 if attr.is_nest: 674 return self._decode(NlAttrs(attr.raw), None) 675 else: 676 return attr.as_bin() 677 678 def _rsp_add(self, rsp, name, is_multi, decoded): 679 if is_multi == None: 680 if name in rsp and type(rsp[name]) is not list: 681 rsp[name] = [rsp[name]] 682 is_multi = True 683 else: 684 is_multi = False 685 686 if not is_multi: 687 rsp[name] = decoded 688 elif name in rsp: 689 rsp[name].append(decoded) 690 else: 691 rsp[name] = [decoded] 692 693 def _resolve_selector(self, attr_spec, search_attrs): 694 sub_msg = attr_spec.sub_message 695 if sub_msg not in self.sub_msgs: 696 raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}") 697 sub_msg_spec = self.sub_msgs[sub_msg] 698 699 selector = attr_spec.selector 700 value = search_attrs.lookup(selector) 701 if value not in sub_msg_spec.formats: 702 raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'") 703 704 spec = sub_msg_spec.formats[value] 705 return spec 706 707 def _decode_sub_msg(self, attr, attr_spec, search_attrs): 708 msg_format = self._resolve_selector(attr_spec, search_attrs) 709 decoded = {} 710 offset = 0 711 if msg_format.fixed_header: 712 decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header)); 713 offset = self._struct_size(msg_format.fixed_header) 714 if msg_format.attr_set: 715 if msg_format.attr_set in self.attr_sets: 716 subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set) 717 decoded.update(subdict) 718 else: 719 raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'") 720 return decoded 721 722 def _decode(self, attrs, space, outer_attrs = None): 723 rsp = dict() 724 if space: 725 attr_space = self.attr_sets[space] 726 search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs) 727 728 for attr in attrs: 729 try: 730 attr_spec = attr_space.attrs_by_val[attr.type] 731 except (KeyError, UnboundLocalError): 732 if not self.process_unknown: 733 raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'") 734 attr_name = f"UnknownAttr({attr.type})" 735 self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr)) 736 continue 737 738 try: 739 if attr_spec["type"] == 'nest': 740 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs) 741 decoded = subdict 742 elif attr_spec["type"] == 'string': 743 decoded = attr.as_strz() 744 elif attr_spec["type"] == 'binary': 745 decoded = self._decode_binary(attr, attr_spec) 746 elif attr_spec["type"] == 'flag': 747 decoded = True 748 elif attr_spec.is_auto_scalar: 749 decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order) 750 elif attr_spec["type"] in NlAttr.type_formats: 751 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) 752 if 'enum' in attr_spec: 753 decoded = self._decode_enum(decoded, attr_spec) 754 elif attr_spec.display_hint: 755 decoded = self._formatted_string(decoded, attr_spec.display_hint) 756 elif attr_spec["type"] == 'indexed-array': 757 decoded = self._decode_array_attr(attr, attr_spec) 758 elif attr_spec["type"] == 'bitfield32': 759 value, selector = struct.unpack("II", attr.raw) 760 if 'enum' in attr_spec: 761 value = self._decode_enum(value, attr_spec) 762 selector = self._decode_enum(selector, attr_spec) 763 decoded = {"value": value, "selector": selector} 764 elif attr_spec["type"] == 'sub-message': 765 decoded = self._decode_sub_msg(attr, attr_spec, search_attrs) 766 elif attr_spec["type"] == 'nest-type-value': 767 decoded = self._decode_nest_type_value(attr, attr_spec) 768 else: 769 if not self.process_unknown: 770 raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') 771 decoded = self._decode_unknown(attr) 772 773 self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded) 774 except: 775 print(f"Error decoding '{attr_spec.name}' from '{space}'") 776 raise 777 778 return rsp 779 780 def _decode_extack_path(self, attrs, attr_set, offset, target): 781 for attr in attrs: 782 try: 783 attr_spec = attr_set.attrs_by_val[attr.type] 784 except KeyError: 785 raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") 786 if offset > target: 787 break 788 if offset == target: 789 return '.' + attr_spec.name 790 791 if offset + attr.full_len <= target: 792 offset += attr.full_len 793 continue 794 if attr_spec['type'] != 'nest': 795 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") 796 offset += 4 797 subpath = self._decode_extack_path(NlAttrs(attr.raw), 798 self.attr_sets[attr_spec['nested-attributes']], 799 offset, target) 800 if subpath is None: 801 return None 802 return '.' + attr_spec.name + subpath 803 804 return None 805 806 def _decode_extack(self, request, op, extack): 807 if 'bad-attr-offs' not in extack: 808 return 809 810 msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op) 811 offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header) 812 path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset, 813 extack['bad-attr-offs']) 814 if path: 815 del extack['bad-attr-offs'] 816 extack['bad-attr'] = path 817 818 def _struct_size(self, name): 819 if name: 820 members = self.consts[name].members 821 size = 0 822 for m in members: 823 if m.type in ['pad', 'binary']: 824 if m.struct: 825 size += self._struct_size(m.struct) 826 else: 827 size += m.len 828 else: 829 format = NlAttr.get_format(m.type, m.byte_order) 830 size += format.size 831 return size 832 else: 833 return 0 834 835 def _decode_struct(self, data, name): 836 members = self.consts[name].members 837 attrs = dict() 838 offset = 0 839 for m in members: 840 value = None 841 if m.type == 'pad': 842 offset += m.len 843 elif m.type == 'binary': 844 if m.struct: 845 len = self._struct_size(m.struct) 846 value = self._decode_struct(data[offset : offset + len], 847 m.struct) 848 offset += len 849 else: 850 value = data[offset : offset + m.len] 851 offset += m.len 852 else: 853 format = NlAttr.get_format(m.type, m.byte_order) 854 [ value ] = format.unpack_from(data, offset) 855 offset += format.size 856 if value is not None: 857 if m.enum: 858 value = self._decode_enum(value, m) 859 elif m.display_hint: 860 value = self._formatted_string(value, m.display_hint) 861 attrs[m.name] = value 862 return attrs 863 864 def _encode_struct(self, name, vals): 865 members = self.consts[name].members 866 attr_payload = b'' 867 for m in members: 868 value = vals.pop(m.name) if m.name in vals else None 869 if m.type == 'pad': 870 attr_payload += bytearray(m.len) 871 elif m.type == 'binary': 872 if m.struct: 873 if value is None: 874 value = dict() 875 attr_payload += self._encode_struct(m.struct, value) 876 else: 877 if value is None: 878 attr_payload += bytearray(m.len) 879 else: 880 attr_payload += bytes.fromhex(value) 881 else: 882 if value is None: 883 value = 0 884 format = NlAttr.get_format(m.type, m.byte_order) 885 attr_payload += format.pack(value) 886 return attr_payload 887 888 def _formatted_string(self, raw, display_hint): 889 if display_hint == 'mac': 890 formatted = ':'.join('%02x' % b for b in raw) 891 elif display_hint == 'hex': 892 if isinstance(raw, int): 893 formatted = hex(raw) 894 else: 895 formatted = bytes.hex(raw, ' ') 896 elif display_hint in [ 'ipv4', 'ipv6' ]: 897 formatted = format(ipaddress.ip_address(raw)) 898 elif display_hint == 'uuid': 899 formatted = str(uuid.UUID(bytes=raw)) 900 else: 901 formatted = raw 902 return formatted 903 904 def handle_ntf(self, decoded): 905 msg = dict() 906 if self.include_raw: 907 msg['raw'] = decoded 908 op = self.rsp_by_value[decoded.cmd()] 909 attrs = self._decode(decoded.raw_attrs, op.attr_set.name) 910 if op.fixed_header: 911 attrs.update(self._decode_struct(decoded.raw, op.fixed_header)) 912 913 msg['name'] = op['name'] 914 msg['msg'] = attrs 915 self.async_msg_queue.put(msg) 916 917 def check_ntf(self): 918 while True: 919 try: 920 reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT) 921 except BlockingIOError: 922 return 923 924 nms = NlMsgs(reply) 925 self._recv_dbg_print(reply, nms) 926 for nl_msg in nms: 927 if nl_msg.error: 928 print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) 929 print(nl_msg) 930 continue 931 if nl_msg.done: 932 print("Netlink done while checking for ntf!?") 933 continue 934 935 decoded = self.nlproto.decode(self, nl_msg, None) 936 if decoded.cmd() not in self.async_msg_ids: 937 print("Unexpected msg id while checking for ntf", decoded) 938 continue 939 940 self.handle_ntf(decoded) 941 942 def poll_ntf(self, duration=None): 943 start_time = time.time() 944 selector = selectors.DefaultSelector() 945 selector.register(self.sock, selectors.EVENT_READ) 946 947 while True: 948 try: 949 yield self.async_msg_queue.get_nowait() 950 except queue.Empty: 951 if duration is not None: 952 timeout = start_time + duration - time.time() 953 if timeout <= 0: 954 return 955 else: 956 timeout = None 957 events = selector.select(timeout) 958 if events: 959 self.check_ntf() 960 961 def operation_do_attributes(self, name): 962 """ 963 For a given operation name, find and return a supported 964 set of attributes (as a dict). 965 """ 966 op = self.find_operation(name) 967 if not op: 968 return None 969 970 return op['do']['request']['attributes'].copy() 971 972 def _encode_message(self, op, vals, flags, req_seq): 973 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK 974 for flag in flags or []: 975 nl_flags |= flag 976 977 msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq) 978 if op.fixed_header: 979 msg += self._encode_struct(op.fixed_header, vals) 980 search_attrs = SpaceAttrs(op.attr_set, vals) 981 for name, value in vals.items(): 982 msg += self._add_attr(op.attr_set.name, name, value, search_attrs) 983 msg = _genl_msg_finalize(msg) 984 return msg 985 986 def _ops(self, ops): 987 reqs_by_seq = {} 988 req_seq = random.randint(1024, 65535) 989 payload = b'' 990 for (method, vals, flags) in ops: 991 op = self.ops[method] 992 msg = self._encode_message(op, vals, flags, req_seq) 993 reqs_by_seq[req_seq] = (op, msg, flags) 994 payload += msg 995 req_seq += 1 996 997 self.sock.send(payload, 0) 998 999 done = False 1000 rsp = [] 1001 op_rsp = [] 1002 while not done: 1003 reply = self.sock.recv(self._recv_size) 1004 nms = NlMsgs(reply, attr_space=op.attr_set) 1005 self._recv_dbg_print(reply, nms) 1006 for nl_msg in nms: 1007 if nl_msg.nl_seq in reqs_by_seq: 1008 (op, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq] 1009 if nl_msg.extack: 1010 self._decode_extack(req_msg, op, nl_msg.extack) 1011 else: 1012 op = None 1013 req_flags = [] 1014 1015 if nl_msg.error: 1016 raise NlError(nl_msg) 1017 if nl_msg.done: 1018 if nl_msg.extack: 1019 print("Netlink warning:") 1020 print(nl_msg) 1021 1022 if Netlink.NLM_F_DUMP in req_flags: 1023 rsp.append(op_rsp) 1024 elif not op_rsp: 1025 rsp.append(None) 1026 elif len(op_rsp) == 1: 1027 rsp.append(op_rsp[0]) 1028 else: 1029 rsp.append(op_rsp) 1030 op_rsp = [] 1031 1032 del reqs_by_seq[nl_msg.nl_seq] 1033 done = len(reqs_by_seq) == 0 1034 break 1035 1036 decoded = self.nlproto.decode(self, nl_msg, op) 1037 1038 # Check if this is a reply to our request 1039 if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value: 1040 if decoded.cmd() in self.async_msg_ids: 1041 self.handle_ntf(decoded) 1042 continue 1043 else: 1044 print('Unexpected message: ' + repr(decoded)) 1045 continue 1046 1047 rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name) 1048 if op.fixed_header: 1049 rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header)) 1050 op_rsp.append(rsp_msg) 1051 1052 return rsp 1053 1054 def _op(self, method, vals, flags=None, dump=False): 1055 req_flags = flags or [] 1056 if dump: 1057 req_flags.append(Netlink.NLM_F_DUMP) 1058 1059 ops = [(method, vals, req_flags)] 1060 return self._ops(ops)[0] 1061 1062 def do(self, method, vals, flags=None): 1063 return self._op(method, vals, flags) 1064 1065 def dump(self, method, vals): 1066 return self._op(method, vals, dump=True) 1067 1068 def do_multi(self, ops): 1069 return self._ops(ops) 1070