1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause 2 3from collections import namedtuple 4from enum import Enum 5import functools 6import os 7import random 8import socket 9import struct 10from struct import Struct 11import sys 12import yaml 13import ipaddress 14import uuid 15import queue 16import selectors 17import time 18 19from .nlspec import SpecFamily 20 21# 22# Generic Netlink code which should really be in some library, but I can't quickly find one. 23# 24 25 26class Netlink: 27 # Netlink socket 28 SOL_NETLINK = 270 29 30 NETLINK_ADD_MEMBERSHIP = 1 31 NETLINK_CAP_ACK = 10 32 NETLINK_EXT_ACK = 11 33 NETLINK_GET_STRICT_CHK = 12 34 35 # Netlink message 36 NLMSG_ERROR = 2 37 NLMSG_DONE = 3 38 39 NLM_F_REQUEST = 1 40 NLM_F_ACK = 4 41 NLM_F_ROOT = 0x100 42 NLM_F_MATCH = 0x200 43 44 NLM_F_REPLACE = 0x100 45 NLM_F_EXCL = 0x200 46 NLM_F_CREATE = 0x400 47 NLM_F_APPEND = 0x800 48 49 NLM_F_CAPPED = 0x100 50 NLM_F_ACK_TLVS = 0x200 51 52 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH 53 54 NLA_F_NESTED = 0x8000 55 NLA_F_NET_BYTEORDER = 0x4000 56 57 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER 58 59 # Genetlink defines 60 NETLINK_GENERIC = 16 61 62 GENL_ID_CTRL = 0x10 63 64 # nlctrl 65 CTRL_CMD_GETFAMILY = 3 66 67 CTRL_ATTR_FAMILY_ID = 1 68 CTRL_ATTR_FAMILY_NAME = 2 69 CTRL_ATTR_MAXATTR = 5 70 CTRL_ATTR_MCAST_GROUPS = 7 71 72 CTRL_ATTR_MCAST_GRP_NAME = 1 73 CTRL_ATTR_MCAST_GRP_ID = 2 74 75 # Extack types 76 NLMSGERR_ATTR_MSG = 1 77 NLMSGERR_ATTR_OFFS = 2 78 NLMSGERR_ATTR_COOKIE = 3 79 NLMSGERR_ATTR_POLICY = 4 80 NLMSGERR_ATTR_MISS_TYPE = 5 81 NLMSGERR_ATTR_MISS_NEST = 6 82 83 # Policy types 84 NL_POLICY_TYPE_ATTR_TYPE = 1 85 NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2 86 NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3 87 NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4 88 NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5 89 NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6 90 NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7 91 NL_POLICY_TYPE_ATTR_POLICY_IDX = 8 92 NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9 93 NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10 94 NL_POLICY_TYPE_ATTR_PAD = 11 95 NL_POLICY_TYPE_ATTR_MASK = 12 96 97 AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64', 98 's8', 's16', 's32', 's64', 99 'binary', 'string', 'nul-string', 100 'nested', 'nested-array', 101 'bitfield32', 'sint', 'uint']) 102 103class NlError(Exception): 104 def __init__(self, nl_msg): 105 self.nl_msg = nl_msg 106 self.error = -nl_msg.error 107 108 def __str__(self): 109 return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}" 110 111 112class ConfigError(Exception): 113 pass 114 115 116class NlAttr: 117 ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) 118 type_formats = { 119 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), 120 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), 121 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")), 122 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")), 123 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")), 124 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")), 125 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")), 126 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q")) 127 } 128 129 def __init__(self, raw, offset): 130 self._len, self._type = struct.unpack("HH", raw[offset : offset + 4]) 131 self.type = self._type & ~Netlink.NLA_TYPE_MASK 132 self.is_nest = self._type & Netlink.NLA_F_NESTED 133 self.payload_len = self._len 134 self.full_len = (self.payload_len + 3) & ~3 135 self.raw = raw[offset + 4 : offset + self.payload_len] 136 137 @classmethod 138 def get_format(cls, attr_type, byte_order=None): 139 format = cls.type_formats[attr_type] 140 if byte_order: 141 return format.big if byte_order == "big-endian" \ 142 else format.little 143 return format.native 144 145 def as_scalar(self, attr_type, byte_order=None): 146 format = self.get_format(attr_type, byte_order) 147 return format.unpack(self.raw)[0] 148 149 def as_auto_scalar(self, attr_type, byte_order=None): 150 if len(self.raw) != 4 and len(self.raw) != 8: 151 raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}") 152 real_type = attr_type[0] + str(len(self.raw) * 8) 153 format = self.get_format(real_type, byte_order) 154 return format.unpack(self.raw)[0] 155 156 def as_strz(self): 157 return self.raw.decode('ascii')[:-1] 158 159 def as_bin(self): 160 return self.raw 161 162 def as_c_array(self, type): 163 format = self.get_format(type) 164 return [ x[0] for x in format.iter_unpack(self.raw) ] 165 166 def __repr__(self): 167 return f"[type:{self.type} len:{self._len}] {self.raw}" 168 169 170class NlAttrs: 171 def __init__(self, msg, offset=0): 172 self.attrs = [] 173 174 while offset < len(msg): 175 attr = NlAttr(msg, offset) 176 offset += attr.full_len 177 self.attrs.append(attr) 178 179 def __iter__(self): 180 yield from self.attrs 181 182 def __repr__(self): 183 msg = '' 184 for a in self.attrs: 185 if msg: 186 msg += '\n' 187 msg += repr(a) 188 return msg 189 190 191class NlMsg: 192 def __init__(self, msg, offset, attr_space=None): 193 self.hdr = msg[offset : offset + 16] 194 195 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ 196 struct.unpack("IHHII", self.hdr) 197 198 self.raw = msg[offset + 16 : offset + self.nl_len] 199 200 self.error = 0 201 self.done = 0 202 203 extack_off = None 204 if self.nl_type == Netlink.NLMSG_ERROR: 205 self.error = struct.unpack("i", self.raw[0:4])[0] 206 self.done = 1 207 extack_off = 20 208 elif self.nl_type == Netlink.NLMSG_DONE: 209 self.error = struct.unpack("i", self.raw[0:4])[0] 210 self.done = 1 211 extack_off = 4 212 213 self.extack = None 214 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: 215 self.extack = dict() 216 extack_attrs = NlAttrs(self.raw[extack_off:]) 217 for extack in extack_attrs: 218 if extack.type == Netlink.NLMSGERR_ATTR_MSG: 219 self.extack['msg'] = extack.as_strz() 220 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: 221 self.extack['miss-type'] = extack.as_scalar('u32') 222 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: 223 self.extack['miss-nest'] = extack.as_scalar('u32') 224 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: 225 self.extack['bad-attr-offs'] = extack.as_scalar('u32') 226 elif extack.type == Netlink.NLMSGERR_ATTR_POLICY: 227 self.extack['policy'] = self._decode_policy(extack.raw) 228 else: 229 if 'unknown' not in self.extack: 230 self.extack['unknown'] = [] 231 self.extack['unknown'].append(extack) 232 233 if attr_space: 234 # We don't have the ability to parse nests yet, so only do global 235 if 'miss-type' in self.extack and 'miss-nest' not in self.extack: 236 miss_type = self.extack['miss-type'] 237 if miss_type in attr_space.attrs_by_val: 238 spec = attr_space.attrs_by_val[miss_type] 239 self.extack['miss-type'] = spec['name'] 240 if 'doc' in spec: 241 self.extack['miss-type-doc'] = spec['doc'] 242 243 def _decode_policy(self, raw): 244 policy = {} 245 for attr in NlAttrs(raw): 246 if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE: 247 type = attr.as_scalar('u32') 248 policy['type'] = Netlink.AttrType(type).name 249 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S: 250 policy['min-value'] = attr.as_scalar('s64') 251 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S: 252 policy['max-value'] = attr.as_scalar('s64') 253 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U: 254 policy['min-value'] = attr.as_scalar('u64') 255 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U: 256 policy['max-value'] = attr.as_scalar('u64') 257 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH: 258 policy['min-length'] = attr.as_scalar('u32') 259 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH: 260 policy['max-length'] = attr.as_scalar('u32') 261 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK: 262 policy['bitfield32-mask'] = attr.as_scalar('u32') 263 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK: 264 policy['mask'] = attr.as_scalar('u64') 265 return policy 266 267 def cmd(self): 268 return self.nl_type 269 270 def __repr__(self): 271 msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}" 272 if self.error: 273 msg += '\n\terror: ' + str(self.error) 274 if self.extack: 275 msg += '\n\textack: ' + repr(self.extack) 276 return msg 277 278 279class NlMsgs: 280 def __init__(self, data, attr_space=None): 281 self.msgs = [] 282 283 offset = 0 284 while offset < len(data): 285 msg = NlMsg(data, offset, attr_space=attr_space) 286 offset += msg.nl_len 287 self.msgs.append(msg) 288 289 def __iter__(self): 290 yield from self.msgs 291 292 293genl_family_name_to_id = None 294 295 296def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): 297 # we prepend length in _genl_msg_finalize() 298 if seq is None: 299 seq = random.randint(1, 1024) 300 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 301 genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0) 302 return nlmsg + genlmsg 303 304 305def _genl_msg_finalize(msg): 306 return struct.pack("I", len(msg) + 4) + msg 307 308 309def _genl_load_families(): 310 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: 311 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 312 313 msg = _genl_msg(Netlink.GENL_ID_CTRL, 314 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, 315 Netlink.CTRL_CMD_GETFAMILY, 1) 316 msg = _genl_msg_finalize(msg) 317 318 sock.send(msg, 0) 319 320 global genl_family_name_to_id 321 genl_family_name_to_id = dict() 322 323 while True: 324 reply = sock.recv(128 * 1024) 325 nms = NlMsgs(reply) 326 for nl_msg in nms: 327 if nl_msg.error: 328 print("Netlink error:", nl_msg.error) 329 return 330 if nl_msg.done: 331 return 332 333 gm = GenlMsg(nl_msg) 334 fam = dict() 335 for attr in NlAttrs(gm.raw): 336 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: 337 fam['id'] = attr.as_scalar('u16') 338 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: 339 fam['name'] = attr.as_strz() 340 elif attr.type == Netlink.CTRL_ATTR_MAXATTR: 341 fam['maxattr'] = attr.as_scalar('u32') 342 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: 343 fam['mcast'] = dict() 344 for entry in NlAttrs(attr.raw): 345 mcast_name = None 346 mcast_id = None 347 for entry_attr in NlAttrs(entry.raw): 348 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: 349 mcast_name = entry_attr.as_strz() 350 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: 351 mcast_id = entry_attr.as_scalar('u32') 352 if mcast_name and mcast_id is not None: 353 fam['mcast'][mcast_name] = mcast_id 354 if 'name' in fam and 'id' in fam: 355 genl_family_name_to_id[fam['name']] = fam 356 357 358class GenlMsg: 359 def __init__(self, nl_msg): 360 self.nl = nl_msg 361 self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0) 362 self.raw = nl_msg.raw[4:] 363 364 def cmd(self): 365 return self.genl_cmd 366 367 def __repr__(self): 368 msg = repr(self.nl) 369 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" 370 for a in self.raw_attrs: 371 msg += '\t\t' + repr(a) + '\n' 372 return msg 373 374 375class NetlinkProtocol: 376 def __init__(self, family_name, proto_num): 377 self.family_name = family_name 378 self.proto_num = proto_num 379 380 def _message(self, nl_type, nl_flags, seq=None): 381 if seq is None: 382 seq = random.randint(1, 1024) 383 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 384 return nlmsg 385 386 def message(self, flags, command, version, seq=None): 387 return self._message(command, flags, seq) 388 389 def _decode(self, nl_msg): 390 return nl_msg 391 392 def decode(self, ynl, nl_msg, op): 393 msg = self._decode(nl_msg) 394 if op is None: 395 op = ynl.rsp_by_value[msg.cmd()] 396 fixed_header_size = ynl._struct_size(op.fixed_header) 397 msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size) 398 return msg 399 400 def get_mcast_id(self, mcast_name, mcast_groups): 401 if mcast_name not in mcast_groups: 402 raise Exception(f'Multicast group "{mcast_name}" not present in the spec') 403 return mcast_groups[mcast_name].value 404 405 def msghdr_size(self): 406 return 16 407 408 409class GenlProtocol(NetlinkProtocol): 410 def __init__(self, family_name): 411 super().__init__(family_name, Netlink.NETLINK_GENERIC) 412 413 global genl_family_name_to_id 414 if genl_family_name_to_id is None: 415 _genl_load_families() 416 417 self.genl_family = genl_family_name_to_id[family_name] 418 self.family_id = genl_family_name_to_id[family_name]['id'] 419 420 def message(self, flags, command, version, seq=None): 421 nlmsg = self._message(self.family_id, flags, seq) 422 genlmsg = struct.pack("BBH", command, version, 0) 423 return nlmsg + genlmsg 424 425 def _decode(self, nl_msg): 426 return GenlMsg(nl_msg) 427 428 def get_mcast_id(self, mcast_name, mcast_groups): 429 if mcast_name not in self.genl_family['mcast']: 430 raise Exception(f'Multicast group "{mcast_name}" not present in the family') 431 return self.genl_family['mcast'][mcast_name] 432 433 def msghdr_size(self): 434 return super().msghdr_size() + 4 435 436 437class SpaceAttrs: 438 SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values']) 439 440 def __init__(self, attr_space, attrs, outer = None): 441 outer_scopes = outer.scopes if outer else [] 442 inner_scope = self.SpecValuesPair(attr_space, attrs) 443 self.scopes = [inner_scope] + outer_scopes 444 445 def lookup(self, name): 446 for scope in self.scopes: 447 if name in scope.spec: 448 if name in scope.values: 449 return scope.values[name] 450 spec_name = scope.spec.yaml['name'] 451 raise Exception( 452 f"No value for '{name}' in attribute space '{spec_name}'") 453 raise Exception(f"Attribute '{name}' not defined in any attribute-set") 454 455 456# 457# YNL implementation details. 458# 459 460 461class YnlFamily(SpecFamily): 462 def __init__(self, def_path, schema=None, process_unknown=False, 463 recv_size=0): 464 super().__init__(def_path, schema) 465 466 self.include_raw = False 467 self.process_unknown = process_unknown 468 469 try: 470 if self.proto == "netlink-raw": 471 self.nlproto = NetlinkProtocol(self.yaml['name'], 472 self.yaml['protonum']) 473 else: 474 self.nlproto = GenlProtocol(self.yaml['name']) 475 except KeyError: 476 raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel") 477 478 self._recv_dbg = False 479 # Note that netlink will use conservative (min) message size for 480 # the first dump recv() on the socket, our setting will only matter 481 # from the second recv() on. 482 self._recv_size = recv_size if recv_size else 131072 483 # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo) 484 # for a message, so smaller receive sizes will lead to truncation. 485 # Note that the min size for other families may be larger than 4k! 486 if self._recv_size < 4000: 487 raise ConfigError() 488 489 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num) 490 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 491 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) 492 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1) 493 494 self.async_msg_ids = set() 495 self.async_msg_queue = queue.Queue() 496 497 for msg in self.msgs.values(): 498 if msg.is_async: 499 self.async_msg_ids.add(msg.rsp_value) 500 501 for op_name, op in self.ops.items(): 502 bound_f = functools.partial(self._op, op_name) 503 setattr(self, op.ident_name, bound_f) 504 505 506 def ntf_subscribe(self, mcast_name): 507 mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups) 508 self.sock.bind((0, 0)) 509 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, 510 mcast_id) 511 512 def set_recv_dbg(self, enabled): 513 self._recv_dbg = enabled 514 515 def _recv_dbg_print(self, reply, nl_msgs): 516 if not self._recv_dbg: 517 return 518 print("Recv: read", len(reply), "bytes,", 519 len(nl_msgs.msgs), "messages", file=sys.stderr) 520 for nl_msg in nl_msgs: 521 print(" ", nl_msg, file=sys.stderr) 522 523 def _encode_enum(self, attr_spec, value): 524 enum = self.consts[attr_spec['enum']] 525 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 526 scalar = 0 527 if isinstance(value, str): 528 value = [value] 529 for single_value in value: 530 scalar += enum.entries[single_value].user_value(as_flags = True) 531 return scalar 532 else: 533 return enum.entries[value].user_value() 534 535 def _get_scalar(self, attr_spec, value): 536 try: 537 return int(value) 538 except (ValueError, TypeError) as e: 539 if 'enum' not in attr_spec: 540 raise e 541 return self._encode_enum(attr_spec, value) 542 543 def _add_attr(self, space, name, value, search_attrs): 544 try: 545 attr = self.attr_sets[space][name] 546 except KeyError: 547 raise Exception(f"Space '{space}' has no attribute '{name}'") 548 nl_type = attr.value 549 550 if attr.is_multi and isinstance(value, list): 551 attr_payload = b'' 552 for subvalue in value: 553 attr_payload += self._add_attr(space, name, subvalue, search_attrs) 554 return attr_payload 555 556 if attr["type"] == 'nest': 557 nl_type |= Netlink.NLA_F_NESTED 558 attr_payload = b'' 559 sub_attrs = SpaceAttrs(self.attr_sets[space], value, search_attrs) 560 for subname, subvalue in value.items(): 561 attr_payload += self._add_attr(attr['nested-attributes'], 562 subname, subvalue, sub_attrs) 563 elif attr["type"] == 'flag': 564 if not value: 565 # If value is absent or false then skip attribute creation. 566 return b'' 567 attr_payload = b'' 568 elif attr["type"] == 'string': 569 attr_payload = str(value).encode('ascii') + b'\x00' 570 elif attr["type"] == 'binary': 571 if isinstance(value, bytes): 572 attr_payload = value 573 elif isinstance(value, str): 574 attr_payload = bytes.fromhex(value) 575 elif isinstance(value, dict) and attr.struct_name: 576 attr_payload = self._encode_struct(attr.struct_name, value) 577 else: 578 raise Exception(f'Unknown type for binary attribute, value: {value}') 579 elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar: 580 scalar = self._get_scalar(attr, value) 581 if attr.is_auto_scalar: 582 attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64') 583 else: 584 attr_type = attr["type"] 585 format = NlAttr.get_format(attr_type, attr.byte_order) 586 attr_payload = format.pack(scalar) 587 elif attr['type'] in "bitfield32": 588 scalar_value = self._get_scalar(attr, value["value"]) 589 scalar_selector = self._get_scalar(attr, value["selector"]) 590 attr_payload = struct.pack("II", scalar_value, scalar_selector) 591 elif attr['type'] == 'sub-message': 592 msg_format = self._resolve_selector(attr, search_attrs) 593 attr_payload = b'' 594 if msg_format.fixed_header: 595 attr_payload += self._encode_struct(msg_format.fixed_header, value) 596 if msg_format.attr_set: 597 if msg_format.attr_set in self.attr_sets: 598 nl_type |= Netlink.NLA_F_NESTED 599 sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs) 600 for subname, subvalue in value.items(): 601 attr_payload += self._add_attr(msg_format.attr_set, 602 subname, subvalue, sub_attrs) 603 else: 604 raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'") 605 else: 606 raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') 607 608 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) 609 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad 610 611 def _decode_enum(self, raw, attr_spec): 612 enum = self.consts[attr_spec['enum']] 613 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 614 i = 0 615 value = set() 616 while raw: 617 if raw & 1: 618 value.add(enum.entries_by_val[i].name) 619 raw >>= 1 620 i += 1 621 else: 622 value = enum.entries_by_val[raw].name 623 return value 624 625 def _decode_binary(self, attr, attr_spec): 626 if attr_spec.struct_name: 627 decoded = self._decode_struct(attr.raw, attr_spec.struct_name) 628 elif attr_spec.sub_type: 629 decoded = attr.as_c_array(attr_spec.sub_type) 630 else: 631 decoded = attr.as_bin() 632 if attr_spec.display_hint: 633 decoded = self._formatted_string(decoded, attr_spec.display_hint) 634 return decoded 635 636 def _decode_array_attr(self, attr, attr_spec): 637 decoded = [] 638 offset = 0 639 while offset < len(attr.raw): 640 item = NlAttr(attr.raw, offset) 641 offset += item.full_len 642 643 if attr_spec["sub-type"] == 'nest': 644 subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes']) 645 decoded.append({ item.type: subattrs }) 646 elif attr_spec["sub-type"] == 'binary': 647 subattrs = item.as_bin() 648 if attr_spec.display_hint: 649 subattrs = self._formatted_string(subattrs, attr_spec.display_hint) 650 decoded.append(subattrs) 651 elif attr_spec["sub-type"] in NlAttr.type_formats: 652 subattrs = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order) 653 if attr_spec.display_hint: 654 subattrs = self._formatted_string(subattrs, attr_spec.display_hint) 655 decoded.append(subattrs) 656 else: 657 raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}') 658 return decoded 659 660 def _decode_nest_type_value(self, attr, attr_spec): 661 decoded = {} 662 value = attr 663 for name in attr_spec['type-value']: 664 value = NlAttr(value.raw, 0) 665 decoded[name] = value.type 666 subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes']) 667 decoded.update(subattrs) 668 return decoded 669 670 def _decode_unknown(self, attr): 671 if attr.is_nest: 672 return self._decode(NlAttrs(attr.raw), None) 673 else: 674 return attr.as_bin() 675 676 def _rsp_add(self, rsp, name, is_multi, decoded): 677 if is_multi == None: 678 if name in rsp and type(rsp[name]) is not list: 679 rsp[name] = [rsp[name]] 680 is_multi = True 681 else: 682 is_multi = False 683 684 if not is_multi: 685 rsp[name] = decoded 686 elif name in rsp: 687 rsp[name].append(decoded) 688 else: 689 rsp[name] = [decoded] 690 691 def _resolve_selector(self, attr_spec, search_attrs): 692 sub_msg = attr_spec.sub_message 693 if sub_msg not in self.sub_msgs: 694 raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}") 695 sub_msg_spec = self.sub_msgs[sub_msg] 696 697 selector = attr_spec.selector 698 value = search_attrs.lookup(selector) 699 if value not in sub_msg_spec.formats: 700 raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'") 701 702 spec = sub_msg_spec.formats[value] 703 return spec 704 705 def _decode_sub_msg(self, attr, attr_spec, search_attrs): 706 msg_format = self._resolve_selector(attr_spec, search_attrs) 707 decoded = {} 708 offset = 0 709 if msg_format.fixed_header: 710 decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header)); 711 offset = self._struct_size(msg_format.fixed_header) 712 if msg_format.attr_set: 713 if msg_format.attr_set in self.attr_sets: 714 subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set) 715 decoded.update(subdict) 716 else: 717 raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'") 718 return decoded 719 720 def _decode(self, attrs, space, outer_attrs = None): 721 rsp = dict() 722 if space: 723 attr_space = self.attr_sets[space] 724 search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs) 725 726 for attr in attrs: 727 try: 728 attr_spec = attr_space.attrs_by_val[attr.type] 729 except (KeyError, UnboundLocalError): 730 if not self.process_unknown: 731 raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'") 732 attr_name = f"UnknownAttr({attr.type})" 733 self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr)) 734 continue 735 736 if attr_spec["type"] == 'nest': 737 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs) 738 decoded = subdict 739 elif attr_spec["type"] == 'string': 740 decoded = attr.as_strz() 741 elif attr_spec["type"] == 'binary': 742 decoded = self._decode_binary(attr, attr_spec) 743 elif attr_spec["type"] == 'flag': 744 decoded = True 745 elif attr_spec.is_auto_scalar: 746 decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order) 747 elif attr_spec["type"] in NlAttr.type_formats: 748 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) 749 if 'enum' in attr_spec: 750 decoded = self._decode_enum(decoded, attr_spec) 751 elif attr_spec.display_hint: 752 decoded = self._formatted_string(decoded, attr_spec.display_hint) 753 elif attr_spec["type"] == 'indexed-array': 754 decoded = self._decode_array_attr(attr, attr_spec) 755 elif attr_spec["type"] == 'bitfield32': 756 value, selector = struct.unpack("II", attr.raw) 757 if 'enum' in attr_spec: 758 value = self._decode_enum(value, attr_spec) 759 selector = self._decode_enum(selector, attr_spec) 760 decoded = {"value": value, "selector": selector} 761 elif attr_spec["type"] == 'sub-message': 762 decoded = self._decode_sub_msg(attr, attr_spec, search_attrs) 763 elif attr_spec["type"] == 'nest-type-value': 764 decoded = self._decode_nest_type_value(attr, attr_spec) 765 else: 766 if not self.process_unknown: 767 raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') 768 decoded = self._decode_unknown(attr) 769 770 self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded) 771 772 return rsp 773 774 def _decode_extack_path(self, attrs, attr_set, offset, target): 775 for attr in attrs: 776 try: 777 attr_spec = attr_set.attrs_by_val[attr.type] 778 except KeyError: 779 raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") 780 if offset > target: 781 break 782 if offset == target: 783 return '.' + attr_spec.name 784 785 if offset + attr.full_len <= target: 786 offset += attr.full_len 787 continue 788 if attr_spec['type'] != 'nest': 789 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") 790 offset += 4 791 subpath = self._decode_extack_path(NlAttrs(attr.raw), 792 self.attr_sets[attr_spec['nested-attributes']], 793 offset, target) 794 if subpath is None: 795 return None 796 return '.' + attr_spec.name + subpath 797 798 return None 799 800 def _decode_extack(self, request, op, extack): 801 if 'bad-attr-offs' not in extack: 802 return 803 804 msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op) 805 offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header) 806 path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset, 807 extack['bad-attr-offs']) 808 if path: 809 del extack['bad-attr-offs'] 810 extack['bad-attr'] = path 811 812 def _struct_size(self, name): 813 if name: 814 members = self.consts[name].members 815 size = 0 816 for m in members: 817 if m.type in ['pad', 'binary']: 818 if m.struct: 819 size += self._struct_size(m.struct) 820 else: 821 size += m.len 822 else: 823 format = NlAttr.get_format(m.type, m.byte_order) 824 size += format.size 825 return size 826 else: 827 return 0 828 829 def _decode_struct(self, data, name): 830 members = self.consts[name].members 831 attrs = dict() 832 offset = 0 833 for m in members: 834 value = None 835 if m.type == 'pad': 836 offset += m.len 837 elif m.type == 'binary': 838 if m.struct: 839 len = self._struct_size(m.struct) 840 value = self._decode_struct(data[offset : offset + len], 841 m.struct) 842 offset += len 843 else: 844 value = data[offset : offset + m.len] 845 offset += m.len 846 else: 847 format = NlAttr.get_format(m.type, m.byte_order) 848 [ value ] = format.unpack_from(data, offset) 849 offset += format.size 850 if value is not None: 851 if m.enum: 852 value = self._decode_enum(value, m) 853 elif m.display_hint: 854 value = self._formatted_string(value, m.display_hint) 855 attrs[m.name] = value 856 return attrs 857 858 def _encode_struct(self, name, vals): 859 members = self.consts[name].members 860 attr_payload = b'' 861 for m in members: 862 value = vals.pop(m.name) if m.name in vals else None 863 if m.type == 'pad': 864 attr_payload += bytearray(m.len) 865 elif m.type == 'binary': 866 if m.struct: 867 if value is None: 868 value = dict() 869 attr_payload += self._encode_struct(m.struct, value) 870 else: 871 if value is None: 872 attr_payload += bytearray(m.len) 873 else: 874 attr_payload += bytes.fromhex(value) 875 else: 876 if value is None: 877 value = 0 878 format = NlAttr.get_format(m.type, m.byte_order) 879 attr_payload += format.pack(value) 880 return attr_payload 881 882 def _formatted_string(self, raw, display_hint): 883 if display_hint == 'mac': 884 formatted = ':'.join('%02x' % b for b in raw) 885 elif display_hint == 'hex': 886 if isinstance(raw, int): 887 formatted = hex(raw) 888 else: 889 formatted = bytes.hex(raw, ' ') 890 elif display_hint in [ 'ipv4', 'ipv6' ]: 891 formatted = format(ipaddress.ip_address(raw)) 892 elif display_hint == 'uuid': 893 formatted = str(uuid.UUID(bytes=raw)) 894 else: 895 formatted = raw 896 return formatted 897 898 def handle_ntf(self, decoded): 899 msg = dict() 900 if self.include_raw: 901 msg['raw'] = decoded 902 op = self.rsp_by_value[decoded.cmd()] 903 attrs = self._decode(decoded.raw_attrs, op.attr_set.name) 904 if op.fixed_header: 905 attrs.update(self._decode_struct(decoded.raw, op.fixed_header)) 906 907 msg['name'] = op['name'] 908 msg['msg'] = attrs 909 self.async_msg_queue.put(msg) 910 911 def check_ntf(self): 912 while True: 913 try: 914 reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT) 915 except BlockingIOError: 916 return 917 918 nms = NlMsgs(reply) 919 self._recv_dbg_print(reply, nms) 920 for nl_msg in nms: 921 if nl_msg.error: 922 print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) 923 print(nl_msg) 924 continue 925 if nl_msg.done: 926 print("Netlink done while checking for ntf!?") 927 continue 928 929 decoded = self.nlproto.decode(self, nl_msg, None) 930 if decoded.cmd() not in self.async_msg_ids: 931 print("Unexpected msg id while checking for ntf", decoded) 932 continue 933 934 self.handle_ntf(decoded) 935 936 def poll_ntf(self, duration=None): 937 start_time = time.time() 938 selector = selectors.DefaultSelector() 939 selector.register(self.sock, selectors.EVENT_READ) 940 941 while True: 942 try: 943 yield self.async_msg_queue.get_nowait() 944 except queue.Empty: 945 if duration is not None: 946 timeout = start_time + duration - time.time() 947 if timeout <= 0: 948 return 949 else: 950 timeout = None 951 events = selector.select(timeout) 952 if events: 953 self.check_ntf() 954 955 def operation_do_attributes(self, name): 956 """ 957 For a given operation name, find and return a supported 958 set of attributes (as a dict). 959 """ 960 op = self.find_operation(name) 961 if not op: 962 return None 963 964 return op['do']['request']['attributes'].copy() 965 966 def _encode_message(self, op, vals, flags, req_seq): 967 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK 968 for flag in flags or []: 969 nl_flags |= flag 970 971 msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq) 972 if op.fixed_header: 973 msg += self._encode_struct(op.fixed_header, vals) 974 search_attrs = SpaceAttrs(op.attr_set, vals) 975 for name, value in vals.items(): 976 msg += self._add_attr(op.attr_set.name, name, value, search_attrs) 977 msg = _genl_msg_finalize(msg) 978 return msg 979 980 def _ops(self, ops): 981 reqs_by_seq = {} 982 req_seq = random.randint(1024, 65535) 983 payload = b'' 984 for (method, vals, flags) in ops: 985 op = self.ops[method] 986 msg = self._encode_message(op, vals, flags, req_seq) 987 reqs_by_seq[req_seq] = (op, msg, flags) 988 payload += msg 989 req_seq += 1 990 991 self.sock.send(payload, 0) 992 993 done = False 994 rsp = [] 995 op_rsp = [] 996 while not done: 997 reply = self.sock.recv(self._recv_size) 998 nms = NlMsgs(reply, attr_space=op.attr_set) 999 self._recv_dbg_print(reply, nms) 1000 for nl_msg in nms: 1001 if nl_msg.nl_seq in reqs_by_seq: 1002 (op, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq] 1003 if nl_msg.extack: 1004 self._decode_extack(req_msg, op, nl_msg.extack) 1005 else: 1006 op = None 1007 req_flags = [] 1008 1009 if nl_msg.error: 1010 raise NlError(nl_msg) 1011 if nl_msg.done: 1012 if nl_msg.extack: 1013 print("Netlink warning:") 1014 print(nl_msg) 1015 1016 if Netlink.NLM_F_DUMP in req_flags: 1017 rsp.append(op_rsp) 1018 elif not op_rsp: 1019 rsp.append(None) 1020 elif len(op_rsp) == 1: 1021 rsp.append(op_rsp[0]) 1022 else: 1023 rsp.append(op_rsp) 1024 op_rsp = [] 1025 1026 del reqs_by_seq[nl_msg.nl_seq] 1027 done = len(reqs_by_seq) == 0 1028 break 1029 1030 decoded = self.nlproto.decode(self, nl_msg, op) 1031 1032 # Check if this is a reply to our request 1033 if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value: 1034 if decoded.cmd() in self.async_msg_ids: 1035 self.handle_ntf(decoded) 1036 continue 1037 else: 1038 print('Unexpected message: ' + repr(decoded)) 1039 continue 1040 1041 rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name) 1042 if op.fixed_header: 1043 rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header)) 1044 op_rsp.append(rsp_msg) 1045 1046 return rsp 1047 1048 def _op(self, method, vals, flags=None, dump=False): 1049 req_flags = flags or [] 1050 if dump: 1051 req_flags.append(Netlink.NLM_F_DUMP) 1052 1053 ops = [(method, vals, req_flags)] 1054 return self._ops(ops)[0] 1055 1056 def do(self, method, vals, flags=None): 1057 return self._op(method, vals, flags) 1058 1059 def dump(self, method, vals): 1060 return self._op(method, vals, dump=True) 1061 1062 def do_multi(self, ops): 1063 return self._ops(ops) 1064