1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause 2 3from collections import namedtuple 4from enum import Enum 5import functools 6import os 7import random 8import socket 9import struct 10from struct import Struct 11import sys 12import yaml 13import ipaddress 14import uuid 15import queue 16import selectors 17import time 18 19from .nlspec import SpecFamily 20 21# 22# Generic Netlink code which should really be in some library, but I can't quickly find one. 23# 24 25 26class Netlink: 27 # Netlink socket 28 SOL_NETLINK = 270 29 30 NETLINK_ADD_MEMBERSHIP = 1 31 NETLINK_CAP_ACK = 10 32 NETLINK_EXT_ACK = 11 33 NETLINK_GET_STRICT_CHK = 12 34 35 # Netlink message 36 NLMSG_ERROR = 2 37 NLMSG_DONE = 3 38 39 NLM_F_REQUEST = 1 40 NLM_F_ACK = 4 41 NLM_F_ROOT = 0x100 42 NLM_F_MATCH = 0x200 43 44 NLM_F_REPLACE = 0x100 45 NLM_F_EXCL = 0x200 46 NLM_F_CREATE = 0x400 47 NLM_F_APPEND = 0x800 48 49 NLM_F_CAPPED = 0x100 50 NLM_F_ACK_TLVS = 0x200 51 52 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH 53 54 NLA_F_NESTED = 0x8000 55 NLA_F_NET_BYTEORDER = 0x4000 56 57 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER 58 59 # Genetlink defines 60 NETLINK_GENERIC = 16 61 62 GENL_ID_CTRL = 0x10 63 64 # nlctrl 65 CTRL_CMD_GETFAMILY = 3 66 67 CTRL_ATTR_FAMILY_ID = 1 68 CTRL_ATTR_FAMILY_NAME = 2 69 CTRL_ATTR_MAXATTR = 5 70 CTRL_ATTR_MCAST_GROUPS = 7 71 72 CTRL_ATTR_MCAST_GRP_NAME = 1 73 CTRL_ATTR_MCAST_GRP_ID = 2 74 75 # Extack types 76 NLMSGERR_ATTR_MSG = 1 77 NLMSGERR_ATTR_OFFS = 2 78 NLMSGERR_ATTR_COOKIE = 3 79 NLMSGERR_ATTR_POLICY = 4 80 NLMSGERR_ATTR_MISS_TYPE = 5 81 NLMSGERR_ATTR_MISS_NEST = 6 82 83 # Policy types 84 NL_POLICY_TYPE_ATTR_TYPE = 1 85 NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2 86 NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3 87 NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4 88 NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5 89 NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6 90 NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7 91 NL_POLICY_TYPE_ATTR_POLICY_IDX = 8 92 NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9 93 NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10 94 NL_POLICY_TYPE_ATTR_PAD = 11 95 NL_POLICY_TYPE_ATTR_MASK = 12 96 97 AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64', 98 's8', 's16', 's32', 's64', 99 'binary', 'string', 'nul-string', 100 'nested', 'nested-array', 101 'bitfield32', 'sint', 'uint']) 102 103class NlError(Exception): 104 def __init__(self, nl_msg): 105 self.nl_msg = nl_msg 106 self.error = -nl_msg.error 107 108 def __str__(self): 109 return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}" 110 111 112class ConfigError(Exception): 113 pass 114 115 116class NlAttr: 117 ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) 118 type_formats = { 119 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), 120 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), 121 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")), 122 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")), 123 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")), 124 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")), 125 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")), 126 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q")) 127 } 128 129 def __init__(self, raw, offset): 130 self._len, self._type = struct.unpack("HH", raw[offset : offset + 4]) 131 self.type = self._type & ~Netlink.NLA_TYPE_MASK 132 self.is_nest = self._type & Netlink.NLA_F_NESTED 133 self.payload_len = self._len 134 self.full_len = (self.payload_len + 3) & ~3 135 self.raw = raw[offset + 4 : offset + self.payload_len] 136 137 @classmethod 138 def get_format(cls, attr_type, byte_order=None): 139 format = cls.type_formats[attr_type] 140 if byte_order: 141 return format.big if byte_order == "big-endian" \ 142 else format.little 143 return format.native 144 145 def as_scalar(self, attr_type, byte_order=None): 146 format = self.get_format(attr_type, byte_order) 147 return format.unpack(self.raw)[0] 148 149 def as_auto_scalar(self, attr_type, byte_order=None): 150 if len(self.raw) != 4 and len(self.raw) != 8: 151 raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}") 152 real_type = attr_type[0] + str(len(self.raw) * 8) 153 format = self.get_format(real_type, byte_order) 154 return format.unpack(self.raw)[0] 155 156 def as_strz(self): 157 return self.raw.decode('ascii')[:-1] 158 159 def as_bin(self): 160 return self.raw 161 162 def as_c_array(self, type): 163 format = self.get_format(type) 164 return [ x[0] for x in format.iter_unpack(self.raw) ] 165 166 def __repr__(self): 167 return f"[type:{self.type} len:{self._len}] {self.raw}" 168 169 170class NlAttrs: 171 def __init__(self, msg, offset=0): 172 self.attrs = [] 173 174 while offset < len(msg): 175 attr = NlAttr(msg, offset) 176 offset += attr.full_len 177 self.attrs.append(attr) 178 179 def __iter__(self): 180 yield from self.attrs 181 182 def __repr__(self): 183 msg = '' 184 for a in self.attrs: 185 if msg: 186 msg += '\n' 187 msg += repr(a) 188 return msg 189 190 191class NlMsg: 192 def __init__(self, msg, offset, attr_space=None): 193 self.hdr = msg[offset : offset + 16] 194 195 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ 196 struct.unpack("IHHII", self.hdr) 197 198 self.raw = msg[offset + 16 : offset + self.nl_len] 199 200 self.error = 0 201 self.done = 0 202 203 extack_off = None 204 if self.nl_type == Netlink.NLMSG_ERROR: 205 self.error = struct.unpack("i", self.raw[0:4])[0] 206 self.done = 1 207 extack_off = 20 208 elif self.nl_type == Netlink.NLMSG_DONE: 209 self.error = struct.unpack("i", self.raw[0:4])[0] 210 self.done = 1 211 extack_off = 4 212 213 self.extack = None 214 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: 215 self.extack = dict() 216 extack_attrs = NlAttrs(self.raw[extack_off:]) 217 for extack in extack_attrs: 218 if extack.type == Netlink.NLMSGERR_ATTR_MSG: 219 self.extack['msg'] = extack.as_strz() 220 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: 221 self.extack['miss-type'] = extack.as_scalar('u32') 222 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: 223 self.extack['miss-nest'] = extack.as_scalar('u32') 224 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: 225 self.extack['bad-attr-offs'] = extack.as_scalar('u32') 226 elif extack.type == Netlink.NLMSGERR_ATTR_POLICY: 227 self.extack['policy'] = self._decode_policy(extack.raw) 228 else: 229 if 'unknown' not in self.extack: 230 self.extack['unknown'] = [] 231 self.extack['unknown'].append(extack) 232 233 if attr_space: 234 # We don't have the ability to parse nests yet, so only do global 235 if 'miss-type' in self.extack and 'miss-nest' not in self.extack: 236 miss_type = self.extack['miss-type'] 237 if miss_type in attr_space.attrs_by_val: 238 spec = attr_space.attrs_by_val[miss_type] 239 self.extack['miss-type'] = spec['name'] 240 if 'doc' in spec: 241 self.extack['miss-type-doc'] = spec['doc'] 242 243 def _decode_policy(self, raw): 244 policy = {} 245 for attr in NlAttrs(raw): 246 if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE: 247 type = attr.as_scalar('u32') 248 policy['type'] = Netlink.AttrType(type).name 249 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S: 250 policy['min-value'] = attr.as_scalar('s64') 251 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S: 252 policy['max-value'] = attr.as_scalar('s64') 253 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U: 254 policy['min-value'] = attr.as_scalar('u64') 255 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U: 256 policy['max-value'] = attr.as_scalar('u64') 257 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH: 258 policy['min-length'] = attr.as_scalar('u32') 259 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH: 260 policy['max-length'] = attr.as_scalar('u32') 261 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK: 262 policy['bitfield32-mask'] = attr.as_scalar('u32') 263 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK: 264 policy['mask'] = attr.as_scalar('u64') 265 return policy 266 267 def cmd(self): 268 return self.nl_type 269 270 def __repr__(self): 271 msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}" 272 if self.error: 273 msg += '\n\terror: ' + str(self.error) 274 if self.extack: 275 msg += '\n\textack: ' + repr(self.extack) 276 return msg 277 278 279class NlMsgs: 280 def __init__(self, data, attr_space=None): 281 self.msgs = [] 282 283 offset = 0 284 while offset < len(data): 285 msg = NlMsg(data, offset, attr_space=attr_space) 286 offset += msg.nl_len 287 self.msgs.append(msg) 288 289 def __iter__(self): 290 yield from self.msgs 291 292 293genl_family_name_to_id = None 294 295 296def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): 297 # we prepend length in _genl_msg_finalize() 298 if seq is None: 299 seq = random.randint(1, 1024) 300 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 301 genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0) 302 return nlmsg + genlmsg 303 304 305def _genl_msg_finalize(msg): 306 return struct.pack("I", len(msg) + 4) + msg 307 308 309def _genl_load_families(): 310 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: 311 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 312 313 msg = _genl_msg(Netlink.GENL_ID_CTRL, 314 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, 315 Netlink.CTRL_CMD_GETFAMILY, 1) 316 msg = _genl_msg_finalize(msg) 317 318 sock.send(msg, 0) 319 320 global genl_family_name_to_id 321 genl_family_name_to_id = dict() 322 323 while True: 324 reply = sock.recv(128 * 1024) 325 nms = NlMsgs(reply) 326 for nl_msg in nms: 327 if nl_msg.error: 328 print("Netlink error:", nl_msg.error) 329 return 330 if nl_msg.done: 331 return 332 333 gm = GenlMsg(nl_msg) 334 fam = dict() 335 for attr in NlAttrs(gm.raw): 336 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: 337 fam['id'] = attr.as_scalar('u16') 338 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: 339 fam['name'] = attr.as_strz() 340 elif attr.type == Netlink.CTRL_ATTR_MAXATTR: 341 fam['maxattr'] = attr.as_scalar('u32') 342 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: 343 fam['mcast'] = dict() 344 for entry in NlAttrs(attr.raw): 345 mcast_name = None 346 mcast_id = None 347 for entry_attr in NlAttrs(entry.raw): 348 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: 349 mcast_name = entry_attr.as_strz() 350 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: 351 mcast_id = entry_attr.as_scalar('u32') 352 if mcast_name and mcast_id is not None: 353 fam['mcast'][mcast_name] = mcast_id 354 if 'name' in fam and 'id' in fam: 355 genl_family_name_to_id[fam['name']] = fam 356 357 358class GenlMsg: 359 def __init__(self, nl_msg): 360 self.nl = nl_msg 361 self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0) 362 self.raw = nl_msg.raw[4:] 363 364 def cmd(self): 365 return self.genl_cmd 366 367 def __repr__(self): 368 msg = repr(self.nl) 369 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" 370 for a in self.raw_attrs: 371 msg += '\t\t' + repr(a) + '\n' 372 return msg 373 374 375class NetlinkProtocol: 376 def __init__(self, family_name, proto_num): 377 self.family_name = family_name 378 self.proto_num = proto_num 379 380 def _message(self, nl_type, nl_flags, seq=None): 381 if seq is None: 382 seq = random.randint(1, 1024) 383 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 384 return nlmsg 385 386 def message(self, flags, command, version, seq=None): 387 return self._message(command, flags, seq) 388 389 def _decode(self, nl_msg): 390 return nl_msg 391 392 def decode(self, ynl, nl_msg, op): 393 msg = self._decode(nl_msg) 394 if op is None: 395 op = ynl.rsp_by_value[msg.cmd()] 396 fixed_header_size = ynl._struct_size(op.fixed_header) 397 msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size) 398 return msg 399 400 def get_mcast_id(self, mcast_name, mcast_groups): 401 if mcast_name not in mcast_groups: 402 raise Exception(f'Multicast group "{mcast_name}" not present in the spec') 403 return mcast_groups[mcast_name].value 404 405 def msghdr_size(self): 406 return 16 407 408 409class GenlProtocol(NetlinkProtocol): 410 def __init__(self, family_name): 411 super().__init__(family_name, Netlink.NETLINK_GENERIC) 412 413 global genl_family_name_to_id 414 if genl_family_name_to_id is None: 415 _genl_load_families() 416 417 self.genl_family = genl_family_name_to_id[family_name] 418 self.family_id = genl_family_name_to_id[family_name]['id'] 419 420 def message(self, flags, command, version, seq=None): 421 nlmsg = self._message(self.family_id, flags, seq) 422 genlmsg = struct.pack("BBH", command, version, 0) 423 return nlmsg + genlmsg 424 425 def _decode(self, nl_msg): 426 return GenlMsg(nl_msg) 427 428 def get_mcast_id(self, mcast_name, mcast_groups): 429 if mcast_name not in self.genl_family['mcast']: 430 raise Exception(f'Multicast group "{mcast_name}" not present in the family') 431 return self.genl_family['mcast'][mcast_name] 432 433 def msghdr_size(self): 434 return super().msghdr_size() + 4 435 436 437class SpaceAttrs: 438 SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values']) 439 440 def __init__(self, attr_space, attrs, outer = None): 441 outer_scopes = outer.scopes if outer else [] 442 inner_scope = self.SpecValuesPair(attr_space, attrs) 443 self.scopes = [inner_scope] + outer_scopes 444 445 def lookup(self, name): 446 for scope in self.scopes: 447 if name in scope.spec: 448 if name in scope.values: 449 return scope.values[name] 450 spec_name = scope.spec.yaml['name'] 451 raise Exception( 452 f"No value for '{name}' in attribute space '{spec_name}'") 453 raise Exception(f"Attribute '{name}' not defined in any attribute-set") 454 455 456# 457# YNL implementation details. 458# 459 460 461class YnlFamily(SpecFamily): 462 def __init__(self, def_path, schema=None, process_unknown=False, 463 recv_size=0): 464 super().__init__(def_path, schema) 465 466 self.include_raw = False 467 self.process_unknown = process_unknown 468 469 try: 470 if self.proto == "netlink-raw": 471 self.nlproto = NetlinkProtocol(self.yaml['name'], 472 self.yaml['protonum']) 473 else: 474 self.nlproto = GenlProtocol(self.yaml['name']) 475 except KeyError: 476 raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel") 477 478 self._recv_dbg = False 479 # Note that netlink will use conservative (min) message size for 480 # the first dump recv() on the socket, our setting will only matter 481 # from the second recv() on. 482 self._recv_size = recv_size if recv_size else 131072 483 # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo) 484 # for a message, so smaller receive sizes will lead to truncation. 485 # Note that the min size for other families may be larger than 4k! 486 if self._recv_size < 4000: 487 raise ConfigError() 488 489 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num) 490 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 491 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) 492 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1) 493 494 self.async_msg_ids = set() 495 self.async_msg_queue = queue.Queue() 496 497 for msg in self.msgs.values(): 498 if msg.is_async: 499 self.async_msg_ids.add(msg.rsp_value) 500 501 for op_name, op in self.ops.items(): 502 bound_f = functools.partial(self._op, op_name) 503 setattr(self, op.ident_name, bound_f) 504 505 506 def ntf_subscribe(self, mcast_name): 507 mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups) 508 self.sock.bind((0, 0)) 509 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, 510 mcast_id) 511 512 def set_recv_dbg(self, enabled): 513 self._recv_dbg = enabled 514 515 def _recv_dbg_print(self, reply, nl_msgs): 516 if not self._recv_dbg: 517 return 518 print("Recv: read", len(reply), "bytes,", 519 len(nl_msgs.msgs), "messages", file=sys.stderr) 520 for nl_msg in nl_msgs: 521 print(" ", nl_msg, file=sys.stderr) 522 523 def _encode_enum(self, attr_spec, value): 524 enum = self.consts[attr_spec['enum']] 525 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 526 scalar = 0 527 if isinstance(value, str): 528 value = [value] 529 for single_value in value: 530 scalar += enum.entries[single_value].user_value(as_flags = True) 531 return scalar 532 else: 533 return enum.entries[value].user_value() 534 535 def _get_scalar(self, attr_spec, value): 536 try: 537 return int(value) 538 except (ValueError, TypeError) as e: 539 if 'enum' in attr_spec: 540 return self._encode_enum(attr_spec, value) 541 if attr_spec.display_hint: 542 return self._from_string(value, attr_spec) 543 raise e 544 545 def _add_attr(self, space, name, value, search_attrs): 546 try: 547 attr = self.attr_sets[space][name] 548 except KeyError: 549 raise Exception(f"Space '{space}' has no attribute '{name}'") 550 nl_type = attr.value 551 552 if attr.is_multi and isinstance(value, list): 553 attr_payload = b'' 554 for subvalue in value: 555 attr_payload += self._add_attr(space, name, subvalue, search_attrs) 556 return attr_payload 557 558 if attr["type"] == 'nest': 559 nl_type |= Netlink.NLA_F_NESTED 560 attr_payload = b'' 561 sub_space = attr['nested-attributes'] 562 sub_attrs = SpaceAttrs(self.attr_sets[sub_space], value, search_attrs) 563 for subname, subvalue in value.items(): 564 attr_payload += self._add_attr(sub_space, subname, subvalue, sub_attrs) 565 elif attr["type"] == 'flag': 566 if not value: 567 # If value is absent or false then skip attribute creation. 568 return b'' 569 attr_payload = b'' 570 elif attr["type"] == 'string': 571 attr_payload = str(value).encode('ascii') + b'\x00' 572 elif attr["type"] == 'binary': 573 if isinstance(value, bytes): 574 attr_payload = value 575 elif isinstance(value, str): 576 if attr.display_hint: 577 attr_payload = self._from_string(value, attr) 578 else: 579 attr_payload = bytes.fromhex(value) 580 elif isinstance(value, dict) and attr.struct_name: 581 attr_payload = self._encode_struct(attr.struct_name, value) 582 else: 583 raise Exception(f'Unknown type for binary attribute, value: {value}') 584 elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar: 585 scalar = self._get_scalar(attr, value) 586 if attr.is_auto_scalar: 587 attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64') 588 else: 589 attr_type = attr["type"] 590 format = NlAttr.get_format(attr_type, attr.byte_order) 591 attr_payload = format.pack(scalar) 592 elif attr['type'] in "bitfield32": 593 scalar_value = self._get_scalar(attr, value["value"]) 594 scalar_selector = self._get_scalar(attr, value["selector"]) 595 attr_payload = struct.pack("II", scalar_value, scalar_selector) 596 elif attr['type'] == 'sub-message': 597 msg_format, _ = self._resolve_selector(attr, search_attrs) 598 attr_payload = b'' 599 if msg_format.fixed_header: 600 attr_payload += self._encode_struct(msg_format.fixed_header, value) 601 if msg_format.attr_set: 602 if msg_format.attr_set in self.attr_sets: 603 nl_type |= Netlink.NLA_F_NESTED 604 sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs) 605 for subname, subvalue in value.items(): 606 attr_payload += self._add_attr(msg_format.attr_set, 607 subname, subvalue, sub_attrs) 608 else: 609 raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'") 610 else: 611 raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') 612 613 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) 614 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad 615 616 def _decode_enum(self, raw, attr_spec): 617 enum = self.consts[attr_spec['enum']] 618 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 619 i = 0 620 value = set() 621 while raw: 622 if raw & 1: 623 value.add(enum.entries_by_val[i].name) 624 raw >>= 1 625 i += 1 626 else: 627 value = enum.entries_by_val[raw].name 628 return value 629 630 def _decode_binary(self, attr, attr_spec): 631 if attr_spec.struct_name: 632 decoded = self._decode_struct(attr.raw, attr_spec.struct_name) 633 elif attr_spec.sub_type: 634 decoded = attr.as_c_array(attr_spec.sub_type) 635 if 'enum' in attr_spec: 636 decoded = [ self._decode_enum(x, attr_spec) for x in decoded ] 637 elif attr_spec.display_hint: 638 decoded = [ self._formatted_string(x, attr_spec.display_hint) 639 for x in decoded ] 640 else: 641 decoded = attr.as_bin() 642 if attr_spec.display_hint: 643 decoded = self._formatted_string(decoded, attr_spec.display_hint) 644 return decoded 645 646 def _decode_array_attr(self, attr, attr_spec): 647 decoded = [] 648 offset = 0 649 while offset < len(attr.raw): 650 item = NlAttr(attr.raw, offset) 651 offset += item.full_len 652 653 if attr_spec["sub-type"] == 'nest': 654 subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes']) 655 decoded.append({ item.type: subattrs }) 656 elif attr_spec["sub-type"] == 'binary': 657 subattr = item.as_bin() 658 if attr_spec.display_hint: 659 subattr = self._formatted_string(subattr, attr_spec.display_hint) 660 decoded.append(subattr) 661 elif attr_spec["sub-type"] in NlAttr.type_formats: 662 subattr = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order) 663 if 'enum' in attr_spec: 664 subattr = self._decode_enum(subattr, attr_spec) 665 elif attr_spec.display_hint: 666 subattr = self._formatted_string(subattr, attr_spec.display_hint) 667 decoded.append(subattr) 668 else: 669 raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}') 670 return decoded 671 672 def _decode_nest_type_value(self, attr, attr_spec): 673 decoded = {} 674 value = attr 675 for name in attr_spec['type-value']: 676 value = NlAttr(value.raw, 0) 677 decoded[name] = value.type 678 subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes']) 679 decoded.update(subattrs) 680 return decoded 681 682 def _decode_unknown(self, attr): 683 if attr.is_nest: 684 return self._decode(NlAttrs(attr.raw), None) 685 else: 686 return attr.as_bin() 687 688 def _rsp_add(self, rsp, name, is_multi, decoded): 689 if is_multi == None: 690 if name in rsp and type(rsp[name]) is not list: 691 rsp[name] = [rsp[name]] 692 is_multi = True 693 else: 694 is_multi = False 695 696 if not is_multi: 697 rsp[name] = decoded 698 elif name in rsp: 699 rsp[name].append(decoded) 700 else: 701 rsp[name] = [decoded] 702 703 def _resolve_selector(self, attr_spec, search_attrs): 704 sub_msg = attr_spec.sub_message 705 if sub_msg not in self.sub_msgs: 706 raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}") 707 sub_msg_spec = self.sub_msgs[sub_msg] 708 709 selector = attr_spec.selector 710 value = search_attrs.lookup(selector) 711 if value not in sub_msg_spec.formats: 712 raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'") 713 714 spec = sub_msg_spec.formats[value] 715 return spec, value 716 717 def _decode_sub_msg(self, attr, attr_spec, search_attrs): 718 msg_format, _ = self._resolve_selector(attr_spec, search_attrs) 719 decoded = {} 720 offset = 0 721 if msg_format.fixed_header: 722 decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header)); 723 offset = self._struct_size(msg_format.fixed_header) 724 if msg_format.attr_set: 725 if msg_format.attr_set in self.attr_sets: 726 subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set) 727 decoded.update(subdict) 728 else: 729 raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'") 730 return decoded 731 732 def _decode(self, attrs, space, outer_attrs = None): 733 rsp = dict() 734 if space: 735 attr_space = self.attr_sets[space] 736 search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs) 737 738 for attr in attrs: 739 try: 740 attr_spec = attr_space.attrs_by_val[attr.type] 741 except (KeyError, UnboundLocalError): 742 if not self.process_unknown: 743 raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'") 744 attr_name = f"UnknownAttr({attr.type})" 745 self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr)) 746 continue 747 748 try: 749 if attr_spec["type"] == 'nest': 750 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs) 751 decoded = subdict 752 elif attr_spec["type"] == 'string': 753 decoded = attr.as_strz() 754 elif attr_spec["type"] == 'binary': 755 decoded = self._decode_binary(attr, attr_spec) 756 elif attr_spec["type"] == 'flag': 757 decoded = True 758 elif attr_spec.is_auto_scalar: 759 decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order) 760 elif attr_spec["type"] in NlAttr.type_formats: 761 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) 762 if 'enum' in attr_spec: 763 decoded = self._decode_enum(decoded, attr_spec) 764 elif attr_spec.display_hint: 765 decoded = self._formatted_string(decoded, attr_spec.display_hint) 766 elif attr_spec["type"] == 'indexed-array': 767 decoded = self._decode_array_attr(attr, attr_spec) 768 elif attr_spec["type"] == 'bitfield32': 769 value, selector = struct.unpack("II", attr.raw) 770 if 'enum' in attr_spec: 771 value = self._decode_enum(value, attr_spec) 772 selector = self._decode_enum(selector, attr_spec) 773 decoded = {"value": value, "selector": selector} 774 elif attr_spec["type"] == 'sub-message': 775 decoded = self._decode_sub_msg(attr, attr_spec, search_attrs) 776 elif attr_spec["type"] == 'nest-type-value': 777 decoded = self._decode_nest_type_value(attr, attr_spec) 778 else: 779 if not self.process_unknown: 780 raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') 781 decoded = self._decode_unknown(attr) 782 783 self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded) 784 except: 785 print(f"Error decoding '{attr_spec.name}' from '{space}'") 786 raise 787 788 return rsp 789 790 def _decode_extack_path(self, attrs, attr_set, offset, target, search_attrs): 791 for attr in attrs: 792 try: 793 attr_spec = attr_set.attrs_by_val[attr.type] 794 except KeyError: 795 raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") 796 if offset > target: 797 break 798 if offset == target: 799 return '.' + attr_spec.name 800 801 if offset + attr.full_len <= target: 802 offset += attr.full_len 803 continue 804 805 pathname = attr_spec.name 806 if attr_spec['type'] == 'nest': 807 sub_attrs = self.attr_sets[attr_spec['nested-attributes']] 808 search_attrs = SpaceAttrs(sub_attrs, search_attrs.lookup(attr_spec['name'])) 809 elif attr_spec['type'] == 'sub-message': 810 msg_format, value = self._resolve_selector(attr_spec, search_attrs) 811 if msg_format is None: 812 raise Exception(f"Can't resolve sub-message of {attr_spec['name']} for extack") 813 sub_attrs = self.attr_sets[msg_format.attr_set] 814 pathname += f"({value})" 815 else: 816 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") 817 offset += 4 818 subpath = self._decode_extack_path(NlAttrs(attr.raw), sub_attrs, 819 offset, target, search_attrs) 820 if subpath is None: 821 return None 822 return '.' + pathname + subpath 823 824 return None 825 826 def _decode_extack(self, request, op, extack, vals): 827 if 'bad-attr-offs' not in extack: 828 return 829 830 msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op) 831 offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header) 832 search_attrs = SpaceAttrs(op.attr_set, vals) 833 path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset, 834 extack['bad-attr-offs'], search_attrs) 835 if path: 836 del extack['bad-attr-offs'] 837 extack['bad-attr'] = path 838 839 def _struct_size(self, name): 840 if name: 841 members = self.consts[name].members 842 size = 0 843 for m in members: 844 if m.type in ['pad', 'binary']: 845 if m.struct: 846 size += self._struct_size(m.struct) 847 else: 848 size += m.len 849 else: 850 format = NlAttr.get_format(m.type, m.byte_order) 851 size += format.size 852 return size 853 else: 854 return 0 855 856 def _decode_struct(self, data, name): 857 members = self.consts[name].members 858 attrs = dict() 859 offset = 0 860 for m in members: 861 value = None 862 if m.type == 'pad': 863 offset += m.len 864 elif m.type == 'binary': 865 if m.struct: 866 len = self._struct_size(m.struct) 867 value = self._decode_struct(data[offset : offset + len], 868 m.struct) 869 offset += len 870 else: 871 value = data[offset : offset + m.len] 872 offset += m.len 873 else: 874 format = NlAttr.get_format(m.type, m.byte_order) 875 [ value ] = format.unpack_from(data, offset) 876 offset += format.size 877 if value is not None: 878 if m.enum: 879 value = self._decode_enum(value, m) 880 elif m.display_hint: 881 value = self._formatted_string(value, m.display_hint) 882 attrs[m.name] = value 883 return attrs 884 885 def _encode_struct(self, name, vals): 886 members = self.consts[name].members 887 attr_payload = b'' 888 for m in members: 889 value = vals.pop(m.name) if m.name in vals else None 890 if m.type == 'pad': 891 attr_payload += bytearray(m.len) 892 elif m.type == 'binary': 893 if m.struct: 894 if value is None: 895 value = dict() 896 attr_payload += self._encode_struct(m.struct, value) 897 else: 898 if value is None: 899 attr_payload += bytearray(m.len) 900 else: 901 attr_payload += bytes.fromhex(value) 902 else: 903 if value is None: 904 value = 0 905 format = NlAttr.get_format(m.type, m.byte_order) 906 attr_payload += format.pack(value) 907 return attr_payload 908 909 def _formatted_string(self, raw, display_hint): 910 if display_hint == 'mac': 911 formatted = ':'.join('%02x' % b for b in raw) 912 elif display_hint == 'hex': 913 if isinstance(raw, int): 914 formatted = hex(raw) 915 else: 916 formatted = bytes.hex(raw, ' ') 917 elif display_hint in [ 'ipv4', 'ipv6' ]: 918 formatted = format(ipaddress.ip_address(raw)) 919 elif display_hint == 'uuid': 920 formatted = str(uuid.UUID(bytes=raw)) 921 else: 922 formatted = raw 923 return formatted 924 925 def _from_string(self, string, attr_spec): 926 if attr_spec.display_hint in ['ipv4', 'ipv6']: 927 ip = ipaddress.ip_address(string) 928 if attr_spec['type'] == 'binary': 929 raw = ip.packed 930 else: 931 raw = int(ip) 932 else: 933 raise Exception(f"Display hint '{attr_spec.display_hint}' not implemented" 934 f" when parsing '{attr_spec['name']}'") 935 return raw 936 937 def handle_ntf(self, decoded): 938 msg = dict() 939 if self.include_raw: 940 msg['raw'] = decoded 941 op = self.rsp_by_value[decoded.cmd()] 942 attrs = self._decode(decoded.raw_attrs, op.attr_set.name) 943 if op.fixed_header: 944 attrs.update(self._decode_struct(decoded.raw, op.fixed_header)) 945 946 msg['name'] = op['name'] 947 msg['msg'] = attrs 948 self.async_msg_queue.put(msg) 949 950 def check_ntf(self): 951 while True: 952 try: 953 reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT) 954 except BlockingIOError: 955 return 956 957 nms = NlMsgs(reply) 958 self._recv_dbg_print(reply, nms) 959 for nl_msg in nms: 960 if nl_msg.error: 961 print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) 962 print(nl_msg) 963 continue 964 if nl_msg.done: 965 print("Netlink done while checking for ntf!?") 966 continue 967 968 decoded = self.nlproto.decode(self, nl_msg, None) 969 if decoded.cmd() not in self.async_msg_ids: 970 print("Unexpected msg id while checking for ntf", decoded) 971 continue 972 973 self.handle_ntf(decoded) 974 975 def poll_ntf(self, duration=None): 976 start_time = time.time() 977 selector = selectors.DefaultSelector() 978 selector.register(self.sock, selectors.EVENT_READ) 979 980 while True: 981 try: 982 yield self.async_msg_queue.get_nowait() 983 except queue.Empty: 984 if duration is not None: 985 timeout = start_time + duration - time.time() 986 if timeout <= 0: 987 return 988 else: 989 timeout = None 990 events = selector.select(timeout) 991 if events: 992 self.check_ntf() 993 994 def operation_do_attributes(self, name): 995 """ 996 For a given operation name, find and return a supported 997 set of attributes (as a dict). 998 """ 999 op = self.find_operation(name) 1000 if not op: 1001 return None 1002 1003 return op['do']['request']['attributes'].copy() 1004 1005 def _encode_message(self, op, vals, flags, req_seq): 1006 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK 1007 for flag in flags or []: 1008 nl_flags |= flag 1009 1010 msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq) 1011 if op.fixed_header: 1012 msg += self._encode_struct(op.fixed_header, vals) 1013 search_attrs = SpaceAttrs(op.attr_set, vals) 1014 for name, value in vals.items(): 1015 msg += self._add_attr(op.attr_set.name, name, value, search_attrs) 1016 msg = _genl_msg_finalize(msg) 1017 return msg 1018 1019 def _ops(self, ops): 1020 reqs_by_seq = {} 1021 req_seq = random.randint(1024, 65535) 1022 payload = b'' 1023 for (method, vals, flags) in ops: 1024 op = self.ops[method] 1025 msg = self._encode_message(op, vals, flags, req_seq) 1026 reqs_by_seq[req_seq] = (op, vals, msg, flags) 1027 payload += msg 1028 req_seq += 1 1029 1030 self.sock.send(payload, 0) 1031 1032 done = False 1033 rsp = [] 1034 op_rsp = [] 1035 while not done: 1036 reply = self.sock.recv(self._recv_size) 1037 nms = NlMsgs(reply, attr_space=op.attr_set) 1038 self._recv_dbg_print(reply, nms) 1039 for nl_msg in nms: 1040 if nl_msg.nl_seq in reqs_by_seq: 1041 (op, vals, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq] 1042 if nl_msg.extack: 1043 self._decode_extack(req_msg, op, nl_msg.extack, vals) 1044 else: 1045 op = None 1046 req_flags = [] 1047 1048 if nl_msg.error: 1049 raise NlError(nl_msg) 1050 if nl_msg.done: 1051 if nl_msg.extack: 1052 print("Netlink warning:") 1053 print(nl_msg) 1054 1055 if Netlink.NLM_F_DUMP in req_flags: 1056 rsp.append(op_rsp) 1057 elif not op_rsp: 1058 rsp.append(None) 1059 elif len(op_rsp) == 1: 1060 rsp.append(op_rsp[0]) 1061 else: 1062 rsp.append(op_rsp) 1063 op_rsp = [] 1064 1065 del reqs_by_seq[nl_msg.nl_seq] 1066 done = len(reqs_by_seq) == 0 1067 break 1068 1069 decoded = self.nlproto.decode(self, nl_msg, op) 1070 1071 # Check if this is a reply to our request 1072 if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value: 1073 if decoded.cmd() in self.async_msg_ids: 1074 self.handle_ntf(decoded) 1075 continue 1076 else: 1077 print('Unexpected message: ' + repr(decoded)) 1078 continue 1079 1080 rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name) 1081 if op.fixed_header: 1082 rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header)) 1083 op_rsp.append(rsp_msg) 1084 1085 return rsp 1086 1087 def _op(self, method, vals, flags=None, dump=False): 1088 req_flags = flags or [] 1089 if dump: 1090 req_flags.append(Netlink.NLM_F_DUMP) 1091 1092 ops = [(method, vals, req_flags)] 1093 return self._ops(ops)[0] 1094 1095 def do(self, method, vals, flags=None): 1096 return self._op(method, vals, flags) 1097 1098 def dump(self, method, vals): 1099 return self._op(method, vals, dump=True) 1100 1101 def do_multi(self, ops): 1102 return self._ops(ops) 1103