1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause 2 3from collections import namedtuple 4from enum import Enum 5import functools 6import os 7import random 8import socket 9import struct 10from struct import Struct 11import sys 12import yaml 13import ipaddress 14import uuid 15 16from .nlspec import SpecFamily 17 18# 19# Generic Netlink code which should really be in some library, but I can't quickly find one. 20# 21 22 23class Netlink: 24 # Netlink socket 25 SOL_NETLINK = 270 26 27 NETLINK_ADD_MEMBERSHIP = 1 28 NETLINK_CAP_ACK = 10 29 NETLINK_EXT_ACK = 11 30 NETLINK_GET_STRICT_CHK = 12 31 32 # Netlink message 33 NLMSG_ERROR = 2 34 NLMSG_DONE = 3 35 36 NLM_F_REQUEST = 1 37 NLM_F_ACK = 4 38 NLM_F_ROOT = 0x100 39 NLM_F_MATCH = 0x200 40 41 NLM_F_REPLACE = 0x100 42 NLM_F_EXCL = 0x200 43 NLM_F_CREATE = 0x400 44 NLM_F_APPEND = 0x800 45 46 NLM_F_CAPPED = 0x100 47 NLM_F_ACK_TLVS = 0x200 48 49 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH 50 51 NLA_F_NESTED = 0x8000 52 NLA_F_NET_BYTEORDER = 0x4000 53 54 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER 55 56 # Genetlink defines 57 NETLINK_GENERIC = 16 58 59 GENL_ID_CTRL = 0x10 60 61 # nlctrl 62 CTRL_CMD_GETFAMILY = 3 63 64 CTRL_ATTR_FAMILY_ID = 1 65 CTRL_ATTR_FAMILY_NAME = 2 66 CTRL_ATTR_MAXATTR = 5 67 CTRL_ATTR_MCAST_GROUPS = 7 68 69 CTRL_ATTR_MCAST_GRP_NAME = 1 70 CTRL_ATTR_MCAST_GRP_ID = 2 71 72 # Extack types 73 NLMSGERR_ATTR_MSG = 1 74 NLMSGERR_ATTR_OFFS = 2 75 NLMSGERR_ATTR_COOKIE = 3 76 NLMSGERR_ATTR_POLICY = 4 77 NLMSGERR_ATTR_MISS_TYPE = 5 78 NLMSGERR_ATTR_MISS_NEST = 6 79 80 # Policy types 81 NL_POLICY_TYPE_ATTR_TYPE = 1 82 NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2 83 NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3 84 NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4 85 NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5 86 NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6 87 NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7 88 NL_POLICY_TYPE_ATTR_POLICY_IDX = 8 89 NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9 90 NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10 91 NL_POLICY_TYPE_ATTR_PAD = 11 92 NL_POLICY_TYPE_ATTR_MASK = 12 93 94 AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64', 95 's8', 's16', 's32', 's64', 96 'binary', 'string', 'nul-string', 97 'nested', 'nested-array', 98 'bitfield32', 'sint', 'uint']) 99 100class NlError(Exception): 101 def __init__(self, nl_msg): 102 self.nl_msg = nl_msg 103 self.error = -nl_msg.error 104 105 def __str__(self): 106 return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}" 107 108 109class ConfigError(Exception): 110 pass 111 112 113class NlAttr: 114 ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) 115 type_formats = { 116 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), 117 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), 118 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")), 119 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")), 120 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")), 121 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")), 122 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")), 123 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q")) 124 } 125 126 def __init__(self, raw, offset): 127 self._len, self._type = struct.unpack("HH", raw[offset : offset + 4]) 128 self.type = self._type & ~Netlink.NLA_TYPE_MASK 129 self.is_nest = self._type & Netlink.NLA_F_NESTED 130 self.payload_len = self._len 131 self.full_len = (self.payload_len + 3) & ~3 132 self.raw = raw[offset + 4 : offset + self.payload_len] 133 134 @classmethod 135 def get_format(cls, attr_type, byte_order=None): 136 format = cls.type_formats[attr_type] 137 if byte_order: 138 return format.big if byte_order == "big-endian" \ 139 else format.little 140 return format.native 141 142 def as_scalar(self, attr_type, byte_order=None): 143 format = self.get_format(attr_type, byte_order) 144 return format.unpack(self.raw)[0] 145 146 def as_auto_scalar(self, attr_type, byte_order=None): 147 if len(self.raw) != 4 and len(self.raw) != 8: 148 raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}") 149 real_type = attr_type[0] + str(len(self.raw) * 8) 150 format = self.get_format(real_type, byte_order) 151 return format.unpack(self.raw)[0] 152 153 def as_strz(self): 154 return self.raw.decode('ascii')[:-1] 155 156 def as_bin(self): 157 return self.raw 158 159 def as_c_array(self, type): 160 format = self.get_format(type) 161 return [ x[0] for x in format.iter_unpack(self.raw) ] 162 163 def __repr__(self): 164 return f"[type:{self.type} len:{self._len}] {self.raw}" 165 166 167class NlAttrs: 168 def __init__(self, msg, offset=0): 169 self.attrs = [] 170 171 while offset < len(msg): 172 attr = NlAttr(msg, offset) 173 offset += attr.full_len 174 self.attrs.append(attr) 175 176 def __iter__(self): 177 yield from self.attrs 178 179 def __repr__(self): 180 msg = '' 181 for a in self.attrs: 182 if msg: 183 msg += '\n' 184 msg += repr(a) 185 return msg 186 187 188class NlMsg: 189 def __init__(self, msg, offset, attr_space=None): 190 self.hdr = msg[offset : offset + 16] 191 192 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ 193 struct.unpack("IHHII", self.hdr) 194 195 self.raw = msg[offset + 16 : offset + self.nl_len] 196 197 self.error = 0 198 self.done = 0 199 200 extack_off = None 201 if self.nl_type == Netlink.NLMSG_ERROR: 202 self.error = struct.unpack("i", self.raw[0:4])[0] 203 self.done = 1 204 extack_off = 20 205 elif self.nl_type == Netlink.NLMSG_DONE: 206 self.error = struct.unpack("i", self.raw[0:4])[0] 207 self.done = 1 208 extack_off = 4 209 210 self.extack = None 211 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: 212 self.extack = dict() 213 extack_attrs = NlAttrs(self.raw[extack_off:]) 214 for extack in extack_attrs: 215 if extack.type == Netlink.NLMSGERR_ATTR_MSG: 216 self.extack['msg'] = extack.as_strz() 217 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: 218 self.extack['miss-type'] = extack.as_scalar('u32') 219 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: 220 self.extack['miss-nest'] = extack.as_scalar('u32') 221 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: 222 self.extack['bad-attr-offs'] = extack.as_scalar('u32') 223 elif extack.type == Netlink.NLMSGERR_ATTR_POLICY: 224 self.extack['policy'] = self._decode_policy(extack.raw) 225 else: 226 if 'unknown' not in self.extack: 227 self.extack['unknown'] = [] 228 self.extack['unknown'].append(extack) 229 230 if attr_space: 231 # We don't have the ability to parse nests yet, so only do global 232 if 'miss-type' in self.extack and 'miss-nest' not in self.extack: 233 miss_type = self.extack['miss-type'] 234 if miss_type in attr_space.attrs_by_val: 235 spec = attr_space.attrs_by_val[miss_type] 236 self.extack['miss-type'] = spec['name'] 237 if 'doc' in spec: 238 self.extack['miss-type-doc'] = spec['doc'] 239 240 def _decode_policy(self, raw): 241 policy = {} 242 for attr in NlAttrs(raw): 243 if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE: 244 type = attr.as_scalar('u32') 245 policy['type'] = Netlink.AttrType(type).name 246 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S: 247 policy['min-value'] = attr.as_scalar('s64') 248 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S: 249 policy['max-value'] = attr.as_scalar('s64') 250 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U: 251 policy['min-value'] = attr.as_scalar('u64') 252 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U: 253 policy['max-value'] = attr.as_scalar('u64') 254 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH: 255 policy['min-length'] = attr.as_scalar('u32') 256 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH: 257 policy['max-length'] = attr.as_scalar('u32') 258 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK: 259 policy['bitfield32-mask'] = attr.as_scalar('u32') 260 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK: 261 policy['mask'] = attr.as_scalar('u64') 262 return policy 263 264 def cmd(self): 265 return self.nl_type 266 267 def __repr__(self): 268 msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}" 269 if self.error: 270 msg += '\n\terror: ' + str(self.error) 271 if self.extack: 272 msg += '\n\textack: ' + repr(self.extack) 273 return msg 274 275 276class NlMsgs: 277 def __init__(self, data, attr_space=None): 278 self.msgs = [] 279 280 offset = 0 281 while offset < len(data): 282 msg = NlMsg(data, offset, attr_space=attr_space) 283 offset += msg.nl_len 284 self.msgs.append(msg) 285 286 def __iter__(self): 287 yield from self.msgs 288 289 290genl_family_name_to_id = None 291 292 293def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): 294 # we prepend length in _genl_msg_finalize() 295 if seq is None: 296 seq = random.randint(1, 1024) 297 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 298 genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0) 299 return nlmsg + genlmsg 300 301 302def _genl_msg_finalize(msg): 303 return struct.pack("I", len(msg) + 4) + msg 304 305 306def _genl_load_families(): 307 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: 308 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 309 310 msg = _genl_msg(Netlink.GENL_ID_CTRL, 311 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, 312 Netlink.CTRL_CMD_GETFAMILY, 1) 313 msg = _genl_msg_finalize(msg) 314 315 sock.send(msg, 0) 316 317 global genl_family_name_to_id 318 genl_family_name_to_id = dict() 319 320 while True: 321 reply = sock.recv(128 * 1024) 322 nms = NlMsgs(reply) 323 for nl_msg in nms: 324 if nl_msg.error: 325 print("Netlink error:", nl_msg.error) 326 return 327 if nl_msg.done: 328 return 329 330 gm = GenlMsg(nl_msg) 331 fam = dict() 332 for attr in NlAttrs(gm.raw): 333 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: 334 fam['id'] = attr.as_scalar('u16') 335 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: 336 fam['name'] = attr.as_strz() 337 elif attr.type == Netlink.CTRL_ATTR_MAXATTR: 338 fam['maxattr'] = attr.as_scalar('u32') 339 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: 340 fam['mcast'] = dict() 341 for entry in NlAttrs(attr.raw): 342 mcast_name = None 343 mcast_id = None 344 for entry_attr in NlAttrs(entry.raw): 345 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: 346 mcast_name = entry_attr.as_strz() 347 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: 348 mcast_id = entry_attr.as_scalar('u32') 349 if mcast_name and mcast_id is not None: 350 fam['mcast'][mcast_name] = mcast_id 351 if 'name' in fam and 'id' in fam: 352 genl_family_name_to_id[fam['name']] = fam 353 354 355class GenlMsg: 356 def __init__(self, nl_msg): 357 self.nl = nl_msg 358 self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0) 359 self.raw = nl_msg.raw[4:] 360 361 def cmd(self): 362 return self.genl_cmd 363 364 def __repr__(self): 365 msg = repr(self.nl) 366 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" 367 for a in self.raw_attrs: 368 msg += '\t\t' + repr(a) + '\n' 369 return msg 370 371 372class NetlinkProtocol: 373 def __init__(self, family_name, proto_num): 374 self.family_name = family_name 375 self.proto_num = proto_num 376 377 def _message(self, nl_type, nl_flags, seq=None): 378 if seq is None: 379 seq = random.randint(1, 1024) 380 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 381 return nlmsg 382 383 def message(self, flags, command, version, seq=None): 384 return self._message(command, flags, seq) 385 386 def _decode(self, nl_msg): 387 return nl_msg 388 389 def decode(self, ynl, nl_msg, op): 390 msg = self._decode(nl_msg) 391 fixed_header_size = ynl._struct_size(op.fixed_header) 392 msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size) 393 return msg 394 395 def get_mcast_id(self, mcast_name, mcast_groups): 396 if mcast_name not in mcast_groups: 397 raise Exception(f'Multicast group "{mcast_name}" not present in the spec') 398 return mcast_groups[mcast_name].value 399 400 def msghdr_size(self): 401 return 16 402 403 404class GenlProtocol(NetlinkProtocol): 405 def __init__(self, family_name): 406 super().__init__(family_name, Netlink.NETLINK_GENERIC) 407 408 global genl_family_name_to_id 409 if genl_family_name_to_id is None: 410 _genl_load_families() 411 412 self.genl_family = genl_family_name_to_id[family_name] 413 self.family_id = genl_family_name_to_id[family_name]['id'] 414 415 def message(self, flags, command, version, seq=None): 416 nlmsg = self._message(self.family_id, flags, seq) 417 genlmsg = struct.pack("BBH", command, version, 0) 418 return nlmsg + genlmsg 419 420 def _decode(self, nl_msg): 421 return GenlMsg(nl_msg) 422 423 def get_mcast_id(self, mcast_name, mcast_groups): 424 if mcast_name not in self.genl_family['mcast']: 425 raise Exception(f'Multicast group "{mcast_name}" not present in the family') 426 return self.genl_family['mcast'][mcast_name] 427 428 def msghdr_size(self): 429 return super().msghdr_size() + 4 430 431 432class SpaceAttrs: 433 SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values']) 434 435 def __init__(self, attr_space, attrs, outer = None): 436 outer_scopes = outer.scopes if outer else [] 437 inner_scope = self.SpecValuesPair(attr_space, attrs) 438 self.scopes = [inner_scope] + outer_scopes 439 440 def lookup(self, name): 441 for scope in self.scopes: 442 if name in scope.spec: 443 if name in scope.values: 444 return scope.values[name] 445 spec_name = scope.spec.yaml['name'] 446 raise Exception( 447 f"No value for '{name}' in attribute space '{spec_name}'") 448 raise Exception(f"Attribute '{name}' not defined in any attribute-set") 449 450 451# 452# YNL implementation details. 453# 454 455 456class YnlFamily(SpecFamily): 457 def __init__(self, def_path, schema=None, process_unknown=False, 458 recv_size=0): 459 super().__init__(def_path, schema) 460 461 self.include_raw = False 462 self.process_unknown = process_unknown 463 464 try: 465 if self.proto == "netlink-raw": 466 self.nlproto = NetlinkProtocol(self.yaml['name'], 467 self.yaml['protonum']) 468 else: 469 self.nlproto = GenlProtocol(self.yaml['name']) 470 except KeyError: 471 raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel") 472 473 self._recv_dbg = False 474 # Note that netlink will use conservative (min) message size for 475 # the first dump recv() on the socket, our setting will only matter 476 # from the second recv() on. 477 self._recv_size = recv_size if recv_size else 131072 478 # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo) 479 # for a message, so smaller receive sizes will lead to truncation. 480 # Note that the min size for other families may be larger than 4k! 481 if self._recv_size < 4000: 482 raise ConfigError() 483 484 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num) 485 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 486 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) 487 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1) 488 489 self.async_msg_ids = set() 490 self.async_msg_queue = [] 491 492 for msg in self.msgs.values(): 493 if msg.is_async: 494 self.async_msg_ids.add(msg.rsp_value) 495 496 for op_name, op in self.ops.items(): 497 bound_f = functools.partial(self._op, op_name) 498 setattr(self, op.ident_name, bound_f) 499 500 501 def ntf_subscribe(self, mcast_name): 502 mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups) 503 self.sock.bind((0, 0)) 504 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, 505 mcast_id) 506 507 def set_recv_dbg(self, enabled): 508 self._recv_dbg = enabled 509 510 def _recv_dbg_print(self, reply, nl_msgs): 511 if not self._recv_dbg: 512 return 513 print("Recv: read", len(reply), "bytes,", 514 len(nl_msgs.msgs), "messages", file=sys.stderr) 515 for nl_msg in nl_msgs: 516 print(" ", nl_msg, file=sys.stderr) 517 518 def _encode_enum(self, attr_spec, value): 519 enum = self.consts[attr_spec['enum']] 520 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 521 scalar = 0 522 if isinstance(value, str): 523 value = [value] 524 for single_value in value: 525 scalar += enum.entries[single_value].user_value(as_flags = True) 526 return scalar 527 else: 528 return enum.entries[value].user_value() 529 530 def _get_scalar(self, attr_spec, value): 531 try: 532 return int(value) 533 except (ValueError, TypeError) as e: 534 if 'enum' not in attr_spec: 535 raise e 536 return self._encode_enum(attr_spec, value) 537 538 def _add_attr(self, space, name, value, search_attrs): 539 try: 540 attr = self.attr_sets[space][name] 541 except KeyError: 542 raise Exception(f"Space '{space}' has no attribute '{name}'") 543 nl_type = attr.value 544 545 if attr.is_multi and isinstance(value, list): 546 attr_payload = b'' 547 for subvalue in value: 548 attr_payload += self._add_attr(space, name, subvalue, search_attrs) 549 return attr_payload 550 551 if attr["type"] == 'nest': 552 nl_type |= Netlink.NLA_F_NESTED 553 attr_payload = b'' 554 sub_attrs = SpaceAttrs(self.attr_sets[space], value, search_attrs) 555 for subname, subvalue in value.items(): 556 attr_payload += self._add_attr(attr['nested-attributes'], 557 subname, subvalue, sub_attrs) 558 elif attr["type"] == 'flag': 559 if not value: 560 # If value is absent or false then skip attribute creation. 561 return b'' 562 attr_payload = b'' 563 elif attr["type"] == 'string': 564 attr_payload = str(value).encode('ascii') + b'\x00' 565 elif attr["type"] == 'binary': 566 if isinstance(value, bytes): 567 attr_payload = value 568 elif isinstance(value, str): 569 attr_payload = bytes.fromhex(value) 570 elif isinstance(value, dict) and attr.struct_name: 571 attr_payload = self._encode_struct(attr.struct_name, value) 572 else: 573 raise Exception(f'Unknown type for binary attribute, value: {value}') 574 elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar: 575 scalar = self._get_scalar(attr, value) 576 if attr.is_auto_scalar: 577 attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64') 578 else: 579 attr_type = attr["type"] 580 format = NlAttr.get_format(attr_type, attr.byte_order) 581 attr_payload = format.pack(scalar) 582 elif attr['type'] in "bitfield32": 583 scalar_value = self._get_scalar(attr, value["value"]) 584 scalar_selector = self._get_scalar(attr, value["selector"]) 585 attr_payload = struct.pack("II", scalar_value, scalar_selector) 586 elif attr['type'] == 'sub-message': 587 msg_format = self._resolve_selector(attr, search_attrs) 588 attr_payload = b'' 589 if msg_format.fixed_header: 590 attr_payload += self._encode_struct(msg_format.fixed_header, value) 591 if msg_format.attr_set: 592 if msg_format.attr_set in self.attr_sets: 593 nl_type |= Netlink.NLA_F_NESTED 594 sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs) 595 for subname, subvalue in value.items(): 596 attr_payload += self._add_attr(msg_format.attr_set, 597 subname, subvalue, sub_attrs) 598 else: 599 raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'") 600 else: 601 raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') 602 603 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) 604 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad 605 606 def _decode_enum(self, raw, attr_spec): 607 enum = self.consts[attr_spec['enum']] 608 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 609 i = 0 610 value = set() 611 while raw: 612 if raw & 1: 613 value.add(enum.entries_by_val[i].name) 614 raw >>= 1 615 i += 1 616 else: 617 value = enum.entries_by_val[raw].name 618 return value 619 620 def _decode_binary(self, attr, attr_spec): 621 if attr_spec.struct_name: 622 decoded = self._decode_struct(attr.raw, attr_spec.struct_name) 623 elif attr_spec.sub_type: 624 decoded = attr.as_c_array(attr_spec.sub_type) 625 else: 626 decoded = attr.as_bin() 627 if attr_spec.display_hint: 628 decoded = self._formatted_string(decoded, attr_spec.display_hint) 629 return decoded 630 631 def _decode_array_attr(self, attr, attr_spec): 632 decoded = [] 633 offset = 0 634 while offset < len(attr.raw): 635 item = NlAttr(attr.raw, offset) 636 offset += item.full_len 637 638 if attr_spec["sub-type"] == 'nest': 639 subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes']) 640 decoded.append({ item.type: subattrs }) 641 elif attr_spec["sub-type"] == 'binary': 642 subattrs = item.as_bin() 643 if attr_spec.display_hint: 644 subattrs = self._formatted_string(subattrs, attr_spec.display_hint) 645 decoded.append(subattrs) 646 elif attr_spec["sub-type"] in NlAttr.type_formats: 647 subattrs = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order) 648 if attr_spec.display_hint: 649 subattrs = self._formatted_string(subattrs, attr_spec.display_hint) 650 decoded.append(subattrs) 651 else: 652 raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}') 653 return decoded 654 655 def _decode_nest_type_value(self, attr, attr_spec): 656 decoded = {} 657 value = attr 658 for name in attr_spec['type-value']: 659 value = NlAttr(value.raw, 0) 660 decoded[name] = value.type 661 subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes']) 662 decoded.update(subattrs) 663 return decoded 664 665 def _decode_unknown(self, attr): 666 if attr.is_nest: 667 return self._decode(NlAttrs(attr.raw), None) 668 else: 669 return attr.as_bin() 670 671 def _rsp_add(self, rsp, name, is_multi, decoded): 672 if is_multi == None: 673 if name in rsp and type(rsp[name]) is not list: 674 rsp[name] = [rsp[name]] 675 is_multi = True 676 else: 677 is_multi = False 678 679 if not is_multi: 680 rsp[name] = decoded 681 elif name in rsp: 682 rsp[name].append(decoded) 683 else: 684 rsp[name] = [decoded] 685 686 def _resolve_selector(self, attr_spec, search_attrs): 687 sub_msg = attr_spec.sub_message 688 if sub_msg not in self.sub_msgs: 689 raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}") 690 sub_msg_spec = self.sub_msgs[sub_msg] 691 692 selector = attr_spec.selector 693 value = search_attrs.lookup(selector) 694 if value not in sub_msg_spec.formats: 695 raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'") 696 697 spec = sub_msg_spec.formats[value] 698 return spec 699 700 def _decode_sub_msg(self, attr, attr_spec, search_attrs): 701 msg_format = self._resolve_selector(attr_spec, search_attrs) 702 decoded = {} 703 offset = 0 704 if msg_format.fixed_header: 705 decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header)); 706 offset = self._struct_size(msg_format.fixed_header) 707 if msg_format.attr_set: 708 if msg_format.attr_set in self.attr_sets: 709 subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set) 710 decoded.update(subdict) 711 else: 712 raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'") 713 return decoded 714 715 def _decode(self, attrs, space, outer_attrs = None): 716 rsp = dict() 717 if space: 718 attr_space = self.attr_sets[space] 719 search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs) 720 721 for attr in attrs: 722 try: 723 attr_spec = attr_space.attrs_by_val[attr.type] 724 except (KeyError, UnboundLocalError): 725 if not self.process_unknown: 726 raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'") 727 attr_name = f"UnknownAttr({attr.type})" 728 self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr)) 729 continue 730 731 if attr_spec["type"] == 'nest': 732 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs) 733 decoded = subdict 734 elif attr_spec["type"] == 'string': 735 decoded = attr.as_strz() 736 elif attr_spec["type"] == 'binary': 737 decoded = self._decode_binary(attr, attr_spec) 738 elif attr_spec["type"] == 'flag': 739 decoded = True 740 elif attr_spec.is_auto_scalar: 741 decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order) 742 elif attr_spec["type"] in NlAttr.type_formats: 743 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) 744 if 'enum' in attr_spec: 745 decoded = self._decode_enum(decoded, attr_spec) 746 elif attr_spec["type"] == 'indexed-array': 747 decoded = self._decode_array_attr(attr, attr_spec) 748 elif attr_spec["type"] == 'bitfield32': 749 value, selector = struct.unpack("II", attr.raw) 750 if 'enum' in attr_spec: 751 value = self._decode_enum(value, attr_spec) 752 selector = self._decode_enum(selector, attr_spec) 753 decoded = {"value": value, "selector": selector} 754 elif attr_spec["type"] == 'sub-message': 755 decoded = self._decode_sub_msg(attr, attr_spec, search_attrs) 756 elif attr_spec["type"] == 'nest-type-value': 757 decoded = self._decode_nest_type_value(attr, attr_spec) 758 else: 759 if not self.process_unknown: 760 raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') 761 decoded = self._decode_unknown(attr) 762 763 self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded) 764 765 return rsp 766 767 def _decode_extack_path(self, attrs, attr_set, offset, target): 768 for attr in attrs: 769 try: 770 attr_spec = attr_set.attrs_by_val[attr.type] 771 except KeyError: 772 raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") 773 if offset > target: 774 break 775 if offset == target: 776 return '.' + attr_spec.name 777 778 if offset + attr.full_len <= target: 779 offset += attr.full_len 780 continue 781 if attr_spec['type'] != 'nest': 782 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") 783 offset += 4 784 subpath = self._decode_extack_path(NlAttrs(attr.raw), 785 self.attr_sets[attr_spec['nested-attributes']], 786 offset, target) 787 if subpath is None: 788 return None 789 return '.' + attr_spec.name + subpath 790 791 return None 792 793 def _decode_extack(self, request, op, extack): 794 if 'bad-attr-offs' not in extack: 795 return 796 797 msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op) 798 offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header) 799 path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset, 800 extack['bad-attr-offs']) 801 if path: 802 del extack['bad-attr-offs'] 803 extack['bad-attr'] = path 804 805 def _struct_size(self, name): 806 if name: 807 members = self.consts[name].members 808 size = 0 809 for m in members: 810 if m.type in ['pad', 'binary']: 811 if m.struct: 812 size += self._struct_size(m.struct) 813 else: 814 size += m.len 815 else: 816 format = NlAttr.get_format(m.type, m.byte_order) 817 size += format.size 818 return size 819 else: 820 return 0 821 822 def _decode_struct(self, data, name): 823 members = self.consts[name].members 824 attrs = dict() 825 offset = 0 826 for m in members: 827 value = None 828 if m.type == 'pad': 829 offset += m.len 830 elif m.type == 'binary': 831 if m.struct: 832 len = self._struct_size(m.struct) 833 value = self._decode_struct(data[offset : offset + len], 834 m.struct) 835 offset += len 836 else: 837 value = data[offset : offset + m.len] 838 offset += m.len 839 else: 840 format = NlAttr.get_format(m.type, m.byte_order) 841 [ value ] = format.unpack_from(data, offset) 842 offset += format.size 843 if value is not None: 844 if m.enum: 845 value = self._decode_enum(value, m) 846 elif m.display_hint: 847 value = self._formatted_string(value, m.display_hint) 848 attrs[m.name] = value 849 return attrs 850 851 def _encode_struct(self, name, vals): 852 members = self.consts[name].members 853 attr_payload = b'' 854 for m in members: 855 value = vals.pop(m.name) if m.name in vals else None 856 if m.type == 'pad': 857 attr_payload += bytearray(m.len) 858 elif m.type == 'binary': 859 if m.struct: 860 if value is None: 861 value = dict() 862 attr_payload += self._encode_struct(m.struct, value) 863 else: 864 if value is None: 865 attr_payload += bytearray(m.len) 866 else: 867 attr_payload += bytes.fromhex(value) 868 else: 869 if value is None: 870 value = 0 871 format = NlAttr.get_format(m.type, m.byte_order) 872 attr_payload += format.pack(value) 873 return attr_payload 874 875 def _formatted_string(self, raw, display_hint): 876 if display_hint == 'mac': 877 formatted = ':'.join('%02x' % b for b in raw) 878 elif display_hint == 'hex': 879 if isinstance(raw, int): 880 formatted = hex(raw) 881 else: 882 formatted = bytes.hex(raw, ' ') 883 elif display_hint in [ 'ipv4', 'ipv6' ]: 884 formatted = format(ipaddress.ip_address(raw)) 885 elif display_hint == 'uuid': 886 formatted = str(uuid.UUID(bytes=raw)) 887 else: 888 formatted = raw 889 return formatted 890 891 def handle_ntf(self, decoded): 892 msg = dict() 893 if self.include_raw: 894 msg['raw'] = decoded 895 op = self.rsp_by_value[decoded.cmd()] 896 attrs = self._decode(decoded.raw_attrs, op.attr_set.name) 897 if op.fixed_header: 898 attrs.update(self._decode_struct(decoded.raw, op.fixed_header)) 899 900 msg['name'] = op['name'] 901 msg['msg'] = attrs 902 self.async_msg_queue.append(msg) 903 904 def check_ntf(self): 905 while True: 906 try: 907 reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT) 908 except BlockingIOError: 909 return 910 911 nms = NlMsgs(reply) 912 self._recv_dbg_print(reply, nms) 913 for nl_msg in nms: 914 if nl_msg.error: 915 print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) 916 print(nl_msg) 917 continue 918 if nl_msg.done: 919 print("Netlink done while checking for ntf!?") 920 continue 921 922 op = self.rsp_by_value[nl_msg.cmd()] 923 decoded = self.nlproto.decode(self, nl_msg, op) 924 if decoded.cmd() not in self.async_msg_ids: 925 print("Unexpected msg id done while checking for ntf", decoded) 926 continue 927 928 self.handle_ntf(decoded) 929 930 def operation_do_attributes(self, name): 931 """ 932 For a given operation name, find and return a supported 933 set of attributes (as a dict). 934 """ 935 op = self.find_operation(name) 936 if not op: 937 return None 938 939 return op['do']['request']['attributes'].copy() 940 941 def _encode_message(self, op, vals, flags, req_seq): 942 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK 943 for flag in flags or []: 944 nl_flags |= flag 945 946 msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq) 947 if op.fixed_header: 948 msg += self._encode_struct(op.fixed_header, vals) 949 search_attrs = SpaceAttrs(op.attr_set, vals) 950 for name, value in vals.items(): 951 msg += self._add_attr(op.attr_set.name, name, value, search_attrs) 952 msg = _genl_msg_finalize(msg) 953 return msg 954 955 def _ops(self, ops): 956 reqs_by_seq = {} 957 req_seq = random.randint(1024, 65535) 958 payload = b'' 959 for (method, vals, flags) in ops: 960 op = self.ops[method] 961 msg = self._encode_message(op, vals, flags, req_seq) 962 reqs_by_seq[req_seq] = (op, msg, flags) 963 payload += msg 964 req_seq += 1 965 966 self.sock.send(payload, 0) 967 968 done = False 969 rsp = [] 970 op_rsp = [] 971 while not done: 972 reply = self.sock.recv(self._recv_size) 973 nms = NlMsgs(reply, attr_space=op.attr_set) 974 self._recv_dbg_print(reply, nms) 975 for nl_msg in nms: 976 if nl_msg.nl_seq in reqs_by_seq: 977 (op, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq] 978 if nl_msg.extack: 979 self._decode_extack(req_msg, op, nl_msg.extack) 980 else: 981 op = self.rsp_by_value[nl_msg.cmd()] 982 req_flags = [] 983 984 if nl_msg.error: 985 raise NlError(nl_msg) 986 if nl_msg.done: 987 if nl_msg.extack: 988 print("Netlink warning:") 989 print(nl_msg) 990 991 if Netlink.NLM_F_DUMP in req_flags: 992 rsp.append(op_rsp) 993 elif not op_rsp: 994 rsp.append(None) 995 elif len(op_rsp) == 1: 996 rsp.append(op_rsp[0]) 997 else: 998 rsp.append(op_rsp) 999 op_rsp = [] 1000 1001 del reqs_by_seq[nl_msg.nl_seq] 1002 done = len(reqs_by_seq) == 0 1003 break 1004 1005 decoded = self.nlproto.decode(self, nl_msg, op) 1006 1007 # Check if this is a reply to our request 1008 if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value: 1009 if decoded.cmd() in self.async_msg_ids: 1010 self.handle_ntf(decoded) 1011 continue 1012 else: 1013 print('Unexpected message: ' + repr(decoded)) 1014 continue 1015 1016 rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name) 1017 if op.fixed_header: 1018 rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header)) 1019 op_rsp.append(rsp_msg) 1020 1021 return rsp 1022 1023 def _op(self, method, vals, flags=None, dump=False): 1024 req_flags = flags or [] 1025 if dump: 1026 req_flags.append(Netlink.NLM_F_DUMP) 1027 1028 ops = [(method, vals, req_flags)] 1029 return self._ops(ops)[0] 1030 1031 def do(self, method, vals, flags=None): 1032 return self._op(method, vals, flags) 1033 1034 def dump(self, method, vals): 1035 return self._op(method, vals, dump=True) 1036 1037 def do_multi(self, ops): 1038 return self._ops(ops) 1039