1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause 2 3from collections import namedtuple 4from enum import Enum 5import functools 6import os 7import random 8import socket 9import struct 10from struct import Struct 11import sys 12import yaml 13import ipaddress 14import uuid 15 16from .nlspec import SpecFamily 17 18# 19# Generic Netlink code which should really be in some library, but I can't quickly find one. 20# 21 22 23class Netlink: 24 # Netlink socket 25 SOL_NETLINK = 270 26 27 NETLINK_ADD_MEMBERSHIP = 1 28 NETLINK_CAP_ACK = 10 29 NETLINK_EXT_ACK = 11 30 NETLINK_GET_STRICT_CHK = 12 31 32 # Netlink message 33 NLMSG_ERROR = 2 34 NLMSG_DONE = 3 35 36 NLM_F_REQUEST = 1 37 NLM_F_ACK = 4 38 NLM_F_ROOT = 0x100 39 NLM_F_MATCH = 0x200 40 41 NLM_F_REPLACE = 0x100 42 NLM_F_EXCL = 0x200 43 NLM_F_CREATE = 0x400 44 NLM_F_APPEND = 0x800 45 46 NLM_F_CAPPED = 0x100 47 NLM_F_ACK_TLVS = 0x200 48 49 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH 50 51 NLA_F_NESTED = 0x8000 52 NLA_F_NET_BYTEORDER = 0x4000 53 54 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER 55 56 # Genetlink defines 57 NETLINK_GENERIC = 16 58 59 GENL_ID_CTRL = 0x10 60 61 # nlctrl 62 CTRL_CMD_GETFAMILY = 3 63 64 CTRL_ATTR_FAMILY_ID = 1 65 CTRL_ATTR_FAMILY_NAME = 2 66 CTRL_ATTR_MAXATTR = 5 67 CTRL_ATTR_MCAST_GROUPS = 7 68 69 CTRL_ATTR_MCAST_GRP_NAME = 1 70 CTRL_ATTR_MCAST_GRP_ID = 2 71 72 # Extack types 73 NLMSGERR_ATTR_MSG = 1 74 NLMSGERR_ATTR_OFFS = 2 75 NLMSGERR_ATTR_COOKIE = 3 76 NLMSGERR_ATTR_POLICY = 4 77 NLMSGERR_ATTR_MISS_TYPE = 5 78 NLMSGERR_ATTR_MISS_NEST = 6 79 80 # Policy types 81 NL_POLICY_TYPE_ATTR_TYPE = 1 82 NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2 83 NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3 84 NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4 85 NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5 86 NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6 87 NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7 88 NL_POLICY_TYPE_ATTR_POLICY_IDX = 8 89 NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9 90 NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10 91 NL_POLICY_TYPE_ATTR_PAD = 11 92 NL_POLICY_TYPE_ATTR_MASK = 12 93 94 AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64', 95 's8', 's16', 's32', 's64', 96 'binary', 'string', 'nul-string', 97 'nested', 'nested-array', 98 'bitfield32', 'sint', 'uint']) 99 100class NlError(Exception): 101 def __init__(self, nl_msg): 102 self.nl_msg = nl_msg 103 self.error = -nl_msg.error 104 105 def __str__(self): 106 return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}" 107 108 109class ConfigError(Exception): 110 pass 111 112 113class NlAttr: 114 ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) 115 type_formats = { 116 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), 117 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), 118 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")), 119 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")), 120 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")), 121 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")), 122 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")), 123 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q")) 124 } 125 126 def __init__(self, raw, offset): 127 self._len, self._type = struct.unpack("HH", raw[offset : offset + 4]) 128 self.type = self._type & ~Netlink.NLA_TYPE_MASK 129 self.is_nest = self._type & Netlink.NLA_F_NESTED 130 self.payload_len = self._len 131 self.full_len = (self.payload_len + 3) & ~3 132 self.raw = raw[offset + 4 : offset + self.payload_len] 133 134 @classmethod 135 def get_format(cls, attr_type, byte_order=None): 136 format = cls.type_formats[attr_type] 137 if byte_order: 138 return format.big if byte_order == "big-endian" \ 139 else format.little 140 return format.native 141 142 def as_scalar(self, attr_type, byte_order=None): 143 format = self.get_format(attr_type, byte_order) 144 return format.unpack(self.raw)[0] 145 146 def as_auto_scalar(self, attr_type, byte_order=None): 147 if len(self.raw) != 4 and len(self.raw) != 8: 148 raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}") 149 real_type = attr_type[0] + str(len(self.raw) * 8) 150 format = self.get_format(real_type, byte_order) 151 return format.unpack(self.raw)[0] 152 153 def as_strz(self): 154 return self.raw.decode('ascii')[:-1] 155 156 def as_bin(self): 157 return self.raw 158 159 def as_c_array(self, type): 160 format = self.get_format(type) 161 return [ x[0] for x in format.iter_unpack(self.raw) ] 162 163 def __repr__(self): 164 return f"[type:{self.type} len:{self._len}] {self.raw}" 165 166 167class NlAttrs: 168 def __init__(self, msg, offset=0): 169 self.attrs = [] 170 171 while offset < len(msg): 172 attr = NlAttr(msg, offset) 173 offset += attr.full_len 174 self.attrs.append(attr) 175 176 def __iter__(self): 177 yield from self.attrs 178 179 def __repr__(self): 180 msg = '' 181 for a in self.attrs: 182 if msg: 183 msg += '\n' 184 msg += repr(a) 185 return msg 186 187 188class NlMsg: 189 def __init__(self, msg, offset, attr_space=None): 190 self.hdr = msg[offset : offset + 16] 191 192 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ 193 struct.unpack("IHHII", self.hdr) 194 195 self.raw = msg[offset + 16 : offset + self.nl_len] 196 197 self.error = 0 198 self.done = 0 199 200 extack_off = None 201 if self.nl_type == Netlink.NLMSG_ERROR: 202 self.error = struct.unpack("i", self.raw[0:4])[0] 203 self.done = 1 204 extack_off = 20 205 elif self.nl_type == Netlink.NLMSG_DONE: 206 self.error = struct.unpack("i", self.raw[0:4])[0] 207 self.done = 1 208 extack_off = 4 209 210 self.extack = None 211 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: 212 self.extack = dict() 213 extack_attrs = NlAttrs(self.raw[extack_off:]) 214 for extack in extack_attrs: 215 if extack.type == Netlink.NLMSGERR_ATTR_MSG: 216 self.extack['msg'] = extack.as_strz() 217 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: 218 self.extack['miss-type'] = extack.as_scalar('u32') 219 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: 220 self.extack['miss-nest'] = extack.as_scalar('u32') 221 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: 222 self.extack['bad-attr-offs'] = extack.as_scalar('u32') 223 elif extack.type == Netlink.NLMSGERR_ATTR_POLICY: 224 self.extack['policy'] = self._decode_policy(extack.raw) 225 else: 226 if 'unknown' not in self.extack: 227 self.extack['unknown'] = [] 228 self.extack['unknown'].append(extack) 229 230 if attr_space: 231 # We don't have the ability to parse nests yet, so only do global 232 if 'miss-type' in self.extack and 'miss-nest' not in self.extack: 233 miss_type = self.extack['miss-type'] 234 if miss_type in attr_space.attrs_by_val: 235 spec = attr_space.attrs_by_val[miss_type] 236 self.extack['miss-type'] = spec['name'] 237 if 'doc' in spec: 238 self.extack['miss-type-doc'] = spec['doc'] 239 240 def _decode_policy(self, raw): 241 policy = {} 242 for attr in NlAttrs(raw): 243 if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE: 244 type = attr.as_scalar('u32') 245 policy['type'] = Netlink.AttrType(type).name 246 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S: 247 policy['min-value'] = attr.as_scalar('s64') 248 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S: 249 policy['max-value'] = attr.as_scalar('s64') 250 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U: 251 policy['min-value'] = attr.as_scalar('u64') 252 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U: 253 policy['max-value'] = attr.as_scalar('u64') 254 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH: 255 policy['min-length'] = attr.as_scalar('u32') 256 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH: 257 policy['max-length'] = attr.as_scalar('u32') 258 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK: 259 policy['bitfield32-mask'] = attr.as_scalar('u32') 260 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK: 261 policy['mask'] = attr.as_scalar('u64') 262 return policy 263 264 def cmd(self): 265 return self.nl_type 266 267 def __repr__(self): 268 msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}" 269 if self.error: 270 msg += '\n\terror: ' + str(self.error) 271 if self.extack: 272 msg += '\n\textack: ' + repr(self.extack) 273 return msg 274 275 276class NlMsgs: 277 def __init__(self, data, attr_space=None): 278 self.msgs = [] 279 280 offset = 0 281 while offset < len(data): 282 msg = NlMsg(data, offset, attr_space=attr_space) 283 offset += msg.nl_len 284 self.msgs.append(msg) 285 286 def __iter__(self): 287 yield from self.msgs 288 289 290genl_family_name_to_id = None 291 292 293def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): 294 # we prepend length in _genl_msg_finalize() 295 if seq is None: 296 seq = random.randint(1, 1024) 297 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 298 genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0) 299 return nlmsg + genlmsg 300 301 302def _genl_msg_finalize(msg): 303 return struct.pack("I", len(msg) + 4) + msg 304 305 306def _genl_load_families(): 307 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: 308 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 309 310 msg = _genl_msg(Netlink.GENL_ID_CTRL, 311 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, 312 Netlink.CTRL_CMD_GETFAMILY, 1) 313 msg = _genl_msg_finalize(msg) 314 315 sock.send(msg, 0) 316 317 global genl_family_name_to_id 318 genl_family_name_to_id = dict() 319 320 while True: 321 reply = sock.recv(128 * 1024) 322 nms = NlMsgs(reply) 323 for nl_msg in nms: 324 if nl_msg.error: 325 print("Netlink error:", nl_msg.error) 326 return 327 if nl_msg.done: 328 return 329 330 gm = GenlMsg(nl_msg) 331 fam = dict() 332 for attr in NlAttrs(gm.raw): 333 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: 334 fam['id'] = attr.as_scalar('u16') 335 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: 336 fam['name'] = attr.as_strz() 337 elif attr.type == Netlink.CTRL_ATTR_MAXATTR: 338 fam['maxattr'] = attr.as_scalar('u32') 339 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: 340 fam['mcast'] = dict() 341 for entry in NlAttrs(attr.raw): 342 mcast_name = None 343 mcast_id = None 344 for entry_attr in NlAttrs(entry.raw): 345 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: 346 mcast_name = entry_attr.as_strz() 347 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: 348 mcast_id = entry_attr.as_scalar('u32') 349 if mcast_name and mcast_id is not None: 350 fam['mcast'][mcast_name] = mcast_id 351 if 'name' in fam and 'id' in fam: 352 genl_family_name_to_id[fam['name']] = fam 353 354 355class GenlMsg: 356 def __init__(self, nl_msg): 357 self.nl = nl_msg 358 self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0) 359 self.raw = nl_msg.raw[4:] 360 361 def cmd(self): 362 return self.genl_cmd 363 364 def __repr__(self): 365 msg = repr(self.nl) 366 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" 367 for a in self.raw_attrs: 368 msg += '\t\t' + repr(a) + '\n' 369 return msg 370 371 372class NetlinkProtocol: 373 def __init__(self, family_name, proto_num): 374 self.family_name = family_name 375 self.proto_num = proto_num 376 377 def _message(self, nl_type, nl_flags, seq=None): 378 if seq is None: 379 seq = random.randint(1, 1024) 380 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 381 return nlmsg 382 383 def message(self, flags, command, version, seq=None): 384 return self._message(command, flags, seq) 385 386 def _decode(self, nl_msg): 387 return nl_msg 388 389 def decode(self, ynl, nl_msg, op): 390 msg = self._decode(nl_msg) 391 if op is None: 392 op = ynl.rsp_by_value[msg.cmd()] 393 fixed_header_size = ynl._struct_size(op.fixed_header) 394 msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size) 395 return msg 396 397 def get_mcast_id(self, mcast_name, mcast_groups): 398 if mcast_name not in mcast_groups: 399 raise Exception(f'Multicast group "{mcast_name}" not present in the spec') 400 return mcast_groups[mcast_name].value 401 402 def msghdr_size(self): 403 return 16 404 405 406class GenlProtocol(NetlinkProtocol): 407 def __init__(self, family_name): 408 super().__init__(family_name, Netlink.NETLINK_GENERIC) 409 410 global genl_family_name_to_id 411 if genl_family_name_to_id is None: 412 _genl_load_families() 413 414 self.genl_family = genl_family_name_to_id[family_name] 415 self.family_id = genl_family_name_to_id[family_name]['id'] 416 417 def message(self, flags, command, version, seq=None): 418 nlmsg = self._message(self.family_id, flags, seq) 419 genlmsg = struct.pack("BBH", command, version, 0) 420 return nlmsg + genlmsg 421 422 def _decode(self, nl_msg): 423 return GenlMsg(nl_msg) 424 425 def get_mcast_id(self, mcast_name, mcast_groups): 426 if mcast_name not in self.genl_family['mcast']: 427 raise Exception(f'Multicast group "{mcast_name}" not present in the family') 428 return self.genl_family['mcast'][mcast_name] 429 430 def msghdr_size(self): 431 return super().msghdr_size() + 4 432 433 434class SpaceAttrs: 435 SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values']) 436 437 def __init__(self, attr_space, attrs, outer = None): 438 outer_scopes = outer.scopes if outer else [] 439 inner_scope = self.SpecValuesPair(attr_space, attrs) 440 self.scopes = [inner_scope] + outer_scopes 441 442 def lookup(self, name): 443 for scope in self.scopes: 444 if name in scope.spec: 445 if name in scope.values: 446 return scope.values[name] 447 spec_name = scope.spec.yaml['name'] 448 raise Exception( 449 f"No value for '{name}' in attribute space '{spec_name}'") 450 raise Exception(f"Attribute '{name}' not defined in any attribute-set") 451 452 453# 454# YNL implementation details. 455# 456 457 458class YnlFamily(SpecFamily): 459 def __init__(self, def_path, schema=None, process_unknown=False, 460 recv_size=0): 461 super().__init__(def_path, schema) 462 463 self.include_raw = False 464 self.process_unknown = process_unknown 465 466 try: 467 if self.proto == "netlink-raw": 468 self.nlproto = NetlinkProtocol(self.yaml['name'], 469 self.yaml['protonum']) 470 else: 471 self.nlproto = GenlProtocol(self.yaml['name']) 472 except KeyError: 473 raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel") 474 475 self._recv_dbg = False 476 # Note that netlink will use conservative (min) message size for 477 # the first dump recv() on the socket, our setting will only matter 478 # from the second recv() on. 479 self._recv_size = recv_size if recv_size else 131072 480 # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo) 481 # for a message, so smaller receive sizes will lead to truncation. 482 # Note that the min size for other families may be larger than 4k! 483 if self._recv_size < 4000: 484 raise ConfigError() 485 486 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num) 487 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 488 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) 489 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1) 490 491 self.async_msg_ids = set() 492 self.async_msg_queue = [] 493 494 for msg in self.msgs.values(): 495 if msg.is_async: 496 self.async_msg_ids.add(msg.rsp_value) 497 498 for op_name, op in self.ops.items(): 499 bound_f = functools.partial(self._op, op_name) 500 setattr(self, op.ident_name, bound_f) 501 502 503 def ntf_subscribe(self, mcast_name): 504 mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups) 505 self.sock.bind((0, 0)) 506 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, 507 mcast_id) 508 509 def set_recv_dbg(self, enabled): 510 self._recv_dbg = enabled 511 512 def _recv_dbg_print(self, reply, nl_msgs): 513 if not self._recv_dbg: 514 return 515 print("Recv: read", len(reply), "bytes,", 516 len(nl_msgs.msgs), "messages", file=sys.stderr) 517 for nl_msg in nl_msgs: 518 print(" ", nl_msg, file=sys.stderr) 519 520 def _encode_enum(self, attr_spec, value): 521 enum = self.consts[attr_spec['enum']] 522 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 523 scalar = 0 524 if isinstance(value, str): 525 value = [value] 526 for single_value in value: 527 scalar += enum.entries[single_value].user_value(as_flags = True) 528 return scalar 529 else: 530 return enum.entries[value].user_value() 531 532 def _get_scalar(self, attr_spec, value): 533 try: 534 return int(value) 535 except (ValueError, TypeError) as e: 536 if 'enum' not in attr_spec: 537 raise e 538 return self._encode_enum(attr_spec, value) 539 540 def _add_attr(self, space, name, value, search_attrs): 541 try: 542 attr = self.attr_sets[space][name] 543 except KeyError: 544 raise Exception(f"Space '{space}' has no attribute '{name}'") 545 nl_type = attr.value 546 547 if attr.is_multi and isinstance(value, list): 548 attr_payload = b'' 549 for subvalue in value: 550 attr_payload += self._add_attr(space, name, subvalue, search_attrs) 551 return attr_payload 552 553 if attr["type"] == 'nest': 554 nl_type |= Netlink.NLA_F_NESTED 555 attr_payload = b'' 556 sub_attrs = SpaceAttrs(self.attr_sets[space], value, search_attrs) 557 for subname, subvalue in value.items(): 558 attr_payload += self._add_attr(attr['nested-attributes'], 559 subname, subvalue, sub_attrs) 560 elif attr["type"] == 'flag': 561 if not value: 562 # If value is absent or false then skip attribute creation. 563 return b'' 564 attr_payload = b'' 565 elif attr["type"] == 'string': 566 attr_payload = str(value).encode('ascii') + b'\x00' 567 elif attr["type"] == 'binary': 568 if isinstance(value, bytes): 569 attr_payload = value 570 elif isinstance(value, str): 571 attr_payload = bytes.fromhex(value) 572 elif isinstance(value, dict) and attr.struct_name: 573 attr_payload = self._encode_struct(attr.struct_name, value) 574 else: 575 raise Exception(f'Unknown type for binary attribute, value: {value}') 576 elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar: 577 scalar = self._get_scalar(attr, value) 578 if attr.is_auto_scalar: 579 attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64') 580 else: 581 attr_type = attr["type"] 582 format = NlAttr.get_format(attr_type, attr.byte_order) 583 attr_payload = format.pack(scalar) 584 elif attr['type'] in "bitfield32": 585 scalar_value = self._get_scalar(attr, value["value"]) 586 scalar_selector = self._get_scalar(attr, value["selector"]) 587 attr_payload = struct.pack("II", scalar_value, scalar_selector) 588 elif attr['type'] == 'sub-message': 589 msg_format = self._resolve_selector(attr, search_attrs) 590 attr_payload = b'' 591 if msg_format.fixed_header: 592 attr_payload += self._encode_struct(msg_format.fixed_header, value) 593 if msg_format.attr_set: 594 if msg_format.attr_set in self.attr_sets: 595 nl_type |= Netlink.NLA_F_NESTED 596 sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs) 597 for subname, subvalue in value.items(): 598 attr_payload += self._add_attr(msg_format.attr_set, 599 subname, subvalue, sub_attrs) 600 else: 601 raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'") 602 else: 603 raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') 604 605 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) 606 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad 607 608 def _decode_enum(self, raw, attr_spec): 609 enum = self.consts[attr_spec['enum']] 610 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 611 i = 0 612 value = set() 613 while raw: 614 if raw & 1: 615 value.add(enum.entries_by_val[i].name) 616 raw >>= 1 617 i += 1 618 else: 619 value = enum.entries_by_val[raw].name 620 return value 621 622 def _decode_binary(self, attr, attr_spec): 623 if attr_spec.struct_name: 624 decoded = self._decode_struct(attr.raw, attr_spec.struct_name) 625 elif attr_spec.sub_type: 626 decoded = attr.as_c_array(attr_spec.sub_type) 627 else: 628 decoded = attr.as_bin() 629 if attr_spec.display_hint: 630 decoded = self._formatted_string(decoded, attr_spec.display_hint) 631 return decoded 632 633 def _decode_array_attr(self, attr, attr_spec): 634 decoded = [] 635 offset = 0 636 while offset < len(attr.raw): 637 item = NlAttr(attr.raw, offset) 638 offset += item.full_len 639 640 if attr_spec["sub-type"] == 'nest': 641 subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes']) 642 decoded.append({ item.type: subattrs }) 643 elif attr_spec["sub-type"] == 'binary': 644 subattrs = item.as_bin() 645 if attr_spec.display_hint: 646 subattrs = self._formatted_string(subattrs, attr_spec.display_hint) 647 decoded.append(subattrs) 648 elif attr_spec["sub-type"] in NlAttr.type_formats: 649 subattrs = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order) 650 if attr_spec.display_hint: 651 subattrs = self._formatted_string(subattrs, attr_spec.display_hint) 652 decoded.append(subattrs) 653 else: 654 raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}') 655 return decoded 656 657 def _decode_nest_type_value(self, attr, attr_spec): 658 decoded = {} 659 value = attr 660 for name in attr_spec['type-value']: 661 value = NlAttr(value.raw, 0) 662 decoded[name] = value.type 663 subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes']) 664 decoded.update(subattrs) 665 return decoded 666 667 def _decode_unknown(self, attr): 668 if attr.is_nest: 669 return self._decode(NlAttrs(attr.raw), None) 670 else: 671 return attr.as_bin() 672 673 def _rsp_add(self, rsp, name, is_multi, decoded): 674 if is_multi == None: 675 if name in rsp and type(rsp[name]) is not list: 676 rsp[name] = [rsp[name]] 677 is_multi = True 678 else: 679 is_multi = False 680 681 if not is_multi: 682 rsp[name] = decoded 683 elif name in rsp: 684 rsp[name].append(decoded) 685 else: 686 rsp[name] = [decoded] 687 688 def _resolve_selector(self, attr_spec, search_attrs): 689 sub_msg = attr_spec.sub_message 690 if sub_msg not in self.sub_msgs: 691 raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}") 692 sub_msg_spec = self.sub_msgs[sub_msg] 693 694 selector = attr_spec.selector 695 value = search_attrs.lookup(selector) 696 if value not in sub_msg_spec.formats: 697 raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'") 698 699 spec = sub_msg_spec.formats[value] 700 return spec 701 702 def _decode_sub_msg(self, attr, attr_spec, search_attrs): 703 msg_format = self._resolve_selector(attr_spec, search_attrs) 704 decoded = {} 705 offset = 0 706 if msg_format.fixed_header: 707 decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header)); 708 offset = self._struct_size(msg_format.fixed_header) 709 if msg_format.attr_set: 710 if msg_format.attr_set in self.attr_sets: 711 subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set) 712 decoded.update(subdict) 713 else: 714 raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'") 715 return decoded 716 717 def _decode(self, attrs, space, outer_attrs = None): 718 rsp = dict() 719 if space: 720 attr_space = self.attr_sets[space] 721 search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs) 722 723 for attr in attrs: 724 try: 725 attr_spec = attr_space.attrs_by_val[attr.type] 726 except (KeyError, UnboundLocalError): 727 if not self.process_unknown: 728 raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'") 729 attr_name = f"UnknownAttr({attr.type})" 730 self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr)) 731 continue 732 733 if attr_spec["type"] == 'nest': 734 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs) 735 decoded = subdict 736 elif attr_spec["type"] == 'string': 737 decoded = attr.as_strz() 738 elif attr_spec["type"] == 'binary': 739 decoded = self._decode_binary(attr, attr_spec) 740 elif attr_spec["type"] == 'flag': 741 decoded = True 742 elif attr_spec.is_auto_scalar: 743 decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order) 744 elif attr_spec["type"] in NlAttr.type_formats: 745 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) 746 if 'enum' in attr_spec: 747 decoded = self._decode_enum(decoded, attr_spec) 748 elif attr_spec.display_hint: 749 decoded = self._formatted_string(decoded, attr_spec.display_hint) 750 elif attr_spec["type"] == 'indexed-array': 751 decoded = self._decode_array_attr(attr, attr_spec) 752 elif attr_spec["type"] == 'bitfield32': 753 value, selector = struct.unpack("II", attr.raw) 754 if 'enum' in attr_spec: 755 value = self._decode_enum(value, attr_spec) 756 selector = self._decode_enum(selector, attr_spec) 757 decoded = {"value": value, "selector": selector} 758 elif attr_spec["type"] == 'sub-message': 759 decoded = self._decode_sub_msg(attr, attr_spec, search_attrs) 760 elif attr_spec["type"] == 'nest-type-value': 761 decoded = self._decode_nest_type_value(attr, attr_spec) 762 else: 763 if not self.process_unknown: 764 raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') 765 decoded = self._decode_unknown(attr) 766 767 self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded) 768 769 return rsp 770 771 def _decode_extack_path(self, attrs, attr_set, offset, target): 772 for attr in attrs: 773 try: 774 attr_spec = attr_set.attrs_by_val[attr.type] 775 except KeyError: 776 raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") 777 if offset > target: 778 break 779 if offset == target: 780 return '.' + attr_spec.name 781 782 if offset + attr.full_len <= target: 783 offset += attr.full_len 784 continue 785 if attr_spec['type'] != 'nest': 786 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") 787 offset += 4 788 subpath = self._decode_extack_path(NlAttrs(attr.raw), 789 self.attr_sets[attr_spec['nested-attributes']], 790 offset, target) 791 if subpath is None: 792 return None 793 return '.' + attr_spec.name + subpath 794 795 return None 796 797 def _decode_extack(self, request, op, extack): 798 if 'bad-attr-offs' not in extack: 799 return 800 801 msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op) 802 offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header) 803 path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset, 804 extack['bad-attr-offs']) 805 if path: 806 del extack['bad-attr-offs'] 807 extack['bad-attr'] = path 808 809 def _struct_size(self, name): 810 if name: 811 members = self.consts[name].members 812 size = 0 813 for m in members: 814 if m.type in ['pad', 'binary']: 815 if m.struct: 816 size += self._struct_size(m.struct) 817 else: 818 size += m.len 819 else: 820 format = NlAttr.get_format(m.type, m.byte_order) 821 size += format.size 822 return size 823 else: 824 return 0 825 826 def _decode_struct(self, data, name): 827 members = self.consts[name].members 828 attrs = dict() 829 offset = 0 830 for m in members: 831 value = None 832 if m.type == 'pad': 833 offset += m.len 834 elif m.type == 'binary': 835 if m.struct: 836 len = self._struct_size(m.struct) 837 value = self._decode_struct(data[offset : offset + len], 838 m.struct) 839 offset += len 840 else: 841 value = data[offset : offset + m.len] 842 offset += m.len 843 else: 844 format = NlAttr.get_format(m.type, m.byte_order) 845 [ value ] = format.unpack_from(data, offset) 846 offset += format.size 847 if value is not None: 848 if m.enum: 849 value = self._decode_enum(value, m) 850 elif m.display_hint: 851 value = self._formatted_string(value, m.display_hint) 852 attrs[m.name] = value 853 return attrs 854 855 def _encode_struct(self, name, vals): 856 members = self.consts[name].members 857 attr_payload = b'' 858 for m in members: 859 value = vals.pop(m.name) if m.name in vals else None 860 if m.type == 'pad': 861 attr_payload += bytearray(m.len) 862 elif m.type == 'binary': 863 if m.struct: 864 if value is None: 865 value = dict() 866 attr_payload += self._encode_struct(m.struct, value) 867 else: 868 if value is None: 869 attr_payload += bytearray(m.len) 870 else: 871 attr_payload += bytes.fromhex(value) 872 else: 873 if value is None: 874 value = 0 875 format = NlAttr.get_format(m.type, m.byte_order) 876 attr_payload += format.pack(value) 877 return attr_payload 878 879 def _formatted_string(self, raw, display_hint): 880 if display_hint == 'mac': 881 formatted = ':'.join('%02x' % b for b in raw) 882 elif display_hint == 'hex': 883 if isinstance(raw, int): 884 formatted = hex(raw) 885 else: 886 formatted = bytes.hex(raw, ' ') 887 elif display_hint in [ 'ipv4', 'ipv6' ]: 888 formatted = format(ipaddress.ip_address(raw)) 889 elif display_hint == 'uuid': 890 formatted = str(uuid.UUID(bytes=raw)) 891 else: 892 formatted = raw 893 return formatted 894 895 def handle_ntf(self, decoded): 896 msg = dict() 897 if self.include_raw: 898 msg['raw'] = decoded 899 op = self.rsp_by_value[decoded.cmd()] 900 attrs = self._decode(decoded.raw_attrs, op.attr_set.name) 901 if op.fixed_header: 902 attrs.update(self._decode_struct(decoded.raw, op.fixed_header)) 903 904 msg['name'] = op['name'] 905 msg['msg'] = attrs 906 self.async_msg_queue.append(msg) 907 908 def check_ntf(self): 909 while True: 910 try: 911 reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT) 912 except BlockingIOError: 913 return 914 915 nms = NlMsgs(reply) 916 self._recv_dbg_print(reply, nms) 917 for nl_msg in nms: 918 if nl_msg.error: 919 print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) 920 print(nl_msg) 921 continue 922 if nl_msg.done: 923 print("Netlink done while checking for ntf!?") 924 continue 925 926 decoded = self.nlproto.decode(self, nl_msg, None) 927 if decoded.cmd() not in self.async_msg_ids: 928 print("Unexpected msg id done while checking for ntf", decoded) 929 continue 930 931 self.handle_ntf(decoded) 932 933 def operation_do_attributes(self, name): 934 """ 935 For a given operation name, find and return a supported 936 set of attributes (as a dict). 937 """ 938 op = self.find_operation(name) 939 if not op: 940 return None 941 942 return op['do']['request']['attributes'].copy() 943 944 def _encode_message(self, op, vals, flags, req_seq): 945 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK 946 for flag in flags or []: 947 nl_flags |= flag 948 949 msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq) 950 if op.fixed_header: 951 msg += self._encode_struct(op.fixed_header, vals) 952 search_attrs = SpaceAttrs(op.attr_set, vals) 953 for name, value in vals.items(): 954 msg += self._add_attr(op.attr_set.name, name, value, search_attrs) 955 msg = _genl_msg_finalize(msg) 956 return msg 957 958 def _ops(self, ops): 959 reqs_by_seq = {} 960 req_seq = random.randint(1024, 65535) 961 payload = b'' 962 for (method, vals, flags) in ops: 963 op = self.ops[method] 964 msg = self._encode_message(op, vals, flags, req_seq) 965 reqs_by_seq[req_seq] = (op, msg, flags) 966 payload += msg 967 req_seq += 1 968 969 self.sock.send(payload, 0) 970 971 done = False 972 rsp = [] 973 op_rsp = [] 974 while not done: 975 reply = self.sock.recv(self._recv_size) 976 nms = NlMsgs(reply, attr_space=op.attr_set) 977 self._recv_dbg_print(reply, nms) 978 for nl_msg in nms: 979 if nl_msg.nl_seq in reqs_by_seq: 980 (op, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq] 981 if nl_msg.extack: 982 self._decode_extack(req_msg, op, nl_msg.extack) 983 else: 984 op = None 985 req_flags = [] 986 987 if nl_msg.error: 988 raise NlError(nl_msg) 989 if nl_msg.done: 990 if nl_msg.extack: 991 print("Netlink warning:") 992 print(nl_msg) 993 994 if Netlink.NLM_F_DUMP in req_flags: 995 rsp.append(op_rsp) 996 elif not op_rsp: 997 rsp.append(None) 998 elif len(op_rsp) == 1: 999 rsp.append(op_rsp[0]) 1000 else: 1001 rsp.append(op_rsp) 1002 op_rsp = [] 1003 1004 del reqs_by_seq[nl_msg.nl_seq] 1005 done = len(reqs_by_seq) == 0 1006 break 1007 1008 decoded = self.nlproto.decode(self, nl_msg, op) 1009 1010 # Check if this is a reply to our request 1011 if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value: 1012 if decoded.cmd() in self.async_msg_ids: 1013 self.handle_ntf(decoded) 1014 continue 1015 else: 1016 print('Unexpected message: ' + repr(decoded)) 1017 continue 1018 1019 rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name) 1020 if op.fixed_header: 1021 rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header)) 1022 op_rsp.append(rsp_msg) 1023 1024 return rsp 1025 1026 def _op(self, method, vals, flags=None, dump=False): 1027 req_flags = flags or [] 1028 if dump: 1029 req_flags.append(Netlink.NLM_F_DUMP) 1030 1031 ops = [(method, vals, req_flags)] 1032 return self._ops(ops)[0] 1033 1034 def do(self, method, vals, flags=None): 1035 return self._op(method, vals, flags) 1036 1037 def dump(self, method, vals): 1038 return self._op(method, vals, dump=True) 1039 1040 def do_multi(self, ops): 1041 return self._ops(ops) 1042