1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause 2 3from collections import namedtuple 4import functools 5import os 6import random 7import socket 8import struct 9from struct import Struct 10import sys 11import yaml 12import ipaddress 13import uuid 14 15from .nlspec import SpecFamily 16 17# 18# Generic Netlink code which should really be in some library, but I can't quickly find one. 19# 20 21 22class Netlink: 23 # Netlink socket 24 SOL_NETLINK = 270 25 26 NETLINK_ADD_MEMBERSHIP = 1 27 NETLINK_CAP_ACK = 10 28 NETLINK_EXT_ACK = 11 29 NETLINK_GET_STRICT_CHK = 12 30 31 # Netlink message 32 NLMSG_ERROR = 2 33 NLMSG_DONE = 3 34 35 NLM_F_REQUEST = 1 36 NLM_F_ACK = 4 37 NLM_F_ROOT = 0x100 38 NLM_F_MATCH = 0x200 39 40 NLM_F_REPLACE = 0x100 41 NLM_F_EXCL = 0x200 42 NLM_F_CREATE = 0x400 43 NLM_F_APPEND = 0x800 44 45 NLM_F_CAPPED = 0x100 46 NLM_F_ACK_TLVS = 0x200 47 48 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH 49 50 NLA_F_NESTED = 0x8000 51 NLA_F_NET_BYTEORDER = 0x4000 52 53 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER 54 55 # Genetlink defines 56 NETLINK_GENERIC = 16 57 58 GENL_ID_CTRL = 0x10 59 60 # nlctrl 61 CTRL_CMD_GETFAMILY = 3 62 63 CTRL_ATTR_FAMILY_ID = 1 64 CTRL_ATTR_FAMILY_NAME = 2 65 CTRL_ATTR_MAXATTR = 5 66 CTRL_ATTR_MCAST_GROUPS = 7 67 68 CTRL_ATTR_MCAST_GRP_NAME = 1 69 CTRL_ATTR_MCAST_GRP_ID = 2 70 71 # Extack types 72 NLMSGERR_ATTR_MSG = 1 73 NLMSGERR_ATTR_OFFS = 2 74 NLMSGERR_ATTR_COOKIE = 3 75 NLMSGERR_ATTR_POLICY = 4 76 NLMSGERR_ATTR_MISS_TYPE = 5 77 NLMSGERR_ATTR_MISS_NEST = 6 78 79 80class NlError(Exception): 81 def __init__(self, nl_msg): 82 self.nl_msg = nl_msg 83 84 def __str__(self): 85 return f"Netlink error: {os.strerror(-self.nl_msg.error)}\n{self.nl_msg}" 86 87 88class ConfigError(Exception): 89 pass 90 91 92class NlAttr: 93 ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) 94 type_formats = { 95 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), 96 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), 97 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")), 98 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")), 99 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")), 100 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")), 101 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")), 102 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q")) 103 } 104 105 def __init__(self, raw, offset): 106 self._len, self._type = struct.unpack("HH", raw[offset : offset + 4]) 107 self.type = self._type & ~Netlink.NLA_TYPE_MASK 108 self.is_nest = self._type & Netlink.NLA_F_NESTED 109 self.payload_len = self._len 110 self.full_len = (self.payload_len + 3) & ~3 111 self.raw = raw[offset + 4 : offset + self.payload_len] 112 113 @classmethod 114 def get_format(cls, attr_type, byte_order=None): 115 format = cls.type_formats[attr_type] 116 if byte_order: 117 return format.big if byte_order == "big-endian" \ 118 else format.little 119 return format.native 120 121 def as_scalar(self, attr_type, byte_order=None): 122 format = self.get_format(attr_type, byte_order) 123 return format.unpack(self.raw)[0] 124 125 def as_auto_scalar(self, attr_type, byte_order=None): 126 if len(self.raw) != 4 and len(self.raw) != 8: 127 raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}") 128 real_type = attr_type[0] + str(len(self.raw) * 8) 129 format = self.get_format(real_type, byte_order) 130 return format.unpack(self.raw)[0] 131 132 def as_strz(self): 133 return self.raw.decode('ascii')[:-1] 134 135 def as_bin(self): 136 return self.raw 137 138 def as_c_array(self, type): 139 format = self.get_format(type) 140 return [ x[0] for x in format.iter_unpack(self.raw) ] 141 142 def __repr__(self): 143 return f"[type:{self.type} len:{self._len}] {self.raw}" 144 145 146class NlAttrs: 147 def __init__(self, msg, offset=0): 148 self.attrs = [] 149 150 while offset < len(msg): 151 attr = NlAttr(msg, offset) 152 offset += attr.full_len 153 self.attrs.append(attr) 154 155 def __iter__(self): 156 yield from self.attrs 157 158 def __repr__(self): 159 msg = '' 160 for a in self.attrs: 161 if msg: 162 msg += '\n' 163 msg += repr(a) 164 return msg 165 166 167class NlMsg: 168 def __init__(self, msg, offset, attr_space=None): 169 self.hdr = msg[offset : offset + 16] 170 171 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ 172 struct.unpack("IHHII", self.hdr) 173 174 self.raw = msg[offset + 16 : offset + self.nl_len] 175 176 self.error = 0 177 self.done = 0 178 179 extack_off = None 180 if self.nl_type == Netlink.NLMSG_ERROR: 181 self.error = struct.unpack("i", self.raw[0:4])[0] 182 self.done = 1 183 extack_off = 20 184 elif self.nl_type == Netlink.NLMSG_DONE: 185 self.done = 1 186 extack_off = 4 187 188 self.extack = None 189 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: 190 self.extack = dict() 191 extack_attrs = NlAttrs(self.raw[extack_off:]) 192 for extack in extack_attrs: 193 if extack.type == Netlink.NLMSGERR_ATTR_MSG: 194 self.extack['msg'] = extack.as_strz() 195 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: 196 self.extack['miss-type'] = extack.as_scalar('u32') 197 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: 198 self.extack['miss-nest'] = extack.as_scalar('u32') 199 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: 200 self.extack['bad-attr-offs'] = extack.as_scalar('u32') 201 else: 202 if 'unknown' not in self.extack: 203 self.extack['unknown'] = [] 204 self.extack['unknown'].append(extack) 205 206 if attr_space: 207 # We don't have the ability to parse nests yet, so only do global 208 if 'miss-type' in self.extack and 'miss-nest' not in self.extack: 209 miss_type = self.extack['miss-type'] 210 if miss_type in attr_space.attrs_by_val: 211 spec = attr_space.attrs_by_val[miss_type] 212 desc = spec['name'] 213 if 'doc' in spec: 214 desc += f" ({spec['doc']})" 215 self.extack['miss-type'] = desc 216 217 def cmd(self): 218 return self.nl_type 219 220 def __repr__(self): 221 msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}" 222 if self.error: 223 msg += '\n\terror: ' + str(self.error) 224 if self.extack: 225 msg += '\n\textack: ' + repr(self.extack) 226 return msg 227 228 229class NlMsgs: 230 def __init__(self, data, attr_space=None): 231 self.msgs = [] 232 233 offset = 0 234 while offset < len(data): 235 msg = NlMsg(data, offset, attr_space=attr_space) 236 offset += msg.nl_len 237 self.msgs.append(msg) 238 239 def __iter__(self): 240 yield from self.msgs 241 242 243genl_family_name_to_id = None 244 245 246def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): 247 # we prepend length in _genl_msg_finalize() 248 if seq is None: 249 seq = random.randint(1, 1024) 250 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 251 genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0) 252 return nlmsg + genlmsg 253 254 255def _genl_msg_finalize(msg): 256 return struct.pack("I", len(msg) + 4) + msg 257 258 259def _genl_load_families(): 260 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: 261 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 262 263 msg = _genl_msg(Netlink.GENL_ID_CTRL, 264 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, 265 Netlink.CTRL_CMD_GETFAMILY, 1) 266 msg = _genl_msg_finalize(msg) 267 268 sock.send(msg, 0) 269 270 global genl_family_name_to_id 271 genl_family_name_to_id = dict() 272 273 while True: 274 reply = sock.recv(128 * 1024) 275 nms = NlMsgs(reply) 276 for nl_msg in nms: 277 if nl_msg.error: 278 print("Netlink error:", nl_msg.error) 279 return 280 if nl_msg.done: 281 return 282 283 gm = GenlMsg(nl_msg) 284 fam = dict() 285 for attr in NlAttrs(gm.raw): 286 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: 287 fam['id'] = attr.as_scalar('u16') 288 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: 289 fam['name'] = attr.as_strz() 290 elif attr.type == Netlink.CTRL_ATTR_MAXATTR: 291 fam['maxattr'] = attr.as_scalar('u32') 292 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: 293 fam['mcast'] = dict() 294 for entry in NlAttrs(attr.raw): 295 mcast_name = None 296 mcast_id = None 297 for entry_attr in NlAttrs(entry.raw): 298 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: 299 mcast_name = entry_attr.as_strz() 300 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: 301 mcast_id = entry_attr.as_scalar('u32') 302 if mcast_name and mcast_id is not None: 303 fam['mcast'][mcast_name] = mcast_id 304 if 'name' in fam and 'id' in fam: 305 genl_family_name_to_id[fam['name']] = fam 306 307 308class GenlMsg: 309 def __init__(self, nl_msg): 310 self.nl = nl_msg 311 self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0) 312 self.raw = nl_msg.raw[4:] 313 314 def cmd(self): 315 return self.genl_cmd 316 317 def __repr__(self): 318 msg = repr(self.nl) 319 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" 320 for a in self.raw_attrs: 321 msg += '\t\t' + repr(a) + '\n' 322 return msg 323 324 325class NetlinkProtocol: 326 def __init__(self, family_name, proto_num): 327 self.family_name = family_name 328 self.proto_num = proto_num 329 330 def _message(self, nl_type, nl_flags, seq=None): 331 if seq is None: 332 seq = random.randint(1, 1024) 333 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 334 return nlmsg 335 336 def message(self, flags, command, version, seq=None): 337 return self._message(command, flags, seq) 338 339 def _decode(self, nl_msg): 340 return nl_msg 341 342 def decode(self, ynl, nl_msg): 343 msg = self._decode(nl_msg) 344 fixed_header_size = 0 345 if ynl: 346 op = ynl.rsp_by_value[msg.cmd()] 347 fixed_header_size = ynl._struct_size(op.fixed_header) 348 msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size) 349 return msg 350 351 def get_mcast_id(self, mcast_name, mcast_groups): 352 if mcast_name not in mcast_groups: 353 raise Exception(f'Multicast group "{mcast_name}" not present in the spec') 354 return mcast_groups[mcast_name].value 355 356 def msghdr_size(self): 357 return 16 358 359 360class GenlProtocol(NetlinkProtocol): 361 def __init__(self, family_name): 362 super().__init__(family_name, Netlink.NETLINK_GENERIC) 363 364 global genl_family_name_to_id 365 if genl_family_name_to_id is None: 366 _genl_load_families() 367 368 self.genl_family = genl_family_name_to_id[family_name] 369 self.family_id = genl_family_name_to_id[family_name]['id'] 370 371 def message(self, flags, command, version, seq=None): 372 nlmsg = self._message(self.family_id, flags, seq) 373 genlmsg = struct.pack("BBH", command, version, 0) 374 return nlmsg + genlmsg 375 376 def _decode(self, nl_msg): 377 return GenlMsg(nl_msg) 378 379 def get_mcast_id(self, mcast_name, mcast_groups): 380 if mcast_name not in self.genl_family['mcast']: 381 raise Exception(f'Multicast group "{mcast_name}" not present in the family') 382 return self.genl_family['mcast'][mcast_name] 383 384 def msghdr_size(self): 385 return super().msghdr_size() + 4 386 387 388class SpaceAttrs: 389 SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values']) 390 391 def __init__(self, attr_space, attrs, outer = None): 392 outer_scopes = outer.scopes if outer else [] 393 inner_scope = self.SpecValuesPair(attr_space, attrs) 394 self.scopes = [inner_scope] + outer_scopes 395 396 def lookup(self, name): 397 for scope in self.scopes: 398 if name in scope.spec: 399 if name in scope.values: 400 return scope.values[name] 401 spec_name = scope.spec.yaml['name'] 402 raise Exception( 403 f"No value for '{name}' in attribute space '{spec_name}'") 404 raise Exception(f"Attribute '{name}' not defined in any attribute-set") 405 406 407# 408# YNL implementation details. 409# 410 411 412class YnlFamily(SpecFamily): 413 def __init__(self, def_path, schema=None, process_unknown=False, 414 recv_size=0): 415 super().__init__(def_path, schema) 416 417 self.include_raw = False 418 self.process_unknown = process_unknown 419 420 try: 421 if self.proto == "netlink-raw": 422 self.nlproto = NetlinkProtocol(self.yaml['name'], 423 self.yaml['protonum']) 424 else: 425 self.nlproto = GenlProtocol(self.yaml['name']) 426 except KeyError: 427 raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel") 428 429 self._recv_dbg = False 430 # Note that netlink will use conservative (min) message size for 431 # the first dump recv() on the socket, our setting will only matter 432 # from the second recv() on. 433 self._recv_size = recv_size if recv_size else 131072 434 # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo) 435 # for a message, so smaller receive sizes will lead to truncation. 436 # Note that the min size for other families may be larger than 4k! 437 if self._recv_size < 4000: 438 raise ConfigError() 439 440 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num) 441 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 442 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) 443 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1) 444 445 self.async_msg_ids = set() 446 self.async_msg_queue = [] 447 448 for msg in self.msgs.values(): 449 if msg.is_async: 450 self.async_msg_ids.add(msg.rsp_value) 451 452 for op_name, op in self.ops.items(): 453 bound_f = functools.partial(self._op, op_name) 454 setattr(self, op.ident_name, bound_f) 455 456 457 def ntf_subscribe(self, mcast_name): 458 mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups) 459 self.sock.bind((0, 0)) 460 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, 461 mcast_id) 462 463 def set_recv_dbg(self, enabled): 464 self._recv_dbg = enabled 465 466 def _recv_dbg_print(self, reply, nl_msgs): 467 if not self._recv_dbg: 468 return 469 print("Recv: read", len(reply), "bytes,", 470 len(nl_msgs.msgs), "messages", file=sys.stderr) 471 for nl_msg in nl_msgs: 472 print(" ", nl_msg, file=sys.stderr) 473 474 def _encode_enum(self, attr_spec, value): 475 enum = self.consts[attr_spec['enum']] 476 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 477 scalar = 0 478 if isinstance(value, str): 479 value = [value] 480 for single_value in value: 481 scalar += enum.entries[single_value].user_value(as_flags = True) 482 return scalar 483 else: 484 return enum.entries[value].user_value() 485 486 def _get_scalar(self, attr_spec, value): 487 try: 488 return int(value) 489 except (ValueError, TypeError) as e: 490 if 'enum' not in attr_spec: 491 raise e 492 return self._encode_enum(attr_spec, value) 493 494 def _add_attr(self, space, name, value, search_attrs): 495 try: 496 attr = self.attr_sets[space][name] 497 except KeyError: 498 raise Exception(f"Space '{space}' has no attribute '{name}'") 499 nl_type = attr.value 500 501 if attr.is_multi and isinstance(value, list): 502 attr_payload = b'' 503 for subvalue in value: 504 attr_payload += self._add_attr(space, name, subvalue, search_attrs) 505 return attr_payload 506 507 if attr["type"] == 'nest': 508 nl_type |= Netlink.NLA_F_NESTED 509 attr_payload = b'' 510 sub_attrs = SpaceAttrs(self.attr_sets[space], value, search_attrs) 511 for subname, subvalue in value.items(): 512 attr_payload += self._add_attr(attr['nested-attributes'], 513 subname, subvalue, sub_attrs) 514 elif attr["type"] == 'flag': 515 if not value: 516 # If value is absent or false then skip attribute creation. 517 return b'' 518 attr_payload = b'' 519 elif attr["type"] == 'string': 520 attr_payload = str(value).encode('ascii') + b'\x00' 521 elif attr["type"] == 'binary': 522 if isinstance(value, bytes): 523 attr_payload = value 524 elif isinstance(value, str): 525 attr_payload = bytes.fromhex(value) 526 elif isinstance(value, dict) and attr.struct_name: 527 attr_payload = self._encode_struct(attr.struct_name, value) 528 else: 529 raise Exception(f'Unknown type for binary attribute, value: {value}') 530 elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar: 531 scalar = self._get_scalar(attr, value) 532 if attr.is_auto_scalar: 533 attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64') 534 else: 535 attr_type = attr["type"] 536 format = NlAttr.get_format(attr_type, attr.byte_order) 537 attr_payload = format.pack(scalar) 538 elif attr['type'] in "bitfield32": 539 scalar_value = self._get_scalar(attr, value["value"]) 540 scalar_selector = self._get_scalar(attr, value["selector"]) 541 attr_payload = struct.pack("II", scalar_value, scalar_selector) 542 elif attr['type'] == 'sub-message': 543 msg_format = self._resolve_selector(attr, search_attrs) 544 attr_payload = b'' 545 if msg_format.fixed_header: 546 attr_payload += self._encode_struct(msg_format.fixed_header, value) 547 if msg_format.attr_set: 548 if msg_format.attr_set in self.attr_sets: 549 nl_type |= Netlink.NLA_F_NESTED 550 sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs) 551 for subname, subvalue in value.items(): 552 attr_payload += self._add_attr(msg_format.attr_set, 553 subname, subvalue, sub_attrs) 554 else: 555 raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'") 556 else: 557 raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') 558 559 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) 560 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad 561 562 def _decode_enum(self, raw, attr_spec): 563 enum = self.consts[attr_spec['enum']] 564 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 565 i = 0 566 value = set() 567 while raw: 568 if raw & 1: 569 value.add(enum.entries_by_val[i].name) 570 raw >>= 1 571 i += 1 572 else: 573 value = enum.entries_by_val[raw].name 574 return value 575 576 def _decode_binary(self, attr, attr_spec): 577 if attr_spec.struct_name: 578 decoded = self._decode_struct(attr.raw, attr_spec.struct_name) 579 elif attr_spec.sub_type: 580 decoded = attr.as_c_array(attr_spec.sub_type) 581 else: 582 decoded = attr.as_bin() 583 if attr_spec.display_hint: 584 decoded = self._formatted_string(decoded, attr_spec.display_hint) 585 return decoded 586 587 def _decode_array_nest(self, attr, attr_spec): 588 decoded = [] 589 offset = 0 590 while offset < len(attr.raw): 591 item = NlAttr(attr.raw, offset) 592 offset += item.full_len 593 594 subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes']) 595 decoded.append({ item.type: subattrs }) 596 return decoded 597 598 def _decode_nest_type_value(self, attr, attr_spec): 599 decoded = {} 600 value = attr 601 for name in attr_spec['type-value']: 602 value = NlAttr(value.raw, 0) 603 decoded[name] = value.type 604 subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes']) 605 decoded.update(subattrs) 606 return decoded 607 608 def _decode_unknown(self, attr): 609 if attr.is_nest: 610 return self._decode(NlAttrs(attr.raw), None) 611 else: 612 return attr.as_bin() 613 614 def _rsp_add(self, rsp, name, is_multi, decoded): 615 if is_multi == None: 616 if name in rsp and type(rsp[name]) is not list: 617 rsp[name] = [rsp[name]] 618 is_multi = True 619 else: 620 is_multi = False 621 622 if not is_multi: 623 rsp[name] = decoded 624 elif name in rsp: 625 rsp[name].append(decoded) 626 else: 627 rsp[name] = [decoded] 628 629 def _resolve_selector(self, attr_spec, search_attrs): 630 sub_msg = attr_spec.sub_message 631 if sub_msg not in self.sub_msgs: 632 raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}") 633 sub_msg_spec = self.sub_msgs[sub_msg] 634 635 selector = attr_spec.selector 636 value = search_attrs.lookup(selector) 637 if value not in sub_msg_spec.formats: 638 raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'") 639 640 spec = sub_msg_spec.formats[value] 641 return spec 642 643 def _decode_sub_msg(self, attr, attr_spec, search_attrs): 644 msg_format = self._resolve_selector(attr_spec, search_attrs) 645 decoded = {} 646 offset = 0 647 if msg_format.fixed_header: 648 decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header)); 649 offset = self._struct_size(msg_format.fixed_header) 650 if msg_format.attr_set: 651 if msg_format.attr_set in self.attr_sets: 652 subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set) 653 decoded.update(subdict) 654 else: 655 raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'") 656 return decoded 657 658 def _decode(self, attrs, space, outer_attrs = None): 659 rsp = dict() 660 if space: 661 attr_space = self.attr_sets[space] 662 search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs) 663 664 for attr in attrs: 665 try: 666 attr_spec = attr_space.attrs_by_val[attr.type] 667 except (KeyError, UnboundLocalError): 668 if not self.process_unknown: 669 raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'") 670 attr_name = f"UnknownAttr({attr.type})" 671 self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr)) 672 continue 673 674 if attr_spec["type"] == 'nest': 675 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs) 676 decoded = subdict 677 elif attr_spec["type"] == 'string': 678 decoded = attr.as_strz() 679 elif attr_spec["type"] == 'binary': 680 decoded = self._decode_binary(attr, attr_spec) 681 elif attr_spec["type"] == 'flag': 682 decoded = True 683 elif attr_spec.is_auto_scalar: 684 decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order) 685 elif attr_spec["type"] in NlAttr.type_formats: 686 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) 687 if 'enum' in attr_spec: 688 decoded = self._decode_enum(decoded, attr_spec) 689 elif attr_spec["type"] == 'array-nest': 690 decoded = self._decode_array_nest(attr, attr_spec) 691 elif attr_spec["type"] == 'bitfield32': 692 value, selector = struct.unpack("II", attr.raw) 693 if 'enum' in attr_spec: 694 value = self._decode_enum(value, attr_spec) 695 selector = self._decode_enum(selector, attr_spec) 696 decoded = {"value": value, "selector": selector} 697 elif attr_spec["type"] == 'sub-message': 698 decoded = self._decode_sub_msg(attr, attr_spec, search_attrs) 699 elif attr_spec["type"] == 'nest-type-value': 700 decoded = self._decode_nest_type_value(attr, attr_spec) 701 else: 702 if not self.process_unknown: 703 raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') 704 decoded = self._decode_unknown(attr) 705 706 self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded) 707 708 return rsp 709 710 def _decode_extack_path(self, attrs, attr_set, offset, target): 711 for attr in attrs: 712 try: 713 attr_spec = attr_set.attrs_by_val[attr.type] 714 except KeyError: 715 raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") 716 if offset > target: 717 break 718 if offset == target: 719 return '.' + attr_spec.name 720 721 if offset + attr.full_len <= target: 722 offset += attr.full_len 723 continue 724 if attr_spec['type'] != 'nest': 725 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") 726 offset += 4 727 subpath = self._decode_extack_path(NlAttrs(attr.raw), 728 self.attr_sets[attr_spec['nested-attributes']], 729 offset, target) 730 if subpath is None: 731 return None 732 return '.' + attr_spec.name + subpath 733 734 return None 735 736 def _decode_extack(self, request, op, extack): 737 if 'bad-attr-offs' not in extack: 738 return 739 740 msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set)) 741 offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header) 742 path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset, 743 extack['bad-attr-offs']) 744 if path: 745 del extack['bad-attr-offs'] 746 extack['bad-attr'] = path 747 748 def _struct_size(self, name): 749 if name: 750 members = self.consts[name].members 751 size = 0 752 for m in members: 753 if m.type in ['pad', 'binary']: 754 if m.struct: 755 size += self._struct_size(m.struct) 756 else: 757 size += m.len 758 else: 759 format = NlAttr.get_format(m.type, m.byte_order) 760 size += format.size 761 return size 762 else: 763 return 0 764 765 def _decode_struct(self, data, name): 766 members = self.consts[name].members 767 attrs = dict() 768 offset = 0 769 for m in members: 770 value = None 771 if m.type == 'pad': 772 offset += m.len 773 elif m.type == 'binary': 774 if m.struct: 775 len = self._struct_size(m.struct) 776 value = self._decode_struct(data[offset : offset + len], 777 m.struct) 778 offset += len 779 else: 780 value = data[offset : offset + m.len] 781 offset += m.len 782 else: 783 format = NlAttr.get_format(m.type, m.byte_order) 784 [ value ] = format.unpack_from(data, offset) 785 offset += format.size 786 if value is not None: 787 if m.enum: 788 value = self._decode_enum(value, m) 789 elif m.display_hint: 790 value = self._formatted_string(value, m.display_hint) 791 attrs[m.name] = value 792 return attrs 793 794 def _encode_struct(self, name, vals): 795 members = self.consts[name].members 796 attr_payload = b'' 797 for m in members: 798 value = vals.pop(m.name) if m.name in vals else None 799 if m.type == 'pad': 800 attr_payload += bytearray(m.len) 801 elif m.type == 'binary': 802 if m.struct: 803 if value is None: 804 value = dict() 805 attr_payload += self._encode_struct(m.struct, value) 806 else: 807 if value is None: 808 attr_payload += bytearray(m.len) 809 else: 810 attr_payload += bytes.fromhex(value) 811 else: 812 if value is None: 813 value = 0 814 format = NlAttr.get_format(m.type, m.byte_order) 815 attr_payload += format.pack(value) 816 return attr_payload 817 818 def _formatted_string(self, raw, display_hint): 819 if display_hint == 'mac': 820 formatted = ':'.join('%02x' % b for b in raw) 821 elif display_hint == 'hex': 822 formatted = bytes.hex(raw, ' ') 823 elif display_hint in [ 'ipv4', 'ipv6' ]: 824 formatted = format(ipaddress.ip_address(raw)) 825 elif display_hint == 'uuid': 826 formatted = str(uuid.UUID(bytes=raw)) 827 else: 828 formatted = raw 829 return formatted 830 831 def handle_ntf(self, decoded): 832 msg = dict() 833 if self.include_raw: 834 msg['raw'] = decoded 835 op = self.rsp_by_value[decoded.cmd()] 836 attrs = self._decode(decoded.raw_attrs, op.attr_set.name) 837 if op.fixed_header: 838 attrs.update(self._decode_struct(decoded.raw, op.fixed_header)) 839 840 msg['name'] = op['name'] 841 msg['msg'] = attrs 842 self.async_msg_queue.append(msg) 843 844 def check_ntf(self): 845 while True: 846 try: 847 reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT) 848 except BlockingIOError: 849 return 850 851 nms = NlMsgs(reply) 852 self._recv_dbg_print(reply, nms) 853 for nl_msg in nms: 854 if nl_msg.error: 855 print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) 856 print(nl_msg) 857 continue 858 if nl_msg.done: 859 print("Netlink done while checking for ntf!?") 860 continue 861 862 decoded = self.nlproto.decode(self, nl_msg) 863 if decoded.cmd() not in self.async_msg_ids: 864 print("Unexpected msg id done while checking for ntf", decoded) 865 continue 866 867 self.handle_ntf(decoded) 868 869 def operation_do_attributes(self, name): 870 """ 871 For a given operation name, find and return a supported 872 set of attributes (as a dict). 873 """ 874 op = self.find_operation(name) 875 if not op: 876 return None 877 878 return op['do']['request']['attributes'].copy() 879 880 def _op(self, method, vals, flags=None, dump=False): 881 op = self.ops[method] 882 883 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK 884 for flag in flags or []: 885 nl_flags |= flag 886 if dump: 887 nl_flags |= Netlink.NLM_F_DUMP 888 889 req_seq = random.randint(1024, 65535) 890 msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq) 891 if op.fixed_header: 892 msg += self._encode_struct(op.fixed_header, vals) 893 search_attrs = SpaceAttrs(op.attr_set, vals) 894 for name, value in vals.items(): 895 msg += self._add_attr(op.attr_set.name, name, value, search_attrs) 896 msg = _genl_msg_finalize(msg) 897 898 self.sock.send(msg, 0) 899 900 done = False 901 rsp = [] 902 while not done: 903 reply = self.sock.recv(self._recv_size) 904 nms = NlMsgs(reply, attr_space=op.attr_set) 905 self._recv_dbg_print(reply, nms) 906 for nl_msg in nms: 907 if nl_msg.extack: 908 self._decode_extack(msg, op, nl_msg.extack) 909 910 if nl_msg.error: 911 raise NlError(nl_msg) 912 if nl_msg.done: 913 if nl_msg.extack: 914 print("Netlink warning:") 915 print(nl_msg) 916 done = True 917 break 918 919 decoded = self.nlproto.decode(self, nl_msg) 920 921 # Check if this is a reply to our request 922 if nl_msg.nl_seq != req_seq or decoded.cmd() != op.rsp_value: 923 if decoded.cmd() in self.async_msg_ids: 924 self.handle_ntf(decoded) 925 continue 926 else: 927 print('Unexpected message: ' + repr(decoded)) 928 continue 929 930 rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name) 931 if op.fixed_header: 932 rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header)) 933 rsp.append(rsp_msg) 934 935 if not rsp: 936 return None 937 if not dump and len(rsp) == 1: 938 return rsp[0] 939 return rsp 940 941 def do(self, method, vals, flags=None): 942 return self._op(method, vals, flags) 943 944 def dump(self, method, vals): 945 return self._op(method, vals, [], dump=True) 946