1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause 2 3from collections import namedtuple 4import functools 5import os 6import random 7import socket 8import struct 9from struct import Struct 10import sys 11import yaml 12import ipaddress 13import uuid 14 15from .nlspec import SpecFamily 16 17# 18# Generic Netlink code which should really be in some library, but I can't quickly find one. 19# 20 21 22class Netlink: 23 # Netlink socket 24 SOL_NETLINK = 270 25 26 NETLINK_ADD_MEMBERSHIP = 1 27 NETLINK_CAP_ACK = 10 28 NETLINK_EXT_ACK = 11 29 NETLINK_GET_STRICT_CHK = 12 30 31 # Netlink message 32 NLMSG_ERROR = 2 33 NLMSG_DONE = 3 34 35 NLM_F_REQUEST = 1 36 NLM_F_ACK = 4 37 NLM_F_ROOT = 0x100 38 NLM_F_MATCH = 0x200 39 40 NLM_F_REPLACE = 0x100 41 NLM_F_EXCL = 0x200 42 NLM_F_CREATE = 0x400 43 NLM_F_APPEND = 0x800 44 45 NLM_F_CAPPED = 0x100 46 NLM_F_ACK_TLVS = 0x200 47 48 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH 49 50 NLA_F_NESTED = 0x8000 51 NLA_F_NET_BYTEORDER = 0x4000 52 53 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER 54 55 # Genetlink defines 56 NETLINK_GENERIC = 16 57 58 GENL_ID_CTRL = 0x10 59 60 # nlctrl 61 CTRL_CMD_GETFAMILY = 3 62 63 CTRL_ATTR_FAMILY_ID = 1 64 CTRL_ATTR_FAMILY_NAME = 2 65 CTRL_ATTR_MAXATTR = 5 66 CTRL_ATTR_MCAST_GROUPS = 7 67 68 CTRL_ATTR_MCAST_GRP_NAME = 1 69 CTRL_ATTR_MCAST_GRP_ID = 2 70 71 # Extack types 72 NLMSGERR_ATTR_MSG = 1 73 NLMSGERR_ATTR_OFFS = 2 74 NLMSGERR_ATTR_COOKIE = 3 75 NLMSGERR_ATTR_POLICY = 4 76 NLMSGERR_ATTR_MISS_TYPE = 5 77 NLMSGERR_ATTR_MISS_NEST = 6 78 79 80class NlError(Exception): 81 def __init__(self, nl_msg): 82 self.nl_msg = nl_msg 83 84 def __str__(self): 85 return f"Netlink error: {os.strerror(-self.nl_msg.error)}\n{self.nl_msg}" 86 87 88class ConfigError(Exception): 89 pass 90 91 92class NlAttr: 93 ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) 94 type_formats = { 95 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), 96 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), 97 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")), 98 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")), 99 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")), 100 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")), 101 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")), 102 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q")) 103 } 104 105 def __init__(self, raw, offset): 106 self._len, self._type = struct.unpack("HH", raw[offset : offset + 4]) 107 self.type = self._type & ~Netlink.NLA_TYPE_MASK 108 self.is_nest = self._type & Netlink.NLA_F_NESTED 109 self.payload_len = self._len 110 self.full_len = (self.payload_len + 3) & ~3 111 self.raw = raw[offset + 4 : offset + self.payload_len] 112 113 @classmethod 114 def get_format(cls, attr_type, byte_order=None): 115 format = cls.type_formats[attr_type] 116 if byte_order: 117 return format.big if byte_order == "big-endian" \ 118 else format.little 119 return format.native 120 121 def as_scalar(self, attr_type, byte_order=None): 122 format = self.get_format(attr_type, byte_order) 123 return format.unpack(self.raw)[0] 124 125 def as_auto_scalar(self, attr_type, byte_order=None): 126 if len(self.raw) != 4 and len(self.raw) != 8: 127 raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}") 128 real_type = attr_type[0] + str(len(self.raw) * 8) 129 format = self.get_format(real_type, byte_order) 130 return format.unpack(self.raw)[0] 131 132 def as_strz(self): 133 return self.raw.decode('ascii')[:-1] 134 135 def as_bin(self): 136 return self.raw 137 138 def as_c_array(self, type): 139 format = self.get_format(type) 140 return [ x[0] for x in format.iter_unpack(self.raw) ] 141 142 def __repr__(self): 143 return f"[type:{self.type} len:{self._len}] {self.raw}" 144 145 146class NlAttrs: 147 def __init__(self, msg, offset=0): 148 self.attrs = [] 149 150 while offset < len(msg): 151 attr = NlAttr(msg, offset) 152 offset += attr.full_len 153 self.attrs.append(attr) 154 155 def __iter__(self): 156 yield from self.attrs 157 158 def __repr__(self): 159 msg = '' 160 for a in self.attrs: 161 if msg: 162 msg += '\n' 163 msg += repr(a) 164 return msg 165 166 167class NlMsg: 168 def __init__(self, msg, offset, attr_space=None): 169 self.hdr = msg[offset : offset + 16] 170 171 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ 172 struct.unpack("IHHII", self.hdr) 173 174 self.raw = msg[offset + 16 : offset + self.nl_len] 175 176 self.error = 0 177 self.done = 0 178 179 extack_off = None 180 if self.nl_type == Netlink.NLMSG_ERROR: 181 self.error = struct.unpack("i", self.raw[0:4])[0] 182 self.done = 1 183 extack_off = 20 184 elif self.nl_type == Netlink.NLMSG_DONE: 185 self.error = struct.unpack("i", self.raw[0:4])[0] 186 self.done = 1 187 extack_off = 4 188 189 self.extack = None 190 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: 191 self.extack = dict() 192 extack_attrs = NlAttrs(self.raw[extack_off:]) 193 for extack in extack_attrs: 194 if extack.type == Netlink.NLMSGERR_ATTR_MSG: 195 self.extack['msg'] = extack.as_strz() 196 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: 197 self.extack['miss-type'] = extack.as_scalar('u32') 198 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: 199 self.extack['miss-nest'] = extack.as_scalar('u32') 200 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: 201 self.extack['bad-attr-offs'] = extack.as_scalar('u32') 202 else: 203 if 'unknown' not in self.extack: 204 self.extack['unknown'] = [] 205 self.extack['unknown'].append(extack) 206 207 if attr_space: 208 # We don't have the ability to parse nests yet, so only do global 209 if 'miss-type' in self.extack and 'miss-nest' not in self.extack: 210 miss_type = self.extack['miss-type'] 211 if miss_type in attr_space.attrs_by_val: 212 spec = attr_space.attrs_by_val[miss_type] 213 desc = spec['name'] 214 if 'doc' in spec: 215 desc += f" ({spec['doc']})" 216 self.extack['miss-type'] = desc 217 218 def cmd(self): 219 return self.nl_type 220 221 def __repr__(self): 222 msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}" 223 if self.error: 224 msg += '\n\terror: ' + str(self.error) 225 if self.extack: 226 msg += '\n\textack: ' + repr(self.extack) 227 return msg 228 229 230class NlMsgs: 231 def __init__(self, data, attr_space=None): 232 self.msgs = [] 233 234 offset = 0 235 while offset < len(data): 236 msg = NlMsg(data, offset, attr_space=attr_space) 237 offset += msg.nl_len 238 self.msgs.append(msg) 239 240 def __iter__(self): 241 yield from self.msgs 242 243 244genl_family_name_to_id = None 245 246 247def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): 248 # we prepend length in _genl_msg_finalize() 249 if seq is None: 250 seq = random.randint(1, 1024) 251 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 252 genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0) 253 return nlmsg + genlmsg 254 255 256def _genl_msg_finalize(msg): 257 return struct.pack("I", len(msg) + 4) + msg 258 259 260def _genl_load_families(): 261 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: 262 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 263 264 msg = _genl_msg(Netlink.GENL_ID_CTRL, 265 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, 266 Netlink.CTRL_CMD_GETFAMILY, 1) 267 msg = _genl_msg_finalize(msg) 268 269 sock.send(msg, 0) 270 271 global genl_family_name_to_id 272 genl_family_name_to_id = dict() 273 274 while True: 275 reply = sock.recv(128 * 1024) 276 nms = NlMsgs(reply) 277 for nl_msg in nms: 278 if nl_msg.error: 279 print("Netlink error:", nl_msg.error) 280 return 281 if nl_msg.done: 282 return 283 284 gm = GenlMsg(nl_msg) 285 fam = dict() 286 for attr in NlAttrs(gm.raw): 287 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: 288 fam['id'] = attr.as_scalar('u16') 289 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: 290 fam['name'] = attr.as_strz() 291 elif attr.type == Netlink.CTRL_ATTR_MAXATTR: 292 fam['maxattr'] = attr.as_scalar('u32') 293 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: 294 fam['mcast'] = dict() 295 for entry in NlAttrs(attr.raw): 296 mcast_name = None 297 mcast_id = None 298 for entry_attr in NlAttrs(entry.raw): 299 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: 300 mcast_name = entry_attr.as_strz() 301 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: 302 mcast_id = entry_attr.as_scalar('u32') 303 if mcast_name and mcast_id is not None: 304 fam['mcast'][mcast_name] = mcast_id 305 if 'name' in fam and 'id' in fam: 306 genl_family_name_to_id[fam['name']] = fam 307 308 309class GenlMsg: 310 def __init__(self, nl_msg): 311 self.nl = nl_msg 312 self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0) 313 self.raw = nl_msg.raw[4:] 314 315 def cmd(self): 316 return self.genl_cmd 317 318 def __repr__(self): 319 msg = repr(self.nl) 320 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" 321 for a in self.raw_attrs: 322 msg += '\t\t' + repr(a) + '\n' 323 return msg 324 325 326class NetlinkProtocol: 327 def __init__(self, family_name, proto_num): 328 self.family_name = family_name 329 self.proto_num = proto_num 330 331 def _message(self, nl_type, nl_flags, seq=None): 332 if seq is None: 333 seq = random.randint(1, 1024) 334 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 335 return nlmsg 336 337 def message(self, flags, command, version, seq=None): 338 return self._message(command, flags, seq) 339 340 def _decode(self, nl_msg): 341 return nl_msg 342 343 def decode(self, ynl, nl_msg): 344 msg = self._decode(nl_msg) 345 fixed_header_size = 0 346 if ynl: 347 op = ynl.rsp_by_value[msg.cmd()] 348 fixed_header_size = ynl._struct_size(op.fixed_header) 349 msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size) 350 return msg 351 352 def get_mcast_id(self, mcast_name, mcast_groups): 353 if mcast_name not in mcast_groups: 354 raise Exception(f'Multicast group "{mcast_name}" not present in the spec') 355 return mcast_groups[mcast_name].value 356 357 def msghdr_size(self): 358 return 16 359 360 361class GenlProtocol(NetlinkProtocol): 362 def __init__(self, family_name): 363 super().__init__(family_name, Netlink.NETLINK_GENERIC) 364 365 global genl_family_name_to_id 366 if genl_family_name_to_id is None: 367 _genl_load_families() 368 369 self.genl_family = genl_family_name_to_id[family_name] 370 self.family_id = genl_family_name_to_id[family_name]['id'] 371 372 def message(self, flags, command, version, seq=None): 373 nlmsg = self._message(self.family_id, flags, seq) 374 genlmsg = struct.pack("BBH", command, version, 0) 375 return nlmsg + genlmsg 376 377 def _decode(self, nl_msg): 378 return GenlMsg(nl_msg) 379 380 def get_mcast_id(self, mcast_name, mcast_groups): 381 if mcast_name not in self.genl_family['mcast']: 382 raise Exception(f'Multicast group "{mcast_name}" not present in the family') 383 return self.genl_family['mcast'][mcast_name] 384 385 def msghdr_size(self): 386 return super().msghdr_size() + 4 387 388 389class SpaceAttrs: 390 SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values']) 391 392 def __init__(self, attr_space, attrs, outer = None): 393 outer_scopes = outer.scopes if outer else [] 394 inner_scope = self.SpecValuesPair(attr_space, attrs) 395 self.scopes = [inner_scope] + outer_scopes 396 397 def lookup(self, name): 398 for scope in self.scopes: 399 if name in scope.spec: 400 if name in scope.values: 401 return scope.values[name] 402 spec_name = scope.spec.yaml['name'] 403 raise Exception( 404 f"No value for '{name}' in attribute space '{spec_name}'") 405 raise Exception(f"Attribute '{name}' not defined in any attribute-set") 406 407 408# 409# YNL implementation details. 410# 411 412 413class YnlFamily(SpecFamily): 414 def __init__(self, def_path, schema=None, process_unknown=False, 415 recv_size=0): 416 super().__init__(def_path, schema) 417 418 self.include_raw = False 419 self.process_unknown = process_unknown 420 421 try: 422 if self.proto == "netlink-raw": 423 self.nlproto = NetlinkProtocol(self.yaml['name'], 424 self.yaml['protonum']) 425 else: 426 self.nlproto = GenlProtocol(self.yaml['name']) 427 except KeyError: 428 raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel") 429 430 self._recv_dbg = False 431 # Note that netlink will use conservative (min) message size for 432 # the first dump recv() on the socket, our setting will only matter 433 # from the second recv() on. 434 self._recv_size = recv_size if recv_size else 131072 435 # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo) 436 # for a message, so smaller receive sizes will lead to truncation. 437 # Note that the min size for other families may be larger than 4k! 438 if self._recv_size < 4000: 439 raise ConfigError() 440 441 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num) 442 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 443 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) 444 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1) 445 446 self.async_msg_ids = set() 447 self.async_msg_queue = [] 448 449 for msg in self.msgs.values(): 450 if msg.is_async: 451 self.async_msg_ids.add(msg.rsp_value) 452 453 for op_name, op in self.ops.items(): 454 bound_f = functools.partial(self._op, op_name) 455 setattr(self, op.ident_name, bound_f) 456 457 458 def ntf_subscribe(self, mcast_name): 459 mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups) 460 self.sock.bind((0, 0)) 461 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, 462 mcast_id) 463 464 def set_recv_dbg(self, enabled): 465 self._recv_dbg = enabled 466 467 def _recv_dbg_print(self, reply, nl_msgs): 468 if not self._recv_dbg: 469 return 470 print("Recv: read", len(reply), "bytes,", 471 len(nl_msgs.msgs), "messages", file=sys.stderr) 472 for nl_msg in nl_msgs: 473 print(" ", nl_msg, file=sys.stderr) 474 475 def _encode_enum(self, attr_spec, value): 476 enum = self.consts[attr_spec['enum']] 477 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 478 scalar = 0 479 if isinstance(value, str): 480 value = [value] 481 for single_value in value: 482 scalar += enum.entries[single_value].user_value(as_flags = True) 483 return scalar 484 else: 485 return enum.entries[value].user_value() 486 487 def _get_scalar(self, attr_spec, value): 488 try: 489 return int(value) 490 except (ValueError, TypeError) as e: 491 if 'enum' not in attr_spec: 492 raise e 493 return self._encode_enum(attr_spec, value) 494 495 def _add_attr(self, space, name, value, search_attrs): 496 try: 497 attr = self.attr_sets[space][name] 498 except KeyError: 499 raise Exception(f"Space '{space}' has no attribute '{name}'") 500 nl_type = attr.value 501 502 if attr.is_multi and isinstance(value, list): 503 attr_payload = b'' 504 for subvalue in value: 505 attr_payload += self._add_attr(space, name, subvalue, search_attrs) 506 return attr_payload 507 508 if attr["type"] == 'nest': 509 nl_type |= Netlink.NLA_F_NESTED 510 attr_payload = b'' 511 sub_attrs = SpaceAttrs(self.attr_sets[space], value, search_attrs) 512 for subname, subvalue in value.items(): 513 attr_payload += self._add_attr(attr['nested-attributes'], 514 subname, subvalue, sub_attrs) 515 elif attr["type"] == 'flag': 516 if not value: 517 # If value is absent or false then skip attribute creation. 518 return b'' 519 attr_payload = b'' 520 elif attr["type"] == 'string': 521 attr_payload = str(value).encode('ascii') + b'\x00' 522 elif attr["type"] == 'binary': 523 if isinstance(value, bytes): 524 attr_payload = value 525 elif isinstance(value, str): 526 attr_payload = bytes.fromhex(value) 527 elif isinstance(value, dict) and attr.struct_name: 528 attr_payload = self._encode_struct(attr.struct_name, value) 529 else: 530 raise Exception(f'Unknown type for binary attribute, value: {value}') 531 elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar: 532 scalar = self._get_scalar(attr, value) 533 if attr.is_auto_scalar: 534 attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64') 535 else: 536 attr_type = attr["type"] 537 format = NlAttr.get_format(attr_type, attr.byte_order) 538 attr_payload = format.pack(scalar) 539 elif attr['type'] in "bitfield32": 540 scalar_value = self._get_scalar(attr, value["value"]) 541 scalar_selector = self._get_scalar(attr, value["selector"]) 542 attr_payload = struct.pack("II", scalar_value, scalar_selector) 543 elif attr['type'] == 'sub-message': 544 msg_format = self._resolve_selector(attr, search_attrs) 545 attr_payload = b'' 546 if msg_format.fixed_header: 547 attr_payload += self._encode_struct(msg_format.fixed_header, value) 548 if msg_format.attr_set: 549 if msg_format.attr_set in self.attr_sets: 550 nl_type |= Netlink.NLA_F_NESTED 551 sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs) 552 for subname, subvalue in value.items(): 553 attr_payload += self._add_attr(msg_format.attr_set, 554 subname, subvalue, sub_attrs) 555 else: 556 raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'") 557 else: 558 raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') 559 560 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) 561 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad 562 563 def _decode_enum(self, raw, attr_spec): 564 enum = self.consts[attr_spec['enum']] 565 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 566 i = 0 567 value = set() 568 while raw: 569 if raw & 1: 570 value.add(enum.entries_by_val[i].name) 571 raw >>= 1 572 i += 1 573 else: 574 value = enum.entries_by_val[raw].name 575 return value 576 577 def _decode_binary(self, attr, attr_spec): 578 if attr_spec.struct_name: 579 decoded = self._decode_struct(attr.raw, attr_spec.struct_name) 580 elif attr_spec.sub_type: 581 decoded = attr.as_c_array(attr_spec.sub_type) 582 else: 583 decoded = attr.as_bin() 584 if attr_spec.display_hint: 585 decoded = self._formatted_string(decoded, attr_spec.display_hint) 586 return decoded 587 588 def _decode_array_nest(self, attr, attr_spec): 589 decoded = [] 590 offset = 0 591 while offset < len(attr.raw): 592 item = NlAttr(attr.raw, offset) 593 offset += item.full_len 594 595 subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes']) 596 decoded.append({ item.type: subattrs }) 597 return decoded 598 599 def _decode_nest_type_value(self, attr, attr_spec): 600 decoded = {} 601 value = attr 602 for name in attr_spec['type-value']: 603 value = NlAttr(value.raw, 0) 604 decoded[name] = value.type 605 subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes']) 606 decoded.update(subattrs) 607 return decoded 608 609 def _decode_unknown(self, attr): 610 if attr.is_nest: 611 return self._decode(NlAttrs(attr.raw), None) 612 else: 613 return attr.as_bin() 614 615 def _rsp_add(self, rsp, name, is_multi, decoded): 616 if is_multi == None: 617 if name in rsp and type(rsp[name]) is not list: 618 rsp[name] = [rsp[name]] 619 is_multi = True 620 else: 621 is_multi = False 622 623 if not is_multi: 624 rsp[name] = decoded 625 elif name in rsp: 626 rsp[name].append(decoded) 627 else: 628 rsp[name] = [decoded] 629 630 def _resolve_selector(self, attr_spec, search_attrs): 631 sub_msg = attr_spec.sub_message 632 if sub_msg not in self.sub_msgs: 633 raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}") 634 sub_msg_spec = self.sub_msgs[sub_msg] 635 636 selector = attr_spec.selector 637 value = search_attrs.lookup(selector) 638 if value not in sub_msg_spec.formats: 639 raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'") 640 641 spec = sub_msg_spec.formats[value] 642 return spec 643 644 def _decode_sub_msg(self, attr, attr_spec, search_attrs): 645 msg_format = self._resolve_selector(attr_spec, search_attrs) 646 decoded = {} 647 offset = 0 648 if msg_format.fixed_header: 649 decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header)); 650 offset = self._struct_size(msg_format.fixed_header) 651 if msg_format.attr_set: 652 if msg_format.attr_set in self.attr_sets: 653 subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set) 654 decoded.update(subdict) 655 else: 656 raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'") 657 return decoded 658 659 def _decode(self, attrs, space, outer_attrs = None): 660 rsp = dict() 661 if space: 662 attr_space = self.attr_sets[space] 663 search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs) 664 665 for attr in attrs: 666 try: 667 attr_spec = attr_space.attrs_by_val[attr.type] 668 except (KeyError, UnboundLocalError): 669 if not self.process_unknown: 670 raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'") 671 attr_name = f"UnknownAttr({attr.type})" 672 self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr)) 673 continue 674 675 if attr_spec["type"] == 'nest': 676 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs) 677 decoded = subdict 678 elif attr_spec["type"] == 'string': 679 decoded = attr.as_strz() 680 elif attr_spec["type"] == 'binary': 681 decoded = self._decode_binary(attr, attr_spec) 682 elif attr_spec["type"] == 'flag': 683 decoded = True 684 elif attr_spec.is_auto_scalar: 685 decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order) 686 elif attr_spec["type"] in NlAttr.type_formats: 687 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) 688 if 'enum' in attr_spec: 689 decoded = self._decode_enum(decoded, attr_spec) 690 elif attr_spec["type"] == 'array-nest': 691 decoded = self._decode_array_nest(attr, attr_spec) 692 elif attr_spec["type"] == 'bitfield32': 693 value, selector = struct.unpack("II", attr.raw) 694 if 'enum' in attr_spec: 695 value = self._decode_enum(value, attr_spec) 696 selector = self._decode_enum(selector, attr_spec) 697 decoded = {"value": value, "selector": selector} 698 elif attr_spec["type"] == 'sub-message': 699 decoded = self._decode_sub_msg(attr, attr_spec, search_attrs) 700 elif attr_spec["type"] == 'nest-type-value': 701 decoded = self._decode_nest_type_value(attr, attr_spec) 702 else: 703 if not self.process_unknown: 704 raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') 705 decoded = self._decode_unknown(attr) 706 707 self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded) 708 709 return rsp 710 711 def _decode_extack_path(self, attrs, attr_set, offset, target): 712 for attr in attrs: 713 try: 714 attr_spec = attr_set.attrs_by_val[attr.type] 715 except KeyError: 716 raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") 717 if offset > target: 718 break 719 if offset == target: 720 return '.' + attr_spec.name 721 722 if offset + attr.full_len <= target: 723 offset += attr.full_len 724 continue 725 if attr_spec['type'] != 'nest': 726 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") 727 offset += 4 728 subpath = self._decode_extack_path(NlAttrs(attr.raw), 729 self.attr_sets[attr_spec['nested-attributes']], 730 offset, target) 731 if subpath is None: 732 return None 733 return '.' + attr_spec.name + subpath 734 735 return None 736 737 def _decode_extack(self, request, op, extack): 738 if 'bad-attr-offs' not in extack: 739 return 740 741 msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set)) 742 offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header) 743 path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset, 744 extack['bad-attr-offs']) 745 if path: 746 del extack['bad-attr-offs'] 747 extack['bad-attr'] = path 748 749 def _struct_size(self, name): 750 if name: 751 members = self.consts[name].members 752 size = 0 753 for m in members: 754 if m.type in ['pad', 'binary']: 755 if m.struct: 756 size += self._struct_size(m.struct) 757 else: 758 size += m.len 759 else: 760 format = NlAttr.get_format(m.type, m.byte_order) 761 size += format.size 762 return size 763 else: 764 return 0 765 766 def _decode_struct(self, data, name): 767 members = self.consts[name].members 768 attrs = dict() 769 offset = 0 770 for m in members: 771 value = None 772 if m.type == 'pad': 773 offset += m.len 774 elif m.type == 'binary': 775 if m.struct: 776 len = self._struct_size(m.struct) 777 value = self._decode_struct(data[offset : offset + len], 778 m.struct) 779 offset += len 780 else: 781 value = data[offset : offset + m.len] 782 offset += m.len 783 else: 784 format = NlAttr.get_format(m.type, m.byte_order) 785 [ value ] = format.unpack_from(data, offset) 786 offset += format.size 787 if value is not None: 788 if m.enum: 789 value = self._decode_enum(value, m) 790 elif m.display_hint: 791 value = self._formatted_string(value, m.display_hint) 792 attrs[m.name] = value 793 return attrs 794 795 def _encode_struct(self, name, vals): 796 members = self.consts[name].members 797 attr_payload = b'' 798 for m in members: 799 value = vals.pop(m.name) if m.name in vals else None 800 if m.type == 'pad': 801 attr_payload += bytearray(m.len) 802 elif m.type == 'binary': 803 if m.struct: 804 if value is None: 805 value = dict() 806 attr_payload += self._encode_struct(m.struct, value) 807 else: 808 if value is None: 809 attr_payload += bytearray(m.len) 810 else: 811 attr_payload += bytes.fromhex(value) 812 else: 813 if value is None: 814 value = 0 815 format = NlAttr.get_format(m.type, m.byte_order) 816 attr_payload += format.pack(value) 817 return attr_payload 818 819 def _formatted_string(self, raw, display_hint): 820 if display_hint == 'mac': 821 formatted = ':'.join('%02x' % b for b in raw) 822 elif display_hint == 'hex': 823 formatted = bytes.hex(raw, ' ') 824 elif display_hint in [ 'ipv4', 'ipv6' ]: 825 formatted = format(ipaddress.ip_address(raw)) 826 elif display_hint == 'uuid': 827 formatted = str(uuid.UUID(bytes=raw)) 828 else: 829 formatted = raw 830 return formatted 831 832 def handle_ntf(self, decoded): 833 msg = dict() 834 if self.include_raw: 835 msg['raw'] = decoded 836 op = self.rsp_by_value[decoded.cmd()] 837 attrs = self._decode(decoded.raw_attrs, op.attr_set.name) 838 if op.fixed_header: 839 attrs.update(self._decode_struct(decoded.raw, op.fixed_header)) 840 841 msg['name'] = op['name'] 842 msg['msg'] = attrs 843 self.async_msg_queue.append(msg) 844 845 def check_ntf(self): 846 while True: 847 try: 848 reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT) 849 except BlockingIOError: 850 return 851 852 nms = NlMsgs(reply) 853 self._recv_dbg_print(reply, nms) 854 for nl_msg in nms: 855 if nl_msg.error: 856 print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) 857 print(nl_msg) 858 continue 859 if nl_msg.done: 860 print("Netlink done while checking for ntf!?") 861 continue 862 863 decoded = self.nlproto.decode(self, nl_msg) 864 if decoded.cmd() not in self.async_msg_ids: 865 print("Unexpected msg id done while checking for ntf", decoded) 866 continue 867 868 self.handle_ntf(decoded) 869 870 def operation_do_attributes(self, name): 871 """ 872 For a given operation name, find and return a supported 873 set of attributes (as a dict). 874 """ 875 op = self.find_operation(name) 876 if not op: 877 return None 878 879 return op['do']['request']['attributes'].copy() 880 881 def _op(self, method, vals, flags=None, dump=False): 882 op = self.ops[method] 883 884 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK 885 for flag in flags or []: 886 nl_flags |= flag 887 if dump: 888 nl_flags |= Netlink.NLM_F_DUMP 889 890 req_seq = random.randint(1024, 65535) 891 msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq) 892 if op.fixed_header: 893 msg += self._encode_struct(op.fixed_header, vals) 894 search_attrs = SpaceAttrs(op.attr_set, vals) 895 for name, value in vals.items(): 896 msg += self._add_attr(op.attr_set.name, name, value, search_attrs) 897 msg = _genl_msg_finalize(msg) 898 899 self.sock.send(msg, 0) 900 901 done = False 902 rsp = [] 903 while not done: 904 reply = self.sock.recv(self._recv_size) 905 nms = NlMsgs(reply, attr_space=op.attr_set) 906 self._recv_dbg_print(reply, nms) 907 for nl_msg in nms: 908 if nl_msg.extack: 909 self._decode_extack(msg, op, nl_msg.extack) 910 911 if nl_msg.error: 912 raise NlError(nl_msg) 913 if nl_msg.done: 914 if nl_msg.extack: 915 print("Netlink warning:") 916 print(nl_msg) 917 done = True 918 break 919 920 decoded = self.nlproto.decode(self, nl_msg) 921 922 # Check if this is a reply to our request 923 if nl_msg.nl_seq != req_seq or decoded.cmd() != op.rsp_value: 924 if decoded.cmd() in self.async_msg_ids: 925 self.handle_ntf(decoded) 926 continue 927 else: 928 print('Unexpected message: ' + repr(decoded)) 929 continue 930 931 rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name) 932 if op.fixed_header: 933 rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header)) 934 rsp.append(rsp_msg) 935 936 if not rsp: 937 return None 938 if not dump and len(rsp) == 1: 939 return rsp[0] 940 return rsp 941 942 def do(self, method, vals, flags=None): 943 return self._op(method, vals, flags) 944 945 def dump(self, method, vals): 946 return self._op(method, vals, [], dump=True) 947