1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause 2 3from collections import namedtuple 4from enum import Enum 5import functools 6import os 7import random 8import socket 9import struct 10from struct import Struct 11import sys 12import ipaddress 13import uuid 14import queue 15import selectors 16import time 17 18from .nlspec import SpecFamily 19 20# 21# Generic Netlink code which should really be in some library, but I can't quickly find one. 22# 23 24 25class Netlink: 26 # Netlink socket 27 SOL_NETLINK = 270 28 29 NETLINK_ADD_MEMBERSHIP = 1 30 NETLINK_CAP_ACK = 10 31 NETLINK_EXT_ACK = 11 32 NETLINK_GET_STRICT_CHK = 12 33 34 # Netlink message 35 NLMSG_ERROR = 2 36 NLMSG_DONE = 3 37 38 NLM_F_REQUEST = 1 39 NLM_F_ACK = 4 40 NLM_F_ROOT = 0x100 41 NLM_F_MATCH = 0x200 42 43 NLM_F_REPLACE = 0x100 44 NLM_F_EXCL = 0x200 45 NLM_F_CREATE = 0x400 46 NLM_F_APPEND = 0x800 47 48 NLM_F_CAPPED = 0x100 49 NLM_F_ACK_TLVS = 0x200 50 51 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH 52 53 NLA_F_NESTED = 0x8000 54 NLA_F_NET_BYTEORDER = 0x4000 55 56 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER 57 58 # Genetlink defines 59 NETLINK_GENERIC = 16 60 61 GENL_ID_CTRL = 0x10 62 63 # nlctrl 64 CTRL_CMD_GETFAMILY = 3 65 66 CTRL_ATTR_FAMILY_ID = 1 67 CTRL_ATTR_FAMILY_NAME = 2 68 CTRL_ATTR_MAXATTR = 5 69 CTRL_ATTR_MCAST_GROUPS = 7 70 71 CTRL_ATTR_MCAST_GRP_NAME = 1 72 CTRL_ATTR_MCAST_GRP_ID = 2 73 74 # Extack types 75 NLMSGERR_ATTR_MSG = 1 76 NLMSGERR_ATTR_OFFS = 2 77 NLMSGERR_ATTR_COOKIE = 3 78 NLMSGERR_ATTR_POLICY = 4 79 NLMSGERR_ATTR_MISS_TYPE = 5 80 NLMSGERR_ATTR_MISS_NEST = 6 81 82 # Policy types 83 NL_POLICY_TYPE_ATTR_TYPE = 1 84 NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2 85 NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3 86 NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4 87 NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5 88 NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6 89 NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7 90 NL_POLICY_TYPE_ATTR_POLICY_IDX = 8 91 NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9 92 NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10 93 NL_POLICY_TYPE_ATTR_PAD = 11 94 NL_POLICY_TYPE_ATTR_MASK = 12 95 96 AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64', 97 's8', 's16', 's32', 's64', 98 'binary', 'string', 'nul-string', 99 'nested', 'nested-array', 100 'bitfield32', 'sint', 'uint']) 101 102class NlError(Exception): 103 def __init__(self, nl_msg): 104 self.nl_msg = nl_msg 105 self.error = -nl_msg.error 106 107 def __str__(self): 108 msg = "Netlink error: " 109 110 extack = self.nl_msg.extack.copy() if self.nl_msg.extack else {} 111 if 'msg' in extack: 112 msg += extack['msg'] + ': ' 113 del extack['msg'] 114 msg += os.strerror(self.error) 115 if extack: 116 msg += ' ' + str(extack) 117 return msg 118 119 120class ConfigError(Exception): 121 pass 122 123 124class NlAttr: 125 ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) 126 type_formats = { 127 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), 128 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), 129 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")), 130 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")), 131 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")), 132 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")), 133 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")), 134 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q")) 135 } 136 137 def __init__(self, raw, offset): 138 self._len, self._type = struct.unpack("HH", raw[offset : offset + 4]) 139 self.type = self._type & ~Netlink.NLA_TYPE_MASK 140 self.is_nest = self._type & Netlink.NLA_F_NESTED 141 self.payload_len = self._len 142 self.full_len = (self.payload_len + 3) & ~3 143 self.raw = raw[offset + 4 : offset + self.payload_len] 144 145 @classmethod 146 def get_format(cls, attr_type, byte_order=None): 147 format = cls.type_formats[attr_type] 148 if byte_order: 149 return format.big if byte_order == "big-endian" \ 150 else format.little 151 return format.native 152 153 def as_scalar(self, attr_type, byte_order=None): 154 format = self.get_format(attr_type, byte_order) 155 return format.unpack(self.raw)[0] 156 157 def as_auto_scalar(self, attr_type, byte_order=None): 158 if len(self.raw) != 4 and len(self.raw) != 8: 159 raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}") 160 real_type = attr_type[0] + str(len(self.raw) * 8) 161 format = self.get_format(real_type, byte_order) 162 return format.unpack(self.raw)[0] 163 164 def as_strz(self): 165 return self.raw.decode('ascii')[:-1] 166 167 def as_bin(self): 168 return self.raw 169 170 def as_c_array(self, type): 171 format = self.get_format(type) 172 return [ x[0] for x in format.iter_unpack(self.raw) ] 173 174 def __repr__(self): 175 return f"[type:{self.type} len:{self._len}] {self.raw}" 176 177 178class NlAttrs: 179 def __init__(self, msg, offset=0): 180 self.attrs = [] 181 182 while offset < len(msg): 183 attr = NlAttr(msg, offset) 184 offset += attr.full_len 185 self.attrs.append(attr) 186 187 def __iter__(self): 188 yield from self.attrs 189 190 def __repr__(self): 191 msg = '' 192 for a in self.attrs: 193 if msg: 194 msg += '\n' 195 msg += repr(a) 196 return msg 197 198 199class NlMsg: 200 def __init__(self, msg, offset, attr_space=None): 201 self.hdr = msg[offset : offset + 16] 202 203 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ 204 struct.unpack("IHHII", self.hdr) 205 206 self.raw = msg[offset + 16 : offset + self.nl_len] 207 208 self.error = 0 209 self.done = 0 210 211 extack_off = None 212 if self.nl_type == Netlink.NLMSG_ERROR: 213 self.error = struct.unpack("i", self.raw[0:4])[0] 214 self.done = 1 215 extack_off = 20 216 elif self.nl_type == Netlink.NLMSG_DONE: 217 self.error = struct.unpack("i", self.raw[0:4])[0] 218 self.done = 1 219 extack_off = 4 220 221 self.extack = None 222 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: 223 self.extack = dict() 224 extack_attrs = NlAttrs(self.raw[extack_off:]) 225 for extack in extack_attrs: 226 if extack.type == Netlink.NLMSGERR_ATTR_MSG: 227 self.extack['msg'] = extack.as_strz() 228 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: 229 self.extack['miss-type'] = extack.as_scalar('u32') 230 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: 231 self.extack['miss-nest'] = extack.as_scalar('u32') 232 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: 233 self.extack['bad-attr-offs'] = extack.as_scalar('u32') 234 elif extack.type == Netlink.NLMSGERR_ATTR_POLICY: 235 self.extack['policy'] = self._decode_policy(extack.raw) 236 else: 237 if 'unknown' not in self.extack: 238 self.extack['unknown'] = [] 239 self.extack['unknown'].append(extack) 240 241 if attr_space: 242 self.annotate_extack(attr_space) 243 244 def _decode_policy(self, raw): 245 policy = {} 246 for attr in NlAttrs(raw): 247 if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE: 248 type = attr.as_scalar('u32') 249 policy['type'] = Netlink.AttrType(type).name 250 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S: 251 policy['min-value'] = attr.as_scalar('s64') 252 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S: 253 policy['max-value'] = attr.as_scalar('s64') 254 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U: 255 policy['min-value'] = attr.as_scalar('u64') 256 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U: 257 policy['max-value'] = attr.as_scalar('u64') 258 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH: 259 policy['min-length'] = attr.as_scalar('u32') 260 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH: 261 policy['max-length'] = attr.as_scalar('u32') 262 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK: 263 policy['bitfield32-mask'] = attr.as_scalar('u32') 264 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK: 265 policy['mask'] = attr.as_scalar('u64') 266 return policy 267 268 def annotate_extack(self, attr_space): 269 """ Make extack more human friendly with attribute information """ 270 271 # We don't have the ability to parse nests yet, so only do global 272 if 'miss-type' in self.extack and 'miss-nest' not in self.extack: 273 miss_type = self.extack['miss-type'] 274 if miss_type in attr_space.attrs_by_val: 275 spec = attr_space.attrs_by_val[miss_type] 276 self.extack['miss-type'] = spec['name'] 277 if 'doc' in spec: 278 self.extack['miss-type-doc'] = spec['doc'] 279 280 def cmd(self): 281 return self.nl_type 282 283 def __repr__(self): 284 msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}" 285 if self.error: 286 msg += '\n\terror: ' + str(self.error) 287 if self.extack: 288 msg += '\n\textack: ' + repr(self.extack) 289 return msg 290 291 292class NlMsgs: 293 def __init__(self, data): 294 self.msgs = [] 295 296 offset = 0 297 while offset < len(data): 298 msg = NlMsg(data, offset) 299 offset += msg.nl_len 300 self.msgs.append(msg) 301 302 def __iter__(self): 303 yield from self.msgs 304 305 306genl_family_name_to_id = None 307 308 309def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): 310 # we prepend length in _genl_msg_finalize() 311 if seq is None: 312 seq = random.randint(1, 1024) 313 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 314 genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0) 315 return nlmsg + genlmsg 316 317 318def _genl_msg_finalize(msg): 319 return struct.pack("I", len(msg) + 4) + msg 320 321 322def _genl_load_families(): 323 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: 324 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 325 326 msg = _genl_msg(Netlink.GENL_ID_CTRL, 327 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, 328 Netlink.CTRL_CMD_GETFAMILY, 1) 329 msg = _genl_msg_finalize(msg) 330 331 sock.send(msg, 0) 332 333 global genl_family_name_to_id 334 genl_family_name_to_id = dict() 335 336 while True: 337 reply = sock.recv(128 * 1024) 338 nms = NlMsgs(reply) 339 for nl_msg in nms: 340 if nl_msg.error: 341 print("Netlink error:", nl_msg.error) 342 return 343 if nl_msg.done: 344 return 345 346 gm = GenlMsg(nl_msg) 347 fam = dict() 348 for attr in NlAttrs(gm.raw): 349 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: 350 fam['id'] = attr.as_scalar('u16') 351 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: 352 fam['name'] = attr.as_strz() 353 elif attr.type == Netlink.CTRL_ATTR_MAXATTR: 354 fam['maxattr'] = attr.as_scalar('u32') 355 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: 356 fam['mcast'] = dict() 357 for entry in NlAttrs(attr.raw): 358 mcast_name = None 359 mcast_id = None 360 for entry_attr in NlAttrs(entry.raw): 361 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: 362 mcast_name = entry_attr.as_strz() 363 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: 364 mcast_id = entry_attr.as_scalar('u32') 365 if mcast_name and mcast_id is not None: 366 fam['mcast'][mcast_name] = mcast_id 367 if 'name' in fam and 'id' in fam: 368 genl_family_name_to_id[fam['name']] = fam 369 370 371class GenlMsg: 372 def __init__(self, nl_msg): 373 self.nl = nl_msg 374 self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0) 375 self.raw = nl_msg.raw[4:] 376 377 def cmd(self): 378 return self.genl_cmd 379 380 def __repr__(self): 381 msg = repr(self.nl) 382 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" 383 for a in self.raw_attrs: 384 msg += '\t\t' + repr(a) + '\n' 385 return msg 386 387 388class NetlinkProtocol: 389 def __init__(self, family_name, proto_num): 390 self.family_name = family_name 391 self.proto_num = proto_num 392 393 def _message(self, nl_type, nl_flags, seq=None): 394 if seq is None: 395 seq = random.randint(1, 1024) 396 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 397 return nlmsg 398 399 def message(self, flags, command, version, seq=None): 400 return self._message(command, flags, seq) 401 402 def _decode(self, nl_msg): 403 return nl_msg 404 405 def decode(self, ynl, nl_msg, op): 406 msg = self._decode(nl_msg) 407 if op is None: 408 op = ynl.rsp_by_value[msg.cmd()] 409 fixed_header_size = ynl._struct_size(op.fixed_header) 410 msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size) 411 return msg 412 413 def get_mcast_id(self, mcast_name, mcast_groups): 414 if mcast_name not in mcast_groups: 415 raise Exception(f'Multicast group "{mcast_name}" not present in the spec') 416 return mcast_groups[mcast_name].value 417 418 def msghdr_size(self): 419 return 16 420 421 422class GenlProtocol(NetlinkProtocol): 423 def __init__(self, family_name): 424 super().__init__(family_name, Netlink.NETLINK_GENERIC) 425 426 global genl_family_name_to_id 427 if genl_family_name_to_id is None: 428 _genl_load_families() 429 430 self.genl_family = genl_family_name_to_id[family_name] 431 self.family_id = genl_family_name_to_id[family_name]['id'] 432 433 def message(self, flags, command, version, seq=None): 434 nlmsg = self._message(self.family_id, flags, seq) 435 genlmsg = struct.pack("BBH", command, version, 0) 436 return nlmsg + genlmsg 437 438 def _decode(self, nl_msg): 439 return GenlMsg(nl_msg) 440 441 def get_mcast_id(self, mcast_name, mcast_groups): 442 if mcast_name not in self.genl_family['mcast']: 443 raise Exception(f'Multicast group "{mcast_name}" not present in the family') 444 return self.genl_family['mcast'][mcast_name] 445 446 def msghdr_size(self): 447 return super().msghdr_size() + 4 448 449 450class SpaceAttrs: 451 SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values']) 452 453 def __init__(self, attr_space, attrs, outer = None): 454 outer_scopes = outer.scopes if outer else [] 455 inner_scope = self.SpecValuesPair(attr_space, attrs) 456 self.scopes = [inner_scope] + outer_scopes 457 458 def lookup(self, name): 459 for scope in self.scopes: 460 if name in scope.spec: 461 if name in scope.values: 462 return scope.values[name] 463 spec_name = scope.spec.yaml['name'] 464 raise Exception( 465 f"No value for '{name}' in attribute space '{spec_name}'") 466 raise Exception(f"Attribute '{name}' not defined in any attribute-set") 467 468 469# 470# YNL implementation details. 471# 472 473 474class YnlFamily(SpecFamily): 475 def __init__(self, def_path, schema=None, process_unknown=False, 476 recv_size=0): 477 super().__init__(def_path, schema) 478 479 self.include_raw = False 480 self.process_unknown = process_unknown 481 482 try: 483 if self.proto == "netlink-raw": 484 self.nlproto = NetlinkProtocol(self.yaml['name'], 485 self.yaml['protonum']) 486 else: 487 self.nlproto = GenlProtocol(self.yaml['name']) 488 except KeyError: 489 raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel") 490 491 self._recv_dbg = False 492 # Note that netlink will use conservative (min) message size for 493 # the first dump recv() on the socket, our setting will only matter 494 # from the second recv() on. 495 self._recv_size = recv_size if recv_size else 131072 496 # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo) 497 # for a message, so smaller receive sizes will lead to truncation. 498 # Note that the min size for other families may be larger than 4k! 499 if self._recv_size < 4000: 500 raise ConfigError() 501 502 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num) 503 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 504 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) 505 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1) 506 507 self.async_msg_ids = set() 508 self.async_msg_queue = queue.Queue() 509 510 for msg in self.msgs.values(): 511 if msg.is_async: 512 self.async_msg_ids.add(msg.rsp_value) 513 514 for op_name, op in self.ops.items(): 515 bound_f = functools.partial(self._op, op_name) 516 setattr(self, op.ident_name, bound_f) 517 518 519 def ntf_subscribe(self, mcast_name): 520 mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups) 521 self.sock.bind((0, 0)) 522 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, 523 mcast_id) 524 525 def set_recv_dbg(self, enabled): 526 self._recv_dbg = enabled 527 528 def _recv_dbg_print(self, reply, nl_msgs): 529 if not self._recv_dbg: 530 return 531 print("Recv: read", len(reply), "bytes,", 532 len(nl_msgs.msgs), "messages", file=sys.stderr) 533 for nl_msg in nl_msgs: 534 print(" ", nl_msg, file=sys.stderr) 535 536 def _encode_enum(self, attr_spec, value): 537 enum = self.consts[attr_spec['enum']] 538 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 539 scalar = 0 540 if isinstance(value, str): 541 value = [value] 542 for single_value in value: 543 scalar += enum.entries[single_value].user_value(as_flags = True) 544 return scalar 545 else: 546 return enum.entries[value].user_value() 547 548 def _get_scalar(self, attr_spec, value): 549 try: 550 return int(value) 551 except (ValueError, TypeError) as e: 552 if 'enum' in attr_spec: 553 return self._encode_enum(attr_spec, value) 554 if attr_spec.display_hint: 555 return self._from_string(value, attr_spec) 556 raise e 557 558 def _add_attr(self, space, name, value, search_attrs): 559 try: 560 attr = self.attr_sets[space][name] 561 except KeyError: 562 raise Exception(f"Space '{space}' has no attribute '{name}'") 563 nl_type = attr.value 564 565 if attr.is_multi and isinstance(value, list): 566 attr_payload = b'' 567 for subvalue in value: 568 attr_payload += self._add_attr(space, name, subvalue, search_attrs) 569 return attr_payload 570 571 if attr["type"] == 'nest': 572 nl_type |= Netlink.NLA_F_NESTED 573 sub_space = attr['nested-attributes'] 574 attr_payload = self._add_nest_attrs(value, sub_space, search_attrs) 575 elif attr['type'] == 'indexed-array' and attr['sub-type'] == 'nest': 576 nl_type |= Netlink.NLA_F_NESTED 577 sub_space = attr['nested-attributes'] 578 attr_payload = self._encode_indexed_array(value, sub_space, 579 search_attrs) 580 elif attr["type"] == 'flag': 581 if not value: 582 # If value is absent or false then skip attribute creation. 583 return b'' 584 attr_payload = b'' 585 elif attr["type"] == 'string': 586 attr_payload = str(value).encode('ascii') + b'\x00' 587 elif attr["type"] == 'binary': 588 if value is None: 589 attr_payload = b'' 590 elif isinstance(value, bytes): 591 attr_payload = value 592 elif isinstance(value, str): 593 if attr.display_hint: 594 attr_payload = self._from_string(value, attr) 595 else: 596 attr_payload = bytes.fromhex(value) 597 elif isinstance(value, dict) and attr.struct_name: 598 attr_payload = self._encode_struct(attr.struct_name, value) 599 elif isinstance(value, list) and attr.sub_type in NlAttr.type_formats: 600 format = NlAttr.get_format(attr.sub_type) 601 attr_payload = b''.join([format.pack(x) for x in value]) 602 else: 603 raise Exception(f'Unknown type for binary attribute, value: {value}') 604 elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar: 605 scalar = self._get_scalar(attr, value) 606 if attr.is_auto_scalar: 607 attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64') 608 else: 609 attr_type = attr["type"] 610 format = NlAttr.get_format(attr_type, attr.byte_order) 611 attr_payload = format.pack(scalar) 612 elif attr['type'] in "bitfield32": 613 scalar_value = self._get_scalar(attr, value["value"]) 614 scalar_selector = self._get_scalar(attr, value["selector"]) 615 attr_payload = struct.pack("II", scalar_value, scalar_selector) 616 elif attr['type'] == 'sub-message': 617 msg_format, _ = self._resolve_selector(attr, search_attrs) 618 attr_payload = b'' 619 if msg_format.fixed_header: 620 attr_payload += self._encode_struct(msg_format.fixed_header, value) 621 if msg_format.attr_set: 622 if msg_format.attr_set in self.attr_sets: 623 nl_type |= Netlink.NLA_F_NESTED 624 sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs) 625 for subname, subvalue in value.items(): 626 attr_payload += self._add_attr(msg_format.attr_set, 627 subname, subvalue, sub_attrs) 628 else: 629 raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'") 630 else: 631 raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') 632 633 return self._add_attr_raw(nl_type, attr_payload) 634 635 def _add_attr_raw(self, nl_type, attr_payload): 636 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) 637 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad 638 639 def _add_nest_attrs(self, value, sub_space, search_attrs): 640 sub_attrs = SpaceAttrs(self.attr_sets[sub_space], value, search_attrs) 641 attr_payload = b'' 642 for subname, subvalue in value.items(): 643 attr_payload += self._add_attr(sub_space, subname, subvalue, 644 sub_attrs) 645 return attr_payload 646 647 def _encode_indexed_array(self, vals, sub_space, search_attrs): 648 attr_payload = b'' 649 for i, val in enumerate(vals): 650 idx = i | Netlink.NLA_F_NESTED 651 val_payload = self._add_nest_attrs(val, sub_space, search_attrs) 652 attr_payload += self._add_attr_raw(idx, val_payload) 653 return attr_payload 654 655 def _get_enum_or_unknown(self, enum, raw): 656 try: 657 name = enum.entries_by_val[raw].name 658 except KeyError as error: 659 if self.process_unknown: 660 name = f"Unknown({raw})" 661 else: 662 raise error 663 return name 664 665 def _decode_enum(self, raw, attr_spec): 666 enum = self.consts[attr_spec['enum']] 667 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 668 i = 0 669 value = set() 670 while raw: 671 if raw & 1: 672 value.add(self._get_enum_or_unknown(enum, i)) 673 raw >>= 1 674 i += 1 675 else: 676 value = self._get_enum_or_unknown(enum, raw) 677 return value 678 679 def _decode_binary(self, attr, attr_spec): 680 if attr_spec.struct_name: 681 decoded = self._decode_struct(attr.raw, attr_spec.struct_name) 682 elif attr_spec.sub_type: 683 decoded = attr.as_c_array(attr_spec.sub_type) 684 if 'enum' in attr_spec: 685 decoded = [ self._decode_enum(x, attr_spec) for x in decoded ] 686 elif attr_spec.display_hint: 687 decoded = [ self._formatted_string(x, attr_spec.display_hint) 688 for x in decoded ] 689 else: 690 decoded = attr.as_bin() 691 if attr_spec.display_hint: 692 decoded = self._formatted_string(decoded, attr_spec.display_hint) 693 return decoded 694 695 def _decode_array_attr(self, attr, attr_spec): 696 decoded = [] 697 offset = 0 698 while offset < len(attr.raw): 699 item = NlAttr(attr.raw, offset) 700 offset += item.full_len 701 702 if attr_spec["sub-type"] == 'nest': 703 subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes']) 704 decoded.append({ item.type: subattrs }) 705 elif attr_spec["sub-type"] == 'binary': 706 subattr = item.as_bin() 707 if attr_spec.display_hint: 708 subattr = self._formatted_string(subattr, attr_spec.display_hint) 709 decoded.append(subattr) 710 elif attr_spec["sub-type"] in NlAttr.type_formats: 711 subattr = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order) 712 if 'enum' in attr_spec: 713 subattr = self._decode_enum(subattr, attr_spec) 714 elif attr_spec.display_hint: 715 subattr = self._formatted_string(subattr, attr_spec.display_hint) 716 decoded.append(subattr) 717 else: 718 raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}') 719 return decoded 720 721 def _decode_nest_type_value(self, attr, attr_spec): 722 decoded = {} 723 value = attr 724 for name in attr_spec['type-value']: 725 value = NlAttr(value.raw, 0) 726 decoded[name] = value.type 727 subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes']) 728 decoded.update(subattrs) 729 return decoded 730 731 def _decode_unknown(self, attr): 732 if attr.is_nest: 733 return self._decode(NlAttrs(attr.raw), None) 734 else: 735 return attr.as_bin() 736 737 def _rsp_add(self, rsp, name, is_multi, decoded): 738 if is_multi is None: 739 if name in rsp and type(rsp[name]) is not list: 740 rsp[name] = [rsp[name]] 741 is_multi = True 742 else: 743 is_multi = False 744 745 if not is_multi: 746 rsp[name] = decoded 747 elif name in rsp: 748 rsp[name].append(decoded) 749 else: 750 rsp[name] = [decoded] 751 752 def _resolve_selector(self, attr_spec, search_attrs): 753 sub_msg = attr_spec.sub_message 754 if sub_msg not in self.sub_msgs: 755 raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}") 756 sub_msg_spec = self.sub_msgs[sub_msg] 757 758 selector = attr_spec.selector 759 value = search_attrs.lookup(selector) 760 if value not in sub_msg_spec.formats: 761 raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'") 762 763 spec = sub_msg_spec.formats[value] 764 return spec, value 765 766 def _decode_sub_msg(self, attr, attr_spec, search_attrs): 767 msg_format, _ = self._resolve_selector(attr_spec, search_attrs) 768 decoded = {} 769 offset = 0 770 if msg_format.fixed_header: 771 decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header)) 772 offset = self._struct_size(msg_format.fixed_header) 773 if msg_format.attr_set: 774 if msg_format.attr_set in self.attr_sets: 775 subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set) 776 decoded.update(subdict) 777 else: 778 raise Exception(f"Unknown attribute-set '{msg_format.attr_set}' when decoding '{attr_spec.name}'") 779 return decoded 780 781 def _decode(self, attrs, space, outer_attrs = None): 782 rsp = dict() 783 if space: 784 attr_space = self.attr_sets[space] 785 search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs) 786 787 for attr in attrs: 788 try: 789 attr_spec = attr_space.attrs_by_val[attr.type] 790 except (KeyError, UnboundLocalError): 791 if not self.process_unknown: 792 raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'") 793 attr_name = f"UnknownAttr({attr.type})" 794 self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr)) 795 continue 796 797 try: 798 if attr_spec["type"] == 'nest': 799 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs) 800 decoded = subdict 801 elif attr_spec["type"] == 'string': 802 decoded = attr.as_strz() 803 elif attr_spec["type"] == 'binary': 804 decoded = self._decode_binary(attr, attr_spec) 805 elif attr_spec["type"] == 'flag': 806 decoded = True 807 elif attr_spec.is_auto_scalar: 808 decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order) 809 if 'enum' in attr_spec: 810 decoded = self._decode_enum(decoded, attr_spec) 811 elif attr_spec["type"] in NlAttr.type_formats: 812 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) 813 if 'enum' in attr_spec: 814 decoded = self._decode_enum(decoded, attr_spec) 815 elif attr_spec.display_hint: 816 decoded = self._formatted_string(decoded, attr_spec.display_hint) 817 elif attr_spec["type"] == 'indexed-array': 818 decoded = self._decode_array_attr(attr, attr_spec) 819 elif attr_spec["type"] == 'bitfield32': 820 value, selector = struct.unpack("II", attr.raw) 821 if 'enum' in attr_spec: 822 value = self._decode_enum(value, attr_spec) 823 selector = self._decode_enum(selector, attr_spec) 824 decoded = {"value": value, "selector": selector} 825 elif attr_spec["type"] == 'sub-message': 826 decoded = self._decode_sub_msg(attr, attr_spec, search_attrs) 827 elif attr_spec["type"] == 'nest-type-value': 828 decoded = self._decode_nest_type_value(attr, attr_spec) 829 else: 830 if not self.process_unknown: 831 raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') 832 decoded = self._decode_unknown(attr) 833 834 self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded) 835 except: 836 print(f"Error decoding '{attr_spec.name}' from '{space}'") 837 raise 838 839 return rsp 840 841 def _decode_extack_path(self, attrs, attr_set, offset, target, search_attrs): 842 for attr in attrs: 843 try: 844 attr_spec = attr_set.attrs_by_val[attr.type] 845 except KeyError: 846 raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") 847 if offset > target: 848 break 849 if offset == target: 850 return '.' + attr_spec.name 851 852 if offset + attr.full_len <= target: 853 offset += attr.full_len 854 continue 855 856 pathname = attr_spec.name 857 if attr_spec['type'] == 'nest': 858 sub_attrs = self.attr_sets[attr_spec['nested-attributes']] 859 search_attrs = SpaceAttrs(sub_attrs, search_attrs.lookup(attr_spec['name'])) 860 elif attr_spec['type'] == 'sub-message': 861 msg_format, value = self._resolve_selector(attr_spec, search_attrs) 862 if msg_format is None: 863 raise Exception(f"Can't resolve sub-message of {attr_spec['name']} for extack") 864 sub_attrs = self.attr_sets[msg_format.attr_set] 865 pathname += f"({value})" 866 else: 867 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") 868 offset += 4 869 subpath = self._decode_extack_path(NlAttrs(attr.raw), sub_attrs, 870 offset, target, search_attrs) 871 if subpath is None: 872 return None 873 return '.' + pathname + subpath 874 875 return None 876 877 def _decode_extack(self, request, op, extack, vals): 878 if 'bad-attr-offs' not in extack: 879 return 880 881 msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op) 882 offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header) 883 search_attrs = SpaceAttrs(op.attr_set, vals) 884 path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset, 885 extack['bad-attr-offs'], search_attrs) 886 if path: 887 del extack['bad-attr-offs'] 888 extack['bad-attr'] = path 889 890 def _struct_size(self, name): 891 if name: 892 members = self.consts[name].members 893 size = 0 894 for m in members: 895 if m.type in ['pad', 'binary']: 896 if m.struct: 897 size += self._struct_size(m.struct) 898 else: 899 size += m.len 900 else: 901 format = NlAttr.get_format(m.type, m.byte_order) 902 size += format.size 903 return size 904 else: 905 return 0 906 907 def _decode_struct(self, data, name): 908 members = self.consts[name].members 909 attrs = dict() 910 offset = 0 911 for m in members: 912 value = None 913 if m.type == 'pad': 914 offset += m.len 915 elif m.type == 'binary': 916 if m.struct: 917 len = self._struct_size(m.struct) 918 value = self._decode_struct(data[offset : offset + len], 919 m.struct) 920 offset += len 921 else: 922 value = data[offset : offset + m.len] 923 offset += m.len 924 else: 925 format = NlAttr.get_format(m.type, m.byte_order) 926 [ value ] = format.unpack_from(data, offset) 927 offset += format.size 928 if value is not None: 929 if m.enum: 930 value = self._decode_enum(value, m) 931 elif m.display_hint: 932 value = self._formatted_string(value, m.display_hint) 933 attrs[m.name] = value 934 return attrs 935 936 def _encode_struct(self, name, vals): 937 members = self.consts[name].members 938 attr_payload = b'' 939 for m in members: 940 value = vals.pop(m.name) if m.name in vals else None 941 if m.type == 'pad': 942 attr_payload += bytearray(m.len) 943 elif m.type == 'binary': 944 if m.struct: 945 if value is None: 946 value = dict() 947 attr_payload += self._encode_struct(m.struct, value) 948 else: 949 if value is None: 950 attr_payload += bytearray(m.len) 951 else: 952 attr_payload += bytes.fromhex(value) 953 else: 954 if value is None: 955 value = 0 956 format = NlAttr.get_format(m.type, m.byte_order) 957 attr_payload += format.pack(value) 958 return attr_payload 959 960 def _formatted_string(self, raw, display_hint): 961 if display_hint == 'mac': 962 formatted = ':'.join('%02x' % b for b in raw) 963 elif display_hint == 'hex': 964 if isinstance(raw, int): 965 formatted = hex(raw) 966 else: 967 formatted = bytes.hex(raw, ' ') 968 elif display_hint in [ 'ipv4', 'ipv6', 'ipv4-or-v6' ]: 969 formatted = format(ipaddress.ip_address(raw)) 970 elif display_hint == 'uuid': 971 formatted = str(uuid.UUID(bytes=raw)) 972 else: 973 formatted = raw 974 return formatted 975 976 def _from_string(self, string, attr_spec): 977 if attr_spec.display_hint in ['ipv4', 'ipv6', 'ipv4-or-v6']: 978 ip = ipaddress.ip_address(string) 979 if attr_spec['type'] == 'binary': 980 raw = ip.packed 981 else: 982 raw = int(ip) 983 elif attr_spec.display_hint == 'hex': 984 if attr_spec['type'] == 'binary': 985 raw = bytes.fromhex(string) 986 else: 987 raw = int(string, 16) 988 elif attr_spec.display_hint == 'mac': 989 # Parse MAC address in format "00:11:22:33:44:55" or "001122334455" 990 if ':' in string: 991 mac_bytes = [int(x, 16) for x in string.split(':')] 992 else: 993 if len(string) % 2 != 0: 994 raise Exception(f"Invalid MAC address format: {string}") 995 mac_bytes = [int(string[i:i+2], 16) for i in range(0, len(string), 2)] 996 raw = bytes(mac_bytes) 997 else: 998 raise Exception(f"Display hint '{attr_spec.display_hint}' not implemented" 999 f" when parsing '{attr_spec['name']}'") 1000 return raw 1001 1002 def handle_ntf(self, decoded): 1003 msg = dict() 1004 if self.include_raw: 1005 msg['raw'] = decoded 1006 op = self.rsp_by_value[decoded.cmd()] 1007 attrs = self._decode(decoded.raw_attrs, op.attr_set.name) 1008 if op.fixed_header: 1009 attrs.update(self._decode_struct(decoded.raw, op.fixed_header)) 1010 1011 msg['name'] = op['name'] 1012 msg['msg'] = attrs 1013 self.async_msg_queue.put(msg) 1014 1015 def check_ntf(self): 1016 while True: 1017 try: 1018 reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT) 1019 except BlockingIOError: 1020 return 1021 1022 nms = NlMsgs(reply) 1023 self._recv_dbg_print(reply, nms) 1024 for nl_msg in nms: 1025 if nl_msg.error: 1026 print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) 1027 print(nl_msg) 1028 continue 1029 if nl_msg.done: 1030 print("Netlink done while checking for ntf!?") 1031 continue 1032 1033 decoded = self.nlproto.decode(self, nl_msg, None) 1034 if decoded.cmd() not in self.async_msg_ids: 1035 print("Unexpected msg id while checking for ntf", decoded) 1036 continue 1037 1038 self.handle_ntf(decoded) 1039 1040 def poll_ntf(self, duration=None): 1041 start_time = time.time() 1042 selector = selectors.DefaultSelector() 1043 selector.register(self.sock, selectors.EVENT_READ) 1044 1045 while True: 1046 try: 1047 yield self.async_msg_queue.get_nowait() 1048 except queue.Empty: 1049 if duration is not None: 1050 timeout = start_time + duration - time.time() 1051 if timeout <= 0: 1052 return 1053 else: 1054 timeout = None 1055 events = selector.select(timeout) 1056 if events: 1057 self.check_ntf() 1058 1059 def operation_do_attributes(self, name): 1060 """ 1061 For a given operation name, find and return a supported 1062 set of attributes (as a dict). 1063 """ 1064 op = self.find_operation(name) 1065 if not op: 1066 return None 1067 1068 return op['do']['request']['attributes'].copy() 1069 1070 def _encode_message(self, op, vals, flags, req_seq): 1071 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK 1072 for flag in flags or []: 1073 nl_flags |= flag 1074 1075 msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq) 1076 if op.fixed_header: 1077 msg += self._encode_struct(op.fixed_header, vals) 1078 search_attrs = SpaceAttrs(op.attr_set, vals) 1079 for name, value in vals.items(): 1080 msg += self._add_attr(op.attr_set.name, name, value, search_attrs) 1081 msg = _genl_msg_finalize(msg) 1082 return msg 1083 1084 def _ops(self, ops): 1085 reqs_by_seq = {} 1086 req_seq = random.randint(1024, 65535) 1087 payload = b'' 1088 for (method, vals, flags) in ops: 1089 op = self.ops[method] 1090 msg = self._encode_message(op, vals, flags, req_seq) 1091 reqs_by_seq[req_seq] = (op, vals, msg, flags) 1092 payload += msg 1093 req_seq += 1 1094 1095 self.sock.send(payload, 0) 1096 1097 done = False 1098 rsp = [] 1099 op_rsp = [] 1100 while not done: 1101 reply = self.sock.recv(self._recv_size) 1102 nms = NlMsgs(reply) 1103 self._recv_dbg_print(reply, nms) 1104 for nl_msg in nms: 1105 if nl_msg.nl_seq in reqs_by_seq: 1106 (op, vals, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq] 1107 if nl_msg.extack: 1108 nl_msg.annotate_extack(op.attr_set) 1109 self._decode_extack(req_msg, op, nl_msg.extack, vals) 1110 else: 1111 op = None 1112 req_flags = [] 1113 1114 if nl_msg.error: 1115 raise NlError(nl_msg) 1116 if nl_msg.done: 1117 if nl_msg.extack: 1118 print("Netlink warning:") 1119 print(nl_msg) 1120 1121 if Netlink.NLM_F_DUMP in req_flags: 1122 rsp.append(op_rsp) 1123 elif not op_rsp: 1124 rsp.append(None) 1125 elif len(op_rsp) == 1: 1126 rsp.append(op_rsp[0]) 1127 else: 1128 rsp.append(op_rsp) 1129 op_rsp = [] 1130 1131 del reqs_by_seq[nl_msg.nl_seq] 1132 done = len(reqs_by_seq) == 0 1133 break 1134 1135 decoded = self.nlproto.decode(self, nl_msg, op) 1136 1137 # Check if this is a reply to our request 1138 if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value: 1139 if decoded.cmd() in self.async_msg_ids: 1140 self.handle_ntf(decoded) 1141 continue 1142 else: 1143 print('Unexpected message: ' + repr(decoded)) 1144 continue 1145 1146 rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name) 1147 if op.fixed_header: 1148 rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header)) 1149 op_rsp.append(rsp_msg) 1150 1151 return rsp 1152 1153 def _op(self, method, vals, flags=None, dump=False): 1154 req_flags = flags or [] 1155 if dump: 1156 req_flags.append(Netlink.NLM_F_DUMP) 1157 1158 ops = [(method, vals, req_flags)] 1159 return self._ops(ops)[0] 1160 1161 def do(self, method, vals, flags=None): 1162 return self._op(method, vals, flags) 1163 1164 def dump(self, method, vals): 1165 return self._op(method, vals, dump=True) 1166 1167 def do_multi(self, ops): 1168 return self._ops(ops) 1169