1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause 2 3from collections import namedtuple 4from enum import Enum 5import functools 6import os 7import random 8import socket 9import struct 10from struct import Struct 11import sys 12import yaml 13import ipaddress 14import uuid 15 16from .nlspec import SpecFamily 17 18# 19# Generic Netlink code which should really be in some library, but I can't quickly find one. 20# 21 22 23class Netlink: 24 # Netlink socket 25 SOL_NETLINK = 270 26 27 NETLINK_ADD_MEMBERSHIP = 1 28 NETLINK_CAP_ACK = 10 29 NETLINK_EXT_ACK = 11 30 NETLINK_GET_STRICT_CHK = 12 31 32 # Netlink message 33 NLMSG_ERROR = 2 34 NLMSG_DONE = 3 35 36 NLM_F_REQUEST = 1 37 NLM_F_ACK = 4 38 NLM_F_ROOT = 0x100 39 NLM_F_MATCH = 0x200 40 41 NLM_F_REPLACE = 0x100 42 NLM_F_EXCL = 0x200 43 NLM_F_CREATE = 0x400 44 NLM_F_APPEND = 0x800 45 46 NLM_F_CAPPED = 0x100 47 NLM_F_ACK_TLVS = 0x200 48 49 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH 50 51 NLA_F_NESTED = 0x8000 52 NLA_F_NET_BYTEORDER = 0x4000 53 54 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER 55 56 # Genetlink defines 57 NETLINK_GENERIC = 16 58 59 GENL_ID_CTRL = 0x10 60 61 # nlctrl 62 CTRL_CMD_GETFAMILY = 3 63 64 CTRL_ATTR_FAMILY_ID = 1 65 CTRL_ATTR_FAMILY_NAME = 2 66 CTRL_ATTR_MAXATTR = 5 67 CTRL_ATTR_MCAST_GROUPS = 7 68 69 CTRL_ATTR_MCAST_GRP_NAME = 1 70 CTRL_ATTR_MCAST_GRP_ID = 2 71 72 # Extack types 73 NLMSGERR_ATTR_MSG = 1 74 NLMSGERR_ATTR_OFFS = 2 75 NLMSGERR_ATTR_COOKIE = 3 76 NLMSGERR_ATTR_POLICY = 4 77 NLMSGERR_ATTR_MISS_TYPE = 5 78 NLMSGERR_ATTR_MISS_NEST = 6 79 80 # Policy types 81 NL_POLICY_TYPE_ATTR_TYPE = 1 82 NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2 83 NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3 84 NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4 85 NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5 86 NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6 87 NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7 88 NL_POLICY_TYPE_ATTR_POLICY_IDX = 8 89 NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9 90 NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10 91 NL_POLICY_TYPE_ATTR_PAD = 11 92 NL_POLICY_TYPE_ATTR_MASK = 12 93 94 AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64', 95 's8', 's16', 's32', 's64', 96 'binary', 'string', 'nul-string', 97 'nested', 'nested-array', 98 'bitfield32', 'sint', 'uint']) 99 100class NlError(Exception): 101 def __init__(self, nl_msg): 102 self.nl_msg = nl_msg 103 self.error = -nl_msg.error 104 105 def __str__(self): 106 return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}" 107 108 109class ConfigError(Exception): 110 pass 111 112 113class NlAttr: 114 ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) 115 type_formats = { 116 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), 117 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), 118 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")), 119 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")), 120 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")), 121 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")), 122 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")), 123 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q")) 124 } 125 126 def __init__(self, raw, offset): 127 self._len, self._type = struct.unpack("HH", raw[offset : offset + 4]) 128 self.type = self._type & ~Netlink.NLA_TYPE_MASK 129 self.is_nest = self._type & Netlink.NLA_F_NESTED 130 self.payload_len = self._len 131 self.full_len = (self.payload_len + 3) & ~3 132 self.raw = raw[offset + 4 : offset + self.payload_len] 133 134 @classmethod 135 def get_format(cls, attr_type, byte_order=None): 136 format = cls.type_formats[attr_type] 137 if byte_order: 138 return format.big if byte_order == "big-endian" \ 139 else format.little 140 return format.native 141 142 def as_scalar(self, attr_type, byte_order=None): 143 format = self.get_format(attr_type, byte_order) 144 return format.unpack(self.raw)[0] 145 146 def as_auto_scalar(self, attr_type, byte_order=None): 147 if len(self.raw) != 4 and len(self.raw) != 8: 148 raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}") 149 real_type = attr_type[0] + str(len(self.raw) * 8) 150 format = self.get_format(real_type, byte_order) 151 return format.unpack(self.raw)[0] 152 153 def as_strz(self): 154 return self.raw.decode('ascii')[:-1] 155 156 def as_bin(self): 157 return self.raw 158 159 def as_c_array(self, type): 160 format = self.get_format(type) 161 return [ x[0] for x in format.iter_unpack(self.raw) ] 162 163 def __repr__(self): 164 return f"[type:{self.type} len:{self._len}] {self.raw}" 165 166 167class NlAttrs: 168 def __init__(self, msg, offset=0): 169 self.attrs = [] 170 171 while offset < len(msg): 172 attr = NlAttr(msg, offset) 173 offset += attr.full_len 174 self.attrs.append(attr) 175 176 def __iter__(self): 177 yield from self.attrs 178 179 def __repr__(self): 180 msg = '' 181 for a in self.attrs: 182 if msg: 183 msg += '\n' 184 msg += repr(a) 185 return msg 186 187 188class NlMsg: 189 def __init__(self, msg, offset, attr_space=None): 190 self.hdr = msg[offset : offset + 16] 191 192 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ 193 struct.unpack("IHHII", self.hdr) 194 195 self.raw = msg[offset + 16 : offset + self.nl_len] 196 197 self.error = 0 198 self.done = 0 199 200 extack_off = None 201 if self.nl_type == Netlink.NLMSG_ERROR: 202 self.error = struct.unpack("i", self.raw[0:4])[0] 203 self.done = 1 204 extack_off = 20 205 elif self.nl_type == Netlink.NLMSG_DONE: 206 self.error = struct.unpack("i", self.raw[0:4])[0] 207 self.done = 1 208 extack_off = 4 209 210 self.extack = None 211 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: 212 self.extack = dict() 213 extack_attrs = NlAttrs(self.raw[extack_off:]) 214 for extack in extack_attrs: 215 if extack.type == Netlink.NLMSGERR_ATTR_MSG: 216 self.extack['msg'] = extack.as_strz() 217 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: 218 self.extack['miss-type'] = extack.as_scalar('u32') 219 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: 220 self.extack['miss-nest'] = extack.as_scalar('u32') 221 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: 222 self.extack['bad-attr-offs'] = extack.as_scalar('u32') 223 elif extack.type == Netlink.NLMSGERR_ATTR_POLICY: 224 self.extack['policy'] = self._decode_policy(extack.raw) 225 else: 226 if 'unknown' not in self.extack: 227 self.extack['unknown'] = [] 228 self.extack['unknown'].append(extack) 229 230 if attr_space: 231 # We don't have the ability to parse nests yet, so only do global 232 if 'miss-type' in self.extack and 'miss-nest' not in self.extack: 233 miss_type = self.extack['miss-type'] 234 if miss_type in attr_space.attrs_by_val: 235 spec = attr_space.attrs_by_val[miss_type] 236 self.extack['miss-type'] = spec['name'] 237 if 'doc' in spec: 238 self.extack['miss-type-doc'] = spec['doc'] 239 240 def _decode_policy(self, raw): 241 policy = {} 242 for attr in NlAttrs(raw): 243 if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE: 244 type = attr.as_scalar('u32') 245 policy['type'] = Netlink.AttrType(type).name 246 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S: 247 policy['min-value'] = attr.as_scalar('s64') 248 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S: 249 policy['max-value'] = attr.as_scalar('s64') 250 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U: 251 policy['min-value'] = attr.as_scalar('u64') 252 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U: 253 policy['max-value'] = attr.as_scalar('u64') 254 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH: 255 policy['min-length'] = attr.as_scalar('u32') 256 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH: 257 policy['max-length'] = attr.as_scalar('u32') 258 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK: 259 policy['bitfield32-mask'] = attr.as_scalar('u32') 260 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK: 261 policy['mask'] = attr.as_scalar('u64') 262 return policy 263 264 def cmd(self): 265 return self.nl_type 266 267 def __repr__(self): 268 msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}" 269 if self.error: 270 msg += '\n\terror: ' + str(self.error) 271 if self.extack: 272 msg += '\n\textack: ' + repr(self.extack) 273 return msg 274 275 276class NlMsgs: 277 def __init__(self, data, attr_space=None): 278 self.msgs = [] 279 280 offset = 0 281 while offset < len(data): 282 msg = NlMsg(data, offset, attr_space=attr_space) 283 offset += msg.nl_len 284 self.msgs.append(msg) 285 286 def __iter__(self): 287 yield from self.msgs 288 289 290genl_family_name_to_id = None 291 292 293def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): 294 # we prepend length in _genl_msg_finalize() 295 if seq is None: 296 seq = random.randint(1, 1024) 297 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 298 genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0) 299 return nlmsg + genlmsg 300 301 302def _genl_msg_finalize(msg): 303 return struct.pack("I", len(msg) + 4) + msg 304 305 306def _genl_load_families(): 307 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: 308 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 309 310 msg = _genl_msg(Netlink.GENL_ID_CTRL, 311 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, 312 Netlink.CTRL_CMD_GETFAMILY, 1) 313 msg = _genl_msg_finalize(msg) 314 315 sock.send(msg, 0) 316 317 global genl_family_name_to_id 318 genl_family_name_to_id = dict() 319 320 while True: 321 reply = sock.recv(128 * 1024) 322 nms = NlMsgs(reply) 323 for nl_msg in nms: 324 if nl_msg.error: 325 print("Netlink error:", nl_msg.error) 326 return 327 if nl_msg.done: 328 return 329 330 gm = GenlMsg(nl_msg) 331 fam = dict() 332 for attr in NlAttrs(gm.raw): 333 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: 334 fam['id'] = attr.as_scalar('u16') 335 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: 336 fam['name'] = attr.as_strz() 337 elif attr.type == Netlink.CTRL_ATTR_MAXATTR: 338 fam['maxattr'] = attr.as_scalar('u32') 339 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: 340 fam['mcast'] = dict() 341 for entry in NlAttrs(attr.raw): 342 mcast_name = None 343 mcast_id = None 344 for entry_attr in NlAttrs(entry.raw): 345 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: 346 mcast_name = entry_attr.as_strz() 347 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: 348 mcast_id = entry_attr.as_scalar('u32') 349 if mcast_name and mcast_id is not None: 350 fam['mcast'][mcast_name] = mcast_id 351 if 'name' in fam and 'id' in fam: 352 genl_family_name_to_id[fam['name']] = fam 353 354 355class GenlMsg: 356 def __init__(self, nl_msg): 357 self.nl = nl_msg 358 self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0) 359 self.raw = nl_msg.raw[4:] 360 361 def cmd(self): 362 return self.genl_cmd 363 364 def __repr__(self): 365 msg = repr(self.nl) 366 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" 367 for a in self.raw_attrs: 368 msg += '\t\t' + repr(a) + '\n' 369 return msg 370 371 372class NetlinkProtocol: 373 def __init__(self, family_name, proto_num): 374 self.family_name = family_name 375 self.proto_num = proto_num 376 377 def _message(self, nl_type, nl_flags, seq=None): 378 if seq is None: 379 seq = random.randint(1, 1024) 380 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 381 return nlmsg 382 383 def message(self, flags, command, version, seq=None): 384 return self._message(command, flags, seq) 385 386 def _decode(self, nl_msg): 387 return nl_msg 388 389 def decode(self, ynl, nl_msg, op): 390 msg = self._decode(nl_msg) 391 fixed_header_size = ynl._struct_size(op.fixed_header) 392 msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size) 393 return msg 394 395 def get_mcast_id(self, mcast_name, mcast_groups): 396 if mcast_name not in mcast_groups: 397 raise Exception(f'Multicast group "{mcast_name}" not present in the spec') 398 return mcast_groups[mcast_name].value 399 400 def msghdr_size(self): 401 return 16 402 403 404class GenlProtocol(NetlinkProtocol): 405 def __init__(self, family_name): 406 super().__init__(family_name, Netlink.NETLINK_GENERIC) 407 408 global genl_family_name_to_id 409 if genl_family_name_to_id is None: 410 _genl_load_families() 411 412 self.genl_family = genl_family_name_to_id[family_name] 413 self.family_id = genl_family_name_to_id[family_name]['id'] 414 415 def message(self, flags, command, version, seq=None): 416 nlmsg = self._message(self.family_id, flags, seq) 417 genlmsg = struct.pack("BBH", command, version, 0) 418 return nlmsg + genlmsg 419 420 def _decode(self, nl_msg): 421 return GenlMsg(nl_msg) 422 423 def get_mcast_id(self, mcast_name, mcast_groups): 424 if mcast_name not in self.genl_family['mcast']: 425 raise Exception(f'Multicast group "{mcast_name}" not present in the family') 426 return self.genl_family['mcast'][mcast_name] 427 428 def msghdr_size(self): 429 return super().msghdr_size() + 4 430 431 432class SpaceAttrs: 433 SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values']) 434 435 def __init__(self, attr_space, attrs, outer = None): 436 outer_scopes = outer.scopes if outer else [] 437 inner_scope = self.SpecValuesPair(attr_space, attrs) 438 self.scopes = [inner_scope] + outer_scopes 439 440 def lookup(self, name): 441 for scope in self.scopes: 442 if name in scope.spec: 443 if name in scope.values: 444 return scope.values[name] 445 spec_name = scope.spec.yaml['name'] 446 raise Exception( 447 f"No value for '{name}' in attribute space '{spec_name}'") 448 raise Exception(f"Attribute '{name}' not defined in any attribute-set") 449 450 451# 452# YNL implementation details. 453# 454 455 456class YnlFamily(SpecFamily): 457 def __init__(self, def_path, schema=None, process_unknown=False, 458 recv_size=0): 459 super().__init__(def_path, schema) 460 461 self.include_raw = False 462 self.process_unknown = process_unknown 463 464 try: 465 if self.proto == "netlink-raw": 466 self.nlproto = NetlinkProtocol(self.yaml['name'], 467 self.yaml['protonum']) 468 else: 469 self.nlproto = GenlProtocol(self.yaml['name']) 470 except KeyError: 471 raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel") 472 473 self._recv_dbg = False 474 # Note that netlink will use conservative (min) message size for 475 # the first dump recv() on the socket, our setting will only matter 476 # from the second recv() on. 477 self._recv_size = recv_size if recv_size else 131072 478 # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo) 479 # for a message, so smaller receive sizes will lead to truncation. 480 # Note that the min size for other families may be larger than 4k! 481 if self._recv_size < 4000: 482 raise ConfigError() 483 484 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num) 485 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 486 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) 487 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1) 488 489 self.async_msg_ids = set() 490 self.async_msg_queue = [] 491 492 for msg in self.msgs.values(): 493 if msg.is_async: 494 self.async_msg_ids.add(msg.rsp_value) 495 496 for op_name, op in self.ops.items(): 497 bound_f = functools.partial(self._op, op_name) 498 setattr(self, op.ident_name, bound_f) 499 500 501 def ntf_subscribe(self, mcast_name): 502 mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups) 503 self.sock.bind((0, 0)) 504 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, 505 mcast_id) 506 507 def set_recv_dbg(self, enabled): 508 self._recv_dbg = enabled 509 510 def _recv_dbg_print(self, reply, nl_msgs): 511 if not self._recv_dbg: 512 return 513 print("Recv: read", len(reply), "bytes,", 514 len(nl_msgs.msgs), "messages", file=sys.stderr) 515 for nl_msg in nl_msgs: 516 print(" ", nl_msg, file=sys.stderr) 517 518 def _encode_enum(self, attr_spec, value): 519 enum = self.consts[attr_spec['enum']] 520 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 521 scalar = 0 522 if isinstance(value, str): 523 value = [value] 524 for single_value in value: 525 scalar += enum.entries[single_value].user_value(as_flags = True) 526 return scalar 527 else: 528 return enum.entries[value].user_value() 529 530 def _get_scalar(self, attr_spec, value): 531 try: 532 return int(value) 533 except (ValueError, TypeError) as e: 534 if 'enum' not in attr_spec: 535 raise e 536 return self._encode_enum(attr_spec, value) 537 538 def _add_attr(self, space, name, value, search_attrs): 539 try: 540 attr = self.attr_sets[space][name] 541 except KeyError: 542 raise Exception(f"Space '{space}' has no attribute '{name}'") 543 nl_type = attr.value 544 545 if attr.is_multi and isinstance(value, list): 546 attr_payload = b'' 547 for subvalue in value: 548 attr_payload += self._add_attr(space, name, subvalue, search_attrs) 549 return attr_payload 550 551 if attr["type"] == 'nest': 552 nl_type |= Netlink.NLA_F_NESTED 553 attr_payload = b'' 554 sub_attrs = SpaceAttrs(self.attr_sets[space], value, search_attrs) 555 for subname, subvalue in value.items(): 556 attr_payload += self._add_attr(attr['nested-attributes'], 557 subname, subvalue, sub_attrs) 558 elif attr["type"] == 'flag': 559 if not value: 560 # If value is absent or false then skip attribute creation. 561 return b'' 562 attr_payload = b'' 563 elif attr["type"] == 'string': 564 attr_payload = str(value).encode('ascii') + b'\x00' 565 elif attr["type"] == 'binary': 566 if isinstance(value, bytes): 567 attr_payload = value 568 elif isinstance(value, str): 569 attr_payload = bytes.fromhex(value) 570 elif isinstance(value, dict) and attr.struct_name: 571 attr_payload = self._encode_struct(attr.struct_name, value) 572 else: 573 raise Exception(f'Unknown type for binary attribute, value: {value}') 574 elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar: 575 scalar = self._get_scalar(attr, value) 576 if attr.is_auto_scalar: 577 attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64') 578 else: 579 attr_type = attr["type"] 580 format = NlAttr.get_format(attr_type, attr.byte_order) 581 attr_payload = format.pack(scalar) 582 elif attr['type'] in "bitfield32": 583 scalar_value = self._get_scalar(attr, value["value"]) 584 scalar_selector = self._get_scalar(attr, value["selector"]) 585 attr_payload = struct.pack("II", scalar_value, scalar_selector) 586 elif attr['type'] == 'sub-message': 587 msg_format = self._resolve_selector(attr, search_attrs) 588 attr_payload = b'' 589 if msg_format.fixed_header: 590 attr_payload += self._encode_struct(msg_format.fixed_header, value) 591 if msg_format.attr_set: 592 if msg_format.attr_set in self.attr_sets: 593 nl_type |= Netlink.NLA_F_NESTED 594 sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs) 595 for subname, subvalue in value.items(): 596 attr_payload += self._add_attr(msg_format.attr_set, 597 subname, subvalue, sub_attrs) 598 else: 599 raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'") 600 else: 601 raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') 602 603 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) 604 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad 605 606 def _decode_enum(self, raw, attr_spec): 607 enum = self.consts[attr_spec['enum']] 608 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 609 i = 0 610 value = set() 611 while raw: 612 if raw & 1: 613 value.add(enum.entries_by_val[i].name) 614 raw >>= 1 615 i += 1 616 else: 617 value = enum.entries_by_val[raw].name 618 return value 619 620 def _decode_binary(self, attr, attr_spec): 621 if attr_spec.struct_name: 622 decoded = self._decode_struct(attr.raw, attr_spec.struct_name) 623 elif attr_spec.sub_type: 624 decoded = attr.as_c_array(attr_spec.sub_type) 625 else: 626 decoded = attr.as_bin() 627 if attr_spec.display_hint: 628 decoded = self._formatted_string(decoded, attr_spec.display_hint) 629 return decoded 630 631 def _decode_array_attr(self, attr, attr_spec): 632 decoded = [] 633 offset = 0 634 while offset < len(attr.raw): 635 item = NlAttr(attr.raw, offset) 636 offset += item.full_len 637 638 if attr_spec["sub-type"] == 'nest': 639 subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes']) 640 decoded.append({ item.type: subattrs }) 641 elif attr_spec["sub-type"] == 'binary': 642 subattrs = item.as_bin() 643 if attr_spec.display_hint: 644 subattrs = self._formatted_string(subattrs, attr_spec.display_hint) 645 decoded.append(subattrs) 646 elif attr_spec["sub-type"] in NlAttr.type_formats: 647 subattrs = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order) 648 if attr_spec.display_hint: 649 subattrs = self._formatted_string(subattrs, attr_spec.display_hint) 650 decoded.append(subattrs) 651 else: 652 raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}') 653 return decoded 654 655 def _decode_nest_type_value(self, attr, attr_spec): 656 decoded = {} 657 value = attr 658 for name in attr_spec['type-value']: 659 value = NlAttr(value.raw, 0) 660 decoded[name] = value.type 661 subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes']) 662 decoded.update(subattrs) 663 return decoded 664 665 def _decode_unknown(self, attr): 666 if attr.is_nest: 667 return self._decode(NlAttrs(attr.raw), None) 668 else: 669 return attr.as_bin() 670 671 def _rsp_add(self, rsp, name, is_multi, decoded): 672 if is_multi == None: 673 if name in rsp and type(rsp[name]) is not list: 674 rsp[name] = [rsp[name]] 675 is_multi = True 676 else: 677 is_multi = False 678 679 if not is_multi: 680 rsp[name] = decoded 681 elif name in rsp: 682 rsp[name].append(decoded) 683 else: 684 rsp[name] = [decoded] 685 686 def _resolve_selector(self, attr_spec, search_attrs): 687 sub_msg = attr_spec.sub_message 688 if sub_msg not in self.sub_msgs: 689 raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}") 690 sub_msg_spec = self.sub_msgs[sub_msg] 691 692 selector = attr_spec.selector 693 value = search_attrs.lookup(selector) 694 if value not in sub_msg_spec.formats: 695 raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'") 696 697 spec = sub_msg_spec.formats[value] 698 return spec 699 700 def _decode_sub_msg(self, attr, attr_spec, search_attrs): 701 msg_format = self._resolve_selector(attr_spec, search_attrs) 702 decoded = {} 703 offset = 0 704 if msg_format.fixed_header: 705 decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header)); 706 offset = self._struct_size(msg_format.fixed_header) 707 if msg_format.attr_set: 708 if msg_format.attr_set in self.attr_sets: 709 subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set) 710 decoded.update(subdict) 711 else: 712 raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'") 713 return decoded 714 715 def _decode(self, attrs, space, outer_attrs = None): 716 rsp = dict() 717 if space: 718 attr_space = self.attr_sets[space] 719 search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs) 720 721 for attr in attrs: 722 try: 723 attr_spec = attr_space.attrs_by_val[attr.type] 724 except (KeyError, UnboundLocalError): 725 if not self.process_unknown: 726 raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'") 727 attr_name = f"UnknownAttr({attr.type})" 728 self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr)) 729 continue 730 731 if attr_spec["type"] == 'nest': 732 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs) 733 decoded = subdict 734 elif attr_spec["type"] == 'string': 735 decoded = attr.as_strz() 736 elif attr_spec["type"] == 'binary': 737 decoded = self._decode_binary(attr, attr_spec) 738 elif attr_spec["type"] == 'flag': 739 decoded = True 740 elif attr_spec.is_auto_scalar: 741 decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order) 742 elif attr_spec["type"] in NlAttr.type_formats: 743 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) 744 if 'enum' in attr_spec: 745 decoded = self._decode_enum(decoded, attr_spec) 746 elif attr_spec.display_hint: 747 decoded = self._formatted_string(decoded, attr_spec.display_hint) 748 elif attr_spec["type"] == 'indexed-array': 749 decoded = self._decode_array_attr(attr, attr_spec) 750 elif attr_spec["type"] == 'bitfield32': 751 value, selector = struct.unpack("II", attr.raw) 752 if 'enum' in attr_spec: 753 value = self._decode_enum(value, attr_spec) 754 selector = self._decode_enum(selector, attr_spec) 755 decoded = {"value": value, "selector": selector} 756 elif attr_spec["type"] == 'sub-message': 757 decoded = self._decode_sub_msg(attr, attr_spec, search_attrs) 758 elif attr_spec["type"] == 'nest-type-value': 759 decoded = self._decode_nest_type_value(attr, attr_spec) 760 else: 761 if not self.process_unknown: 762 raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') 763 decoded = self._decode_unknown(attr) 764 765 self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded) 766 767 return rsp 768 769 def _decode_extack_path(self, attrs, attr_set, offset, target): 770 for attr in attrs: 771 try: 772 attr_spec = attr_set.attrs_by_val[attr.type] 773 except KeyError: 774 raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") 775 if offset > target: 776 break 777 if offset == target: 778 return '.' + attr_spec.name 779 780 if offset + attr.full_len <= target: 781 offset += attr.full_len 782 continue 783 if attr_spec['type'] != 'nest': 784 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") 785 offset += 4 786 subpath = self._decode_extack_path(NlAttrs(attr.raw), 787 self.attr_sets[attr_spec['nested-attributes']], 788 offset, target) 789 if subpath is None: 790 return None 791 return '.' + attr_spec.name + subpath 792 793 return None 794 795 def _decode_extack(self, request, op, extack): 796 if 'bad-attr-offs' not in extack: 797 return 798 799 msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op) 800 offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header) 801 path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset, 802 extack['bad-attr-offs']) 803 if path: 804 del extack['bad-attr-offs'] 805 extack['bad-attr'] = path 806 807 def _struct_size(self, name): 808 if name: 809 members = self.consts[name].members 810 size = 0 811 for m in members: 812 if m.type in ['pad', 'binary']: 813 if m.struct: 814 size += self._struct_size(m.struct) 815 else: 816 size += m.len 817 else: 818 format = NlAttr.get_format(m.type, m.byte_order) 819 size += format.size 820 return size 821 else: 822 return 0 823 824 def _decode_struct(self, data, name): 825 members = self.consts[name].members 826 attrs = dict() 827 offset = 0 828 for m in members: 829 value = None 830 if m.type == 'pad': 831 offset += m.len 832 elif m.type == 'binary': 833 if m.struct: 834 len = self._struct_size(m.struct) 835 value = self._decode_struct(data[offset : offset + len], 836 m.struct) 837 offset += len 838 else: 839 value = data[offset : offset + m.len] 840 offset += m.len 841 else: 842 format = NlAttr.get_format(m.type, m.byte_order) 843 [ value ] = format.unpack_from(data, offset) 844 offset += format.size 845 if value is not None: 846 if m.enum: 847 value = self._decode_enum(value, m) 848 elif m.display_hint: 849 value = self._formatted_string(value, m.display_hint) 850 attrs[m.name] = value 851 return attrs 852 853 def _encode_struct(self, name, vals): 854 members = self.consts[name].members 855 attr_payload = b'' 856 for m in members: 857 value = vals.pop(m.name) if m.name in vals else None 858 if m.type == 'pad': 859 attr_payload += bytearray(m.len) 860 elif m.type == 'binary': 861 if m.struct: 862 if value is None: 863 value = dict() 864 attr_payload += self._encode_struct(m.struct, value) 865 else: 866 if value is None: 867 attr_payload += bytearray(m.len) 868 else: 869 attr_payload += bytes.fromhex(value) 870 else: 871 if value is None: 872 value = 0 873 format = NlAttr.get_format(m.type, m.byte_order) 874 attr_payload += format.pack(value) 875 return attr_payload 876 877 def _formatted_string(self, raw, display_hint): 878 if display_hint == 'mac': 879 formatted = ':'.join('%02x' % b for b in raw) 880 elif display_hint == 'hex': 881 if isinstance(raw, int): 882 formatted = hex(raw) 883 else: 884 formatted = bytes.hex(raw, ' ') 885 elif display_hint in [ 'ipv4', 'ipv6' ]: 886 formatted = format(ipaddress.ip_address(raw)) 887 elif display_hint == 'uuid': 888 formatted = str(uuid.UUID(bytes=raw)) 889 else: 890 formatted = raw 891 return formatted 892 893 def handle_ntf(self, decoded): 894 msg = dict() 895 if self.include_raw: 896 msg['raw'] = decoded 897 op = self.rsp_by_value[decoded.cmd()] 898 attrs = self._decode(decoded.raw_attrs, op.attr_set.name) 899 if op.fixed_header: 900 attrs.update(self._decode_struct(decoded.raw, op.fixed_header)) 901 902 msg['name'] = op['name'] 903 msg['msg'] = attrs 904 self.async_msg_queue.append(msg) 905 906 def check_ntf(self): 907 while True: 908 try: 909 reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT) 910 except BlockingIOError: 911 return 912 913 nms = NlMsgs(reply) 914 self._recv_dbg_print(reply, nms) 915 for nl_msg in nms: 916 if nl_msg.error: 917 print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) 918 print(nl_msg) 919 continue 920 if nl_msg.done: 921 print("Netlink done while checking for ntf!?") 922 continue 923 924 op = self.rsp_by_value[nl_msg.cmd()] 925 decoded = self.nlproto.decode(self, nl_msg, op) 926 if decoded.cmd() not in self.async_msg_ids: 927 print("Unexpected msg id done while checking for ntf", decoded) 928 continue 929 930 self.handle_ntf(decoded) 931 932 def operation_do_attributes(self, name): 933 """ 934 For a given operation name, find and return a supported 935 set of attributes (as a dict). 936 """ 937 op = self.find_operation(name) 938 if not op: 939 return None 940 941 return op['do']['request']['attributes'].copy() 942 943 def _encode_message(self, op, vals, flags, req_seq): 944 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK 945 for flag in flags or []: 946 nl_flags |= flag 947 948 msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq) 949 if op.fixed_header: 950 msg += self._encode_struct(op.fixed_header, vals) 951 search_attrs = SpaceAttrs(op.attr_set, vals) 952 for name, value in vals.items(): 953 msg += self._add_attr(op.attr_set.name, name, value, search_attrs) 954 msg = _genl_msg_finalize(msg) 955 return msg 956 957 def _ops(self, ops): 958 reqs_by_seq = {} 959 req_seq = random.randint(1024, 65535) 960 payload = b'' 961 for (method, vals, flags) in ops: 962 op = self.ops[method] 963 msg = self._encode_message(op, vals, flags, req_seq) 964 reqs_by_seq[req_seq] = (op, msg, flags) 965 payload += msg 966 req_seq += 1 967 968 self.sock.send(payload, 0) 969 970 done = False 971 rsp = [] 972 op_rsp = [] 973 while not done: 974 reply = self.sock.recv(self._recv_size) 975 nms = NlMsgs(reply, attr_space=op.attr_set) 976 self._recv_dbg_print(reply, nms) 977 for nl_msg in nms: 978 if nl_msg.nl_seq in reqs_by_seq: 979 (op, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq] 980 if nl_msg.extack: 981 self._decode_extack(req_msg, op, nl_msg.extack) 982 else: 983 op = self.rsp_by_value[nl_msg.cmd()] 984 req_flags = [] 985 986 if nl_msg.error: 987 raise NlError(nl_msg) 988 if nl_msg.done: 989 if nl_msg.extack: 990 print("Netlink warning:") 991 print(nl_msg) 992 993 if Netlink.NLM_F_DUMP in req_flags: 994 rsp.append(op_rsp) 995 elif not op_rsp: 996 rsp.append(None) 997 elif len(op_rsp) == 1: 998 rsp.append(op_rsp[0]) 999 else: 1000 rsp.append(op_rsp) 1001 op_rsp = [] 1002 1003 del reqs_by_seq[nl_msg.nl_seq] 1004 done = len(reqs_by_seq) == 0 1005 break 1006 1007 decoded = self.nlproto.decode(self, nl_msg, op) 1008 1009 # Check if this is a reply to our request 1010 if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value: 1011 if decoded.cmd() in self.async_msg_ids: 1012 self.handle_ntf(decoded) 1013 continue 1014 else: 1015 print('Unexpected message: ' + repr(decoded)) 1016 continue 1017 1018 rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name) 1019 if op.fixed_header: 1020 rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header)) 1021 op_rsp.append(rsp_msg) 1022 1023 return rsp 1024 1025 def _op(self, method, vals, flags=None, dump=False): 1026 req_flags = flags or [] 1027 if dump: 1028 req_flags.append(Netlink.NLM_F_DUMP) 1029 1030 ops = [(method, vals, req_flags)] 1031 return self._ops(ops)[0] 1032 1033 def do(self, method, vals, flags=None): 1034 return self._op(method, vals, flags) 1035 1036 def dump(self, method, vals): 1037 return self._op(method, vals, dump=True) 1038 1039 def do_multi(self, ops): 1040 return self._ops(ops) 1041