1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause 2 3from collections import namedtuple 4from enum import Enum 5import functools 6import os 7import random 8import socket 9import struct 10from struct import Struct 11import sys 12import yaml 13import ipaddress 14import uuid 15import queue 16import selectors 17import time 18 19from .nlspec import SpecFamily 20 21# 22# Generic Netlink code which should really be in some library, but I can't quickly find one. 23# 24 25 26class Netlink: 27 # Netlink socket 28 SOL_NETLINK = 270 29 30 NETLINK_ADD_MEMBERSHIP = 1 31 NETLINK_CAP_ACK = 10 32 NETLINK_EXT_ACK = 11 33 NETLINK_GET_STRICT_CHK = 12 34 35 # Netlink message 36 NLMSG_ERROR = 2 37 NLMSG_DONE = 3 38 39 NLM_F_REQUEST = 1 40 NLM_F_ACK = 4 41 NLM_F_ROOT = 0x100 42 NLM_F_MATCH = 0x200 43 44 NLM_F_REPLACE = 0x100 45 NLM_F_EXCL = 0x200 46 NLM_F_CREATE = 0x400 47 NLM_F_APPEND = 0x800 48 49 NLM_F_CAPPED = 0x100 50 NLM_F_ACK_TLVS = 0x200 51 52 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH 53 54 NLA_F_NESTED = 0x8000 55 NLA_F_NET_BYTEORDER = 0x4000 56 57 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER 58 59 # Genetlink defines 60 NETLINK_GENERIC = 16 61 62 GENL_ID_CTRL = 0x10 63 64 # nlctrl 65 CTRL_CMD_GETFAMILY = 3 66 67 CTRL_ATTR_FAMILY_ID = 1 68 CTRL_ATTR_FAMILY_NAME = 2 69 CTRL_ATTR_MAXATTR = 5 70 CTRL_ATTR_MCAST_GROUPS = 7 71 72 CTRL_ATTR_MCAST_GRP_NAME = 1 73 CTRL_ATTR_MCAST_GRP_ID = 2 74 75 # Extack types 76 NLMSGERR_ATTR_MSG = 1 77 NLMSGERR_ATTR_OFFS = 2 78 NLMSGERR_ATTR_COOKIE = 3 79 NLMSGERR_ATTR_POLICY = 4 80 NLMSGERR_ATTR_MISS_TYPE = 5 81 NLMSGERR_ATTR_MISS_NEST = 6 82 83 # Policy types 84 NL_POLICY_TYPE_ATTR_TYPE = 1 85 NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2 86 NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3 87 NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4 88 NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5 89 NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6 90 NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7 91 NL_POLICY_TYPE_ATTR_POLICY_IDX = 8 92 NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9 93 NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10 94 NL_POLICY_TYPE_ATTR_PAD = 11 95 NL_POLICY_TYPE_ATTR_MASK = 12 96 97 AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64', 98 's8', 's16', 's32', 's64', 99 'binary', 'string', 'nul-string', 100 'nested', 'nested-array', 101 'bitfield32', 'sint', 'uint']) 102 103class NlError(Exception): 104 def __init__(self, nl_msg): 105 self.nl_msg = nl_msg 106 self.error = -nl_msg.error 107 108 def __str__(self): 109 return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}" 110 111 112class ConfigError(Exception): 113 pass 114 115 116class NlAttr: 117 ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) 118 type_formats = { 119 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), 120 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), 121 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")), 122 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")), 123 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")), 124 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")), 125 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")), 126 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q")) 127 } 128 129 def __init__(self, raw, offset): 130 self._len, self._type = struct.unpack("HH", raw[offset : offset + 4]) 131 self.type = self._type & ~Netlink.NLA_TYPE_MASK 132 self.is_nest = self._type & Netlink.NLA_F_NESTED 133 self.payload_len = self._len 134 self.full_len = (self.payload_len + 3) & ~3 135 self.raw = raw[offset + 4 : offset + self.payload_len] 136 137 @classmethod 138 def get_format(cls, attr_type, byte_order=None): 139 format = cls.type_formats[attr_type] 140 if byte_order: 141 return format.big if byte_order == "big-endian" \ 142 else format.little 143 return format.native 144 145 def as_scalar(self, attr_type, byte_order=None): 146 format = self.get_format(attr_type, byte_order) 147 return format.unpack(self.raw)[0] 148 149 def as_auto_scalar(self, attr_type, byte_order=None): 150 if len(self.raw) != 4 and len(self.raw) != 8: 151 raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}") 152 real_type = attr_type[0] + str(len(self.raw) * 8) 153 format = self.get_format(real_type, byte_order) 154 return format.unpack(self.raw)[0] 155 156 def as_strz(self): 157 return self.raw.decode('ascii')[:-1] 158 159 def as_bin(self): 160 return self.raw 161 162 def as_c_array(self, type): 163 format = self.get_format(type) 164 return [ x[0] for x in format.iter_unpack(self.raw) ] 165 166 def __repr__(self): 167 return f"[type:{self.type} len:{self._len}] {self.raw}" 168 169 170class NlAttrs: 171 def __init__(self, msg, offset=0): 172 self.attrs = [] 173 174 while offset < len(msg): 175 attr = NlAttr(msg, offset) 176 offset += attr.full_len 177 self.attrs.append(attr) 178 179 def __iter__(self): 180 yield from self.attrs 181 182 def __repr__(self): 183 msg = '' 184 for a in self.attrs: 185 if msg: 186 msg += '\n' 187 msg += repr(a) 188 return msg 189 190 191class NlMsg: 192 def __init__(self, msg, offset, attr_space=None): 193 self.hdr = msg[offset : offset + 16] 194 195 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ 196 struct.unpack("IHHII", self.hdr) 197 198 self.raw = msg[offset + 16 : offset + self.nl_len] 199 200 self.error = 0 201 self.done = 0 202 203 extack_off = None 204 if self.nl_type == Netlink.NLMSG_ERROR: 205 self.error = struct.unpack("i", self.raw[0:4])[0] 206 self.done = 1 207 extack_off = 20 208 elif self.nl_type == Netlink.NLMSG_DONE: 209 self.error = struct.unpack("i", self.raw[0:4])[0] 210 self.done = 1 211 extack_off = 4 212 213 self.extack = None 214 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: 215 self.extack = dict() 216 extack_attrs = NlAttrs(self.raw[extack_off:]) 217 for extack in extack_attrs: 218 if extack.type == Netlink.NLMSGERR_ATTR_MSG: 219 self.extack['msg'] = extack.as_strz() 220 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: 221 self.extack['miss-type'] = extack.as_scalar('u32') 222 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: 223 self.extack['miss-nest'] = extack.as_scalar('u32') 224 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: 225 self.extack['bad-attr-offs'] = extack.as_scalar('u32') 226 elif extack.type == Netlink.NLMSGERR_ATTR_POLICY: 227 self.extack['policy'] = self._decode_policy(extack.raw) 228 else: 229 if 'unknown' not in self.extack: 230 self.extack['unknown'] = [] 231 self.extack['unknown'].append(extack) 232 233 if attr_space: 234 # We don't have the ability to parse nests yet, so only do global 235 if 'miss-type' in self.extack and 'miss-nest' not in self.extack: 236 miss_type = self.extack['miss-type'] 237 if miss_type in attr_space.attrs_by_val: 238 spec = attr_space.attrs_by_val[miss_type] 239 self.extack['miss-type'] = spec['name'] 240 if 'doc' in spec: 241 self.extack['miss-type-doc'] = spec['doc'] 242 243 def _decode_policy(self, raw): 244 policy = {} 245 for attr in NlAttrs(raw): 246 if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE: 247 type = attr.as_scalar('u32') 248 policy['type'] = Netlink.AttrType(type).name 249 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S: 250 policy['min-value'] = attr.as_scalar('s64') 251 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S: 252 policy['max-value'] = attr.as_scalar('s64') 253 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U: 254 policy['min-value'] = attr.as_scalar('u64') 255 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U: 256 policy['max-value'] = attr.as_scalar('u64') 257 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH: 258 policy['min-length'] = attr.as_scalar('u32') 259 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH: 260 policy['max-length'] = attr.as_scalar('u32') 261 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK: 262 policy['bitfield32-mask'] = attr.as_scalar('u32') 263 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK: 264 policy['mask'] = attr.as_scalar('u64') 265 return policy 266 267 def cmd(self): 268 return self.nl_type 269 270 def __repr__(self): 271 msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}" 272 if self.error: 273 msg += '\n\terror: ' + str(self.error) 274 if self.extack: 275 msg += '\n\textack: ' + repr(self.extack) 276 return msg 277 278 279class NlMsgs: 280 def __init__(self, data, attr_space=None): 281 self.msgs = [] 282 283 offset = 0 284 while offset < len(data): 285 msg = NlMsg(data, offset, attr_space=attr_space) 286 offset += msg.nl_len 287 self.msgs.append(msg) 288 289 def __iter__(self): 290 yield from self.msgs 291 292 293genl_family_name_to_id = None 294 295 296def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): 297 # we prepend length in _genl_msg_finalize() 298 if seq is None: 299 seq = random.randint(1, 1024) 300 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 301 genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0) 302 return nlmsg + genlmsg 303 304 305def _genl_msg_finalize(msg): 306 return struct.pack("I", len(msg) + 4) + msg 307 308 309def _genl_load_families(): 310 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: 311 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 312 313 msg = _genl_msg(Netlink.GENL_ID_CTRL, 314 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, 315 Netlink.CTRL_CMD_GETFAMILY, 1) 316 msg = _genl_msg_finalize(msg) 317 318 sock.send(msg, 0) 319 320 global genl_family_name_to_id 321 genl_family_name_to_id = dict() 322 323 while True: 324 reply = sock.recv(128 * 1024) 325 nms = NlMsgs(reply) 326 for nl_msg in nms: 327 if nl_msg.error: 328 print("Netlink error:", nl_msg.error) 329 return 330 if nl_msg.done: 331 return 332 333 gm = GenlMsg(nl_msg) 334 fam = dict() 335 for attr in NlAttrs(gm.raw): 336 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: 337 fam['id'] = attr.as_scalar('u16') 338 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: 339 fam['name'] = attr.as_strz() 340 elif attr.type == Netlink.CTRL_ATTR_MAXATTR: 341 fam['maxattr'] = attr.as_scalar('u32') 342 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: 343 fam['mcast'] = dict() 344 for entry in NlAttrs(attr.raw): 345 mcast_name = None 346 mcast_id = None 347 for entry_attr in NlAttrs(entry.raw): 348 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: 349 mcast_name = entry_attr.as_strz() 350 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: 351 mcast_id = entry_attr.as_scalar('u32') 352 if mcast_name and mcast_id is not None: 353 fam['mcast'][mcast_name] = mcast_id 354 if 'name' in fam and 'id' in fam: 355 genl_family_name_to_id[fam['name']] = fam 356 357 358class GenlMsg: 359 def __init__(self, nl_msg): 360 self.nl = nl_msg 361 self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0) 362 self.raw = nl_msg.raw[4:] 363 364 def cmd(self): 365 return self.genl_cmd 366 367 def __repr__(self): 368 msg = repr(self.nl) 369 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" 370 for a in self.raw_attrs: 371 msg += '\t\t' + repr(a) + '\n' 372 return msg 373 374 375class NetlinkProtocol: 376 def __init__(self, family_name, proto_num): 377 self.family_name = family_name 378 self.proto_num = proto_num 379 380 def _message(self, nl_type, nl_flags, seq=None): 381 if seq is None: 382 seq = random.randint(1, 1024) 383 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 384 return nlmsg 385 386 def message(self, flags, command, version, seq=None): 387 return self._message(command, flags, seq) 388 389 def _decode(self, nl_msg): 390 return nl_msg 391 392 def decode(self, ynl, nl_msg, op): 393 msg = self._decode(nl_msg) 394 if op is None: 395 op = ynl.rsp_by_value[msg.cmd()] 396 fixed_header_size = ynl._struct_size(op.fixed_header) 397 msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size) 398 return msg 399 400 def get_mcast_id(self, mcast_name, mcast_groups): 401 if mcast_name not in mcast_groups: 402 raise Exception(f'Multicast group "{mcast_name}" not present in the spec') 403 return mcast_groups[mcast_name].value 404 405 def msghdr_size(self): 406 return 16 407 408 409class GenlProtocol(NetlinkProtocol): 410 def __init__(self, family_name): 411 super().__init__(family_name, Netlink.NETLINK_GENERIC) 412 413 global genl_family_name_to_id 414 if genl_family_name_to_id is None: 415 _genl_load_families() 416 417 self.genl_family = genl_family_name_to_id[family_name] 418 self.family_id = genl_family_name_to_id[family_name]['id'] 419 420 def message(self, flags, command, version, seq=None): 421 nlmsg = self._message(self.family_id, flags, seq) 422 genlmsg = struct.pack("BBH", command, version, 0) 423 return nlmsg + genlmsg 424 425 def _decode(self, nl_msg): 426 return GenlMsg(nl_msg) 427 428 def get_mcast_id(self, mcast_name, mcast_groups): 429 if mcast_name not in self.genl_family['mcast']: 430 raise Exception(f'Multicast group "{mcast_name}" not present in the family') 431 return self.genl_family['mcast'][mcast_name] 432 433 def msghdr_size(self): 434 return super().msghdr_size() + 4 435 436 437class SpaceAttrs: 438 SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values']) 439 440 def __init__(self, attr_space, attrs, outer = None): 441 outer_scopes = outer.scopes if outer else [] 442 inner_scope = self.SpecValuesPair(attr_space, attrs) 443 self.scopes = [inner_scope] + outer_scopes 444 445 def lookup(self, name): 446 for scope in self.scopes: 447 if name in scope.spec: 448 if name in scope.values: 449 return scope.values[name] 450 spec_name = scope.spec.yaml['name'] 451 raise Exception( 452 f"No value for '{name}' in attribute space '{spec_name}'") 453 raise Exception(f"Attribute '{name}' not defined in any attribute-set") 454 455 456# 457# YNL implementation details. 458# 459 460 461class YnlFamily(SpecFamily): 462 def __init__(self, def_path, schema=None, process_unknown=False, 463 recv_size=0): 464 super().__init__(def_path, schema) 465 466 self.include_raw = False 467 self.process_unknown = process_unknown 468 469 try: 470 if self.proto == "netlink-raw": 471 self.nlproto = NetlinkProtocol(self.yaml['name'], 472 self.yaml['protonum']) 473 else: 474 self.nlproto = GenlProtocol(self.yaml['name']) 475 except KeyError: 476 raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel") 477 478 self._recv_dbg = False 479 # Note that netlink will use conservative (min) message size for 480 # the first dump recv() on the socket, our setting will only matter 481 # from the second recv() on. 482 self._recv_size = recv_size if recv_size else 131072 483 # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo) 484 # for a message, so smaller receive sizes will lead to truncation. 485 # Note that the min size for other families may be larger than 4k! 486 if self._recv_size < 4000: 487 raise ConfigError() 488 489 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num) 490 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 491 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) 492 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1) 493 494 self.async_msg_ids = set() 495 self.async_msg_queue = queue.Queue() 496 497 for msg in self.msgs.values(): 498 if msg.is_async: 499 self.async_msg_ids.add(msg.rsp_value) 500 501 for op_name, op in self.ops.items(): 502 bound_f = functools.partial(self._op, op_name) 503 setattr(self, op.ident_name, bound_f) 504 505 506 def ntf_subscribe(self, mcast_name): 507 mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups) 508 self.sock.bind((0, 0)) 509 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, 510 mcast_id) 511 512 def set_recv_dbg(self, enabled): 513 self._recv_dbg = enabled 514 515 def _recv_dbg_print(self, reply, nl_msgs): 516 if not self._recv_dbg: 517 return 518 print("Recv: read", len(reply), "bytes,", 519 len(nl_msgs.msgs), "messages", file=sys.stderr) 520 for nl_msg in nl_msgs: 521 print(" ", nl_msg, file=sys.stderr) 522 523 def _encode_enum(self, attr_spec, value): 524 enum = self.consts[attr_spec['enum']] 525 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 526 scalar = 0 527 if isinstance(value, str): 528 value = [value] 529 for single_value in value: 530 scalar += enum.entries[single_value].user_value(as_flags = True) 531 return scalar 532 else: 533 return enum.entries[value].user_value() 534 535 def _get_scalar(self, attr_spec, value): 536 try: 537 return int(value) 538 except (ValueError, TypeError) as e: 539 if 'enum' in attr_spec: 540 return self._encode_enum(attr_spec, value) 541 if attr_spec.display_hint: 542 return self._from_string(value, attr_spec) 543 raise e 544 545 def _add_attr(self, space, name, value, search_attrs): 546 try: 547 attr = self.attr_sets[space][name] 548 except KeyError: 549 raise Exception(f"Space '{space}' has no attribute '{name}'") 550 nl_type = attr.value 551 552 if attr.is_multi and isinstance(value, list): 553 attr_payload = b'' 554 for subvalue in value: 555 attr_payload += self._add_attr(space, name, subvalue, search_attrs) 556 return attr_payload 557 558 if attr["type"] == 'nest': 559 nl_type |= Netlink.NLA_F_NESTED 560 attr_payload = b'' 561 sub_space = attr['nested-attributes'] 562 sub_attrs = SpaceAttrs(self.attr_sets[sub_space], value, search_attrs) 563 for subname, subvalue in value.items(): 564 attr_payload += self._add_attr(sub_space, subname, subvalue, sub_attrs) 565 elif attr["type"] == 'flag': 566 if not value: 567 # If value is absent or false then skip attribute creation. 568 return b'' 569 attr_payload = b'' 570 elif attr["type"] == 'string': 571 attr_payload = str(value).encode('ascii') + b'\x00' 572 elif attr["type"] == 'binary': 573 if isinstance(value, bytes): 574 attr_payload = value 575 elif isinstance(value, str): 576 if attr.display_hint: 577 attr_payload = self._from_string(value, attr) 578 else: 579 attr_payload = bytes.fromhex(value) 580 elif isinstance(value, dict) and attr.struct_name: 581 attr_payload = self._encode_struct(attr.struct_name, value) 582 else: 583 raise Exception(f'Unknown type for binary attribute, value: {value}') 584 elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar: 585 scalar = self._get_scalar(attr, value) 586 if attr.is_auto_scalar: 587 attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64') 588 else: 589 attr_type = attr["type"] 590 format = NlAttr.get_format(attr_type, attr.byte_order) 591 attr_payload = format.pack(scalar) 592 elif attr['type'] in "bitfield32": 593 scalar_value = self._get_scalar(attr, value["value"]) 594 scalar_selector = self._get_scalar(attr, value["selector"]) 595 attr_payload = struct.pack("II", scalar_value, scalar_selector) 596 elif attr['type'] == 'sub-message': 597 msg_format = self._resolve_selector(attr, search_attrs) 598 attr_payload = b'' 599 if msg_format.fixed_header: 600 attr_payload += self._encode_struct(msg_format.fixed_header, value) 601 if msg_format.attr_set: 602 if msg_format.attr_set in self.attr_sets: 603 nl_type |= Netlink.NLA_F_NESTED 604 sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs) 605 for subname, subvalue in value.items(): 606 attr_payload += self._add_attr(msg_format.attr_set, 607 subname, subvalue, sub_attrs) 608 else: 609 raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'") 610 else: 611 raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') 612 613 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) 614 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad 615 616 def _decode_enum(self, raw, attr_spec): 617 enum = self.consts[attr_spec['enum']] 618 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 619 i = 0 620 value = set() 621 while raw: 622 if raw & 1: 623 value.add(enum.entries_by_val[i].name) 624 raw >>= 1 625 i += 1 626 else: 627 value = enum.entries_by_val[raw].name 628 return value 629 630 def _decode_binary(self, attr, attr_spec): 631 if attr_spec.struct_name: 632 decoded = self._decode_struct(attr.raw, attr_spec.struct_name) 633 elif attr_spec.sub_type: 634 decoded = attr.as_c_array(attr_spec.sub_type) 635 if 'enum' in attr_spec: 636 decoded = [ self._decode_enum(x, attr_spec) for x in decoded ] 637 elif attr_spec.display_hint: 638 decoded = [ self._formatted_string(x, attr_spec.display_hint) 639 for x in decoded ] 640 else: 641 decoded = attr.as_bin() 642 if attr_spec.display_hint: 643 decoded = self._formatted_string(decoded, attr_spec.display_hint) 644 return decoded 645 646 def _decode_array_attr(self, attr, attr_spec): 647 decoded = [] 648 offset = 0 649 while offset < len(attr.raw): 650 item = NlAttr(attr.raw, offset) 651 offset += item.full_len 652 653 if attr_spec["sub-type"] == 'nest': 654 subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes']) 655 decoded.append({ item.type: subattrs }) 656 elif attr_spec["sub-type"] == 'binary': 657 subattr = item.as_bin() 658 if attr_spec.display_hint: 659 subattr = self._formatted_string(subattr, attr_spec.display_hint) 660 decoded.append(subattr) 661 elif attr_spec["sub-type"] in NlAttr.type_formats: 662 subattr = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order) 663 if 'enum' in attr_spec: 664 subattr = self._decode_enum(subattr, attr_spec) 665 elif attr_spec.display_hint: 666 subattr = self._formatted_string(subattr, attr_spec.display_hint) 667 decoded.append(subattr) 668 else: 669 raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}') 670 return decoded 671 672 def _decode_nest_type_value(self, attr, attr_spec): 673 decoded = {} 674 value = attr 675 for name in attr_spec['type-value']: 676 value = NlAttr(value.raw, 0) 677 decoded[name] = value.type 678 subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes']) 679 decoded.update(subattrs) 680 return decoded 681 682 def _decode_unknown(self, attr): 683 if attr.is_nest: 684 return self._decode(NlAttrs(attr.raw), None) 685 else: 686 return attr.as_bin() 687 688 def _rsp_add(self, rsp, name, is_multi, decoded): 689 if is_multi == None: 690 if name in rsp and type(rsp[name]) is not list: 691 rsp[name] = [rsp[name]] 692 is_multi = True 693 else: 694 is_multi = False 695 696 if not is_multi: 697 rsp[name] = decoded 698 elif name in rsp: 699 rsp[name].append(decoded) 700 else: 701 rsp[name] = [decoded] 702 703 def _resolve_selector(self, attr_spec, search_attrs): 704 sub_msg = attr_spec.sub_message 705 if sub_msg not in self.sub_msgs: 706 raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}") 707 sub_msg_spec = self.sub_msgs[sub_msg] 708 709 selector = attr_spec.selector 710 value = search_attrs.lookup(selector) 711 if value not in sub_msg_spec.formats: 712 raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'") 713 714 spec = sub_msg_spec.formats[value] 715 return spec 716 717 def _decode_sub_msg(self, attr, attr_spec, search_attrs): 718 msg_format = self._resolve_selector(attr_spec, search_attrs) 719 decoded = {} 720 offset = 0 721 if msg_format.fixed_header: 722 decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header)); 723 offset = self._struct_size(msg_format.fixed_header) 724 if msg_format.attr_set: 725 if msg_format.attr_set in self.attr_sets: 726 subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set) 727 decoded.update(subdict) 728 else: 729 raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'") 730 return decoded 731 732 def _decode(self, attrs, space, outer_attrs = None): 733 rsp = dict() 734 if space: 735 attr_space = self.attr_sets[space] 736 search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs) 737 738 for attr in attrs: 739 try: 740 attr_spec = attr_space.attrs_by_val[attr.type] 741 except (KeyError, UnboundLocalError): 742 if not self.process_unknown: 743 raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'") 744 attr_name = f"UnknownAttr({attr.type})" 745 self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr)) 746 continue 747 748 try: 749 if attr_spec["type"] == 'nest': 750 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs) 751 decoded = subdict 752 elif attr_spec["type"] == 'string': 753 decoded = attr.as_strz() 754 elif attr_spec["type"] == 'binary': 755 decoded = self._decode_binary(attr, attr_spec) 756 elif attr_spec["type"] == 'flag': 757 decoded = True 758 elif attr_spec.is_auto_scalar: 759 decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order) 760 elif attr_spec["type"] in NlAttr.type_formats: 761 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) 762 if 'enum' in attr_spec: 763 decoded = self._decode_enum(decoded, attr_spec) 764 elif attr_spec.display_hint: 765 decoded = self._formatted_string(decoded, attr_spec.display_hint) 766 elif attr_spec["type"] == 'indexed-array': 767 decoded = self._decode_array_attr(attr, attr_spec) 768 elif attr_spec["type"] == 'bitfield32': 769 value, selector = struct.unpack("II", attr.raw) 770 if 'enum' in attr_spec: 771 value = self._decode_enum(value, attr_spec) 772 selector = self._decode_enum(selector, attr_spec) 773 decoded = {"value": value, "selector": selector} 774 elif attr_spec["type"] == 'sub-message': 775 decoded = self._decode_sub_msg(attr, attr_spec, search_attrs) 776 elif attr_spec["type"] == 'nest-type-value': 777 decoded = self._decode_nest_type_value(attr, attr_spec) 778 else: 779 if not self.process_unknown: 780 raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') 781 decoded = self._decode_unknown(attr) 782 783 self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded) 784 except: 785 print(f"Error decoding '{attr_spec.name}' from '{space}'") 786 raise 787 788 return rsp 789 790 def _decode_extack_path(self, attrs, attr_set, offset, target): 791 for attr in attrs: 792 try: 793 attr_spec = attr_set.attrs_by_val[attr.type] 794 except KeyError: 795 raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") 796 if offset > target: 797 break 798 if offset == target: 799 return '.' + attr_spec.name 800 801 if offset + attr.full_len <= target: 802 offset += attr.full_len 803 continue 804 if attr_spec['type'] != 'nest': 805 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") 806 offset += 4 807 subpath = self._decode_extack_path(NlAttrs(attr.raw), 808 self.attr_sets[attr_spec['nested-attributes']], 809 offset, target) 810 if subpath is None: 811 return None 812 return '.' + attr_spec.name + subpath 813 814 return None 815 816 def _decode_extack(self, request, op, extack): 817 if 'bad-attr-offs' not in extack: 818 return 819 820 msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op) 821 offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header) 822 path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset, 823 extack['bad-attr-offs']) 824 if path: 825 del extack['bad-attr-offs'] 826 extack['bad-attr'] = path 827 828 def _struct_size(self, name): 829 if name: 830 members = self.consts[name].members 831 size = 0 832 for m in members: 833 if m.type in ['pad', 'binary']: 834 if m.struct: 835 size += self._struct_size(m.struct) 836 else: 837 size += m.len 838 else: 839 format = NlAttr.get_format(m.type, m.byte_order) 840 size += format.size 841 return size 842 else: 843 return 0 844 845 def _decode_struct(self, data, name): 846 members = self.consts[name].members 847 attrs = dict() 848 offset = 0 849 for m in members: 850 value = None 851 if m.type == 'pad': 852 offset += m.len 853 elif m.type == 'binary': 854 if m.struct: 855 len = self._struct_size(m.struct) 856 value = self._decode_struct(data[offset : offset + len], 857 m.struct) 858 offset += len 859 else: 860 value = data[offset : offset + m.len] 861 offset += m.len 862 else: 863 format = NlAttr.get_format(m.type, m.byte_order) 864 [ value ] = format.unpack_from(data, offset) 865 offset += format.size 866 if value is not None: 867 if m.enum: 868 value = self._decode_enum(value, m) 869 elif m.display_hint: 870 value = self._formatted_string(value, m.display_hint) 871 attrs[m.name] = value 872 return attrs 873 874 def _encode_struct(self, name, vals): 875 members = self.consts[name].members 876 attr_payload = b'' 877 for m in members: 878 value = vals.pop(m.name) if m.name in vals else None 879 if m.type == 'pad': 880 attr_payload += bytearray(m.len) 881 elif m.type == 'binary': 882 if m.struct: 883 if value is None: 884 value = dict() 885 attr_payload += self._encode_struct(m.struct, value) 886 else: 887 if value is None: 888 attr_payload += bytearray(m.len) 889 else: 890 attr_payload += bytes.fromhex(value) 891 else: 892 if value is None: 893 value = 0 894 format = NlAttr.get_format(m.type, m.byte_order) 895 attr_payload += format.pack(value) 896 return attr_payload 897 898 def _formatted_string(self, raw, display_hint): 899 if display_hint == 'mac': 900 formatted = ':'.join('%02x' % b for b in raw) 901 elif display_hint == 'hex': 902 if isinstance(raw, int): 903 formatted = hex(raw) 904 else: 905 formatted = bytes.hex(raw, ' ') 906 elif display_hint in [ 'ipv4', 'ipv6' ]: 907 formatted = format(ipaddress.ip_address(raw)) 908 elif display_hint == 'uuid': 909 formatted = str(uuid.UUID(bytes=raw)) 910 else: 911 formatted = raw 912 return formatted 913 914 def _from_string(self, string, attr_spec): 915 if attr_spec.display_hint in ['ipv4', 'ipv6']: 916 ip = ipaddress.ip_address(string) 917 if attr_spec['type'] == 'binary': 918 raw = ip.packed 919 else: 920 raw = int(ip) 921 else: 922 raise Exception(f"Display hint '{attr_spec.display_hint}' not implemented" 923 f" when parsing '{attr_spec['name']}'") 924 return raw 925 926 def handle_ntf(self, decoded): 927 msg = dict() 928 if self.include_raw: 929 msg['raw'] = decoded 930 op = self.rsp_by_value[decoded.cmd()] 931 attrs = self._decode(decoded.raw_attrs, op.attr_set.name) 932 if op.fixed_header: 933 attrs.update(self._decode_struct(decoded.raw, op.fixed_header)) 934 935 msg['name'] = op['name'] 936 msg['msg'] = attrs 937 self.async_msg_queue.put(msg) 938 939 def check_ntf(self): 940 while True: 941 try: 942 reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT) 943 except BlockingIOError: 944 return 945 946 nms = NlMsgs(reply) 947 self._recv_dbg_print(reply, nms) 948 for nl_msg in nms: 949 if nl_msg.error: 950 print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) 951 print(nl_msg) 952 continue 953 if nl_msg.done: 954 print("Netlink done while checking for ntf!?") 955 continue 956 957 decoded = self.nlproto.decode(self, nl_msg, None) 958 if decoded.cmd() not in self.async_msg_ids: 959 print("Unexpected msg id while checking for ntf", decoded) 960 continue 961 962 self.handle_ntf(decoded) 963 964 def poll_ntf(self, duration=None): 965 start_time = time.time() 966 selector = selectors.DefaultSelector() 967 selector.register(self.sock, selectors.EVENT_READ) 968 969 while True: 970 try: 971 yield self.async_msg_queue.get_nowait() 972 except queue.Empty: 973 if duration is not None: 974 timeout = start_time + duration - time.time() 975 if timeout <= 0: 976 return 977 else: 978 timeout = None 979 events = selector.select(timeout) 980 if events: 981 self.check_ntf() 982 983 def operation_do_attributes(self, name): 984 """ 985 For a given operation name, find and return a supported 986 set of attributes (as a dict). 987 """ 988 op = self.find_operation(name) 989 if not op: 990 return None 991 992 return op['do']['request']['attributes'].copy() 993 994 def _encode_message(self, op, vals, flags, req_seq): 995 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK 996 for flag in flags or []: 997 nl_flags |= flag 998 999 msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq) 1000 if op.fixed_header: 1001 msg += self._encode_struct(op.fixed_header, vals) 1002 search_attrs = SpaceAttrs(op.attr_set, vals) 1003 for name, value in vals.items(): 1004 msg += self._add_attr(op.attr_set.name, name, value, search_attrs) 1005 msg = _genl_msg_finalize(msg) 1006 return msg 1007 1008 def _ops(self, ops): 1009 reqs_by_seq = {} 1010 req_seq = random.randint(1024, 65535) 1011 payload = b'' 1012 for (method, vals, flags) in ops: 1013 op = self.ops[method] 1014 msg = self._encode_message(op, vals, flags, req_seq) 1015 reqs_by_seq[req_seq] = (op, msg, flags) 1016 payload += msg 1017 req_seq += 1 1018 1019 self.sock.send(payload, 0) 1020 1021 done = False 1022 rsp = [] 1023 op_rsp = [] 1024 while not done: 1025 reply = self.sock.recv(self._recv_size) 1026 nms = NlMsgs(reply, attr_space=op.attr_set) 1027 self._recv_dbg_print(reply, nms) 1028 for nl_msg in nms: 1029 if nl_msg.nl_seq in reqs_by_seq: 1030 (op, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq] 1031 if nl_msg.extack: 1032 self._decode_extack(req_msg, op, nl_msg.extack) 1033 else: 1034 op = None 1035 req_flags = [] 1036 1037 if nl_msg.error: 1038 raise NlError(nl_msg) 1039 if nl_msg.done: 1040 if nl_msg.extack: 1041 print("Netlink warning:") 1042 print(nl_msg) 1043 1044 if Netlink.NLM_F_DUMP in req_flags: 1045 rsp.append(op_rsp) 1046 elif not op_rsp: 1047 rsp.append(None) 1048 elif len(op_rsp) == 1: 1049 rsp.append(op_rsp[0]) 1050 else: 1051 rsp.append(op_rsp) 1052 op_rsp = [] 1053 1054 del reqs_by_seq[nl_msg.nl_seq] 1055 done = len(reqs_by_seq) == 0 1056 break 1057 1058 decoded = self.nlproto.decode(self, nl_msg, op) 1059 1060 # Check if this is a reply to our request 1061 if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value: 1062 if decoded.cmd() in self.async_msg_ids: 1063 self.handle_ntf(decoded) 1064 continue 1065 else: 1066 print('Unexpected message: ' + repr(decoded)) 1067 continue 1068 1069 rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name) 1070 if op.fixed_header: 1071 rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header)) 1072 op_rsp.append(rsp_msg) 1073 1074 return rsp 1075 1076 def _op(self, method, vals, flags=None, dump=False): 1077 req_flags = flags or [] 1078 if dump: 1079 req_flags.append(Netlink.NLM_F_DUMP) 1080 1081 ops = [(method, vals, req_flags)] 1082 return self._ops(ops)[0] 1083 1084 def do(self, method, vals, flags=None): 1085 return self._op(method, vals, flags) 1086 1087 def dump(self, method, vals): 1088 return self._op(method, vals, dump=True) 1089 1090 def do_multi(self, ops): 1091 return self._ops(ops) 1092