1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause 2 3from collections import namedtuple 4import functools 5import os 6import random 7import socket 8import struct 9from struct import Struct 10import yaml 11import ipaddress 12import uuid 13 14from .nlspec import SpecFamily 15 16# 17# Generic Netlink code which should really be in some library, but I can't quickly find one. 18# 19 20 21class Netlink: 22 # Netlink socket 23 SOL_NETLINK = 270 24 25 NETLINK_ADD_MEMBERSHIP = 1 26 NETLINK_CAP_ACK = 10 27 NETLINK_EXT_ACK = 11 28 NETLINK_GET_STRICT_CHK = 12 29 30 # Netlink message 31 NLMSG_ERROR = 2 32 NLMSG_DONE = 3 33 34 NLM_F_REQUEST = 1 35 NLM_F_ACK = 4 36 NLM_F_ROOT = 0x100 37 NLM_F_MATCH = 0x200 38 39 NLM_F_REPLACE = 0x100 40 NLM_F_EXCL = 0x200 41 NLM_F_CREATE = 0x400 42 NLM_F_APPEND = 0x800 43 44 NLM_F_CAPPED = 0x100 45 NLM_F_ACK_TLVS = 0x200 46 47 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH 48 49 NLA_F_NESTED = 0x8000 50 NLA_F_NET_BYTEORDER = 0x4000 51 52 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER 53 54 # Genetlink defines 55 NETLINK_GENERIC = 16 56 57 GENL_ID_CTRL = 0x10 58 59 # nlctrl 60 CTRL_CMD_GETFAMILY = 3 61 62 CTRL_ATTR_FAMILY_ID = 1 63 CTRL_ATTR_FAMILY_NAME = 2 64 CTRL_ATTR_MAXATTR = 5 65 CTRL_ATTR_MCAST_GROUPS = 7 66 67 CTRL_ATTR_MCAST_GRP_NAME = 1 68 CTRL_ATTR_MCAST_GRP_ID = 2 69 70 # Extack types 71 NLMSGERR_ATTR_MSG = 1 72 NLMSGERR_ATTR_OFFS = 2 73 NLMSGERR_ATTR_COOKIE = 3 74 NLMSGERR_ATTR_POLICY = 4 75 NLMSGERR_ATTR_MISS_TYPE = 5 76 NLMSGERR_ATTR_MISS_NEST = 6 77 78 79class NlError(Exception): 80 def __init__(self, nl_msg): 81 self.nl_msg = nl_msg 82 83 def __str__(self): 84 return f"Netlink error: {os.strerror(-self.nl_msg.error)}\n{self.nl_msg}" 85 86 87class NlAttr: 88 ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) 89 type_formats = { 90 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), 91 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), 92 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")), 93 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")), 94 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")), 95 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")), 96 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")), 97 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q")) 98 } 99 100 def __init__(self, raw, offset): 101 self._len, self._type = struct.unpack("HH", raw[offset : offset + 4]) 102 self.type = self._type & ~Netlink.NLA_TYPE_MASK 103 self.is_nest = self._type & Netlink.NLA_F_NESTED 104 self.payload_len = self._len 105 self.full_len = (self.payload_len + 3) & ~3 106 self.raw = raw[offset + 4 : offset + self.payload_len] 107 108 @classmethod 109 def get_format(cls, attr_type, byte_order=None): 110 format = cls.type_formats[attr_type] 111 if byte_order: 112 return format.big if byte_order == "big-endian" \ 113 else format.little 114 return format.native 115 116 @classmethod 117 def formatted_string(cls, raw, display_hint): 118 if display_hint == 'mac': 119 formatted = ':'.join('%02x' % b for b in raw) 120 elif display_hint == 'hex': 121 formatted = bytes.hex(raw, ' ') 122 elif display_hint in [ 'ipv4', 'ipv6' ]: 123 formatted = format(ipaddress.ip_address(raw)) 124 elif display_hint == 'uuid': 125 formatted = str(uuid.UUID(bytes=raw)) 126 else: 127 formatted = raw 128 return formatted 129 130 def as_scalar(self, attr_type, byte_order=None): 131 format = self.get_format(attr_type, byte_order) 132 return format.unpack(self.raw)[0] 133 134 def as_auto_scalar(self, attr_type, byte_order=None): 135 if len(self.raw) != 4 and len(self.raw) != 8: 136 raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}") 137 real_type = attr_type[0] + str(len(self.raw) * 8) 138 format = self.get_format(real_type, byte_order) 139 return format.unpack(self.raw)[0] 140 141 def as_strz(self): 142 return self.raw.decode('ascii')[:-1] 143 144 def as_bin(self): 145 return self.raw 146 147 def as_c_array(self, type): 148 format = self.get_format(type) 149 return [ x[0] for x in format.iter_unpack(self.raw) ] 150 151 def as_struct(self, members): 152 value = dict() 153 offset = 0 154 for m in members: 155 # TODO: handle non-scalar members 156 if m.type == 'binary': 157 decoded = self.raw[offset : offset + m['len']] 158 offset += m['len'] 159 elif m.type in NlAttr.type_formats: 160 format = self.get_format(m.type, m.byte_order) 161 [ decoded ] = format.unpack_from(self.raw, offset) 162 offset += format.size 163 if m.display_hint: 164 decoded = self.formatted_string(decoded, m.display_hint) 165 value[m.name] = decoded 166 return value 167 168 def __repr__(self): 169 return f"[type:{self.type} len:{self._len}] {self.raw}" 170 171 172class NlAttrs: 173 def __init__(self, msg, offset=0): 174 self.attrs = [] 175 176 while offset < len(msg): 177 attr = NlAttr(msg, offset) 178 offset += attr.full_len 179 self.attrs.append(attr) 180 181 def __iter__(self): 182 yield from self.attrs 183 184 def __repr__(self): 185 msg = '' 186 for a in self.attrs: 187 if msg: 188 msg += '\n' 189 msg += repr(a) 190 return msg 191 192 193class NlMsg: 194 def __init__(self, msg, offset, attr_space=None): 195 self.hdr = msg[offset : offset + 16] 196 197 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ 198 struct.unpack("IHHII", self.hdr) 199 200 self.raw = msg[offset + 16 : offset + self.nl_len] 201 202 self.error = 0 203 self.done = 0 204 205 extack_off = None 206 if self.nl_type == Netlink.NLMSG_ERROR: 207 self.error = struct.unpack("i", self.raw[0:4])[0] 208 self.done = 1 209 extack_off = 20 210 elif self.nl_type == Netlink.NLMSG_DONE: 211 self.done = 1 212 extack_off = 4 213 214 self.extack = None 215 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: 216 self.extack = dict() 217 extack_attrs = NlAttrs(self.raw[extack_off:]) 218 for extack in extack_attrs: 219 if extack.type == Netlink.NLMSGERR_ATTR_MSG: 220 self.extack['msg'] = extack.as_strz() 221 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: 222 self.extack['miss-type'] = extack.as_scalar('u32') 223 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: 224 self.extack['miss-nest'] = extack.as_scalar('u32') 225 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: 226 self.extack['bad-attr-offs'] = extack.as_scalar('u32') 227 else: 228 if 'unknown' not in self.extack: 229 self.extack['unknown'] = [] 230 self.extack['unknown'].append(extack) 231 232 if attr_space: 233 # We don't have the ability to parse nests yet, so only do global 234 if 'miss-type' in self.extack and 'miss-nest' not in self.extack: 235 miss_type = self.extack['miss-type'] 236 if miss_type in attr_space.attrs_by_val: 237 spec = attr_space.attrs_by_val[miss_type] 238 desc = spec['name'] 239 if 'doc' in spec: 240 desc += f" ({spec['doc']})" 241 self.extack['miss-type'] = desc 242 243 def cmd(self): 244 return self.nl_type 245 246 def __repr__(self): 247 msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}\n" 248 if self.error: 249 msg += '\terror: ' + str(self.error) 250 if self.extack: 251 msg += '\textack: ' + repr(self.extack) 252 return msg 253 254 255class NlMsgs: 256 def __init__(self, data, attr_space=None): 257 self.msgs = [] 258 259 offset = 0 260 while offset < len(data): 261 msg = NlMsg(data, offset, attr_space=attr_space) 262 offset += msg.nl_len 263 self.msgs.append(msg) 264 265 def __iter__(self): 266 yield from self.msgs 267 268 269genl_family_name_to_id = None 270 271 272def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): 273 # we prepend length in _genl_msg_finalize() 274 if seq is None: 275 seq = random.randint(1, 1024) 276 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 277 genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0) 278 return nlmsg + genlmsg 279 280 281def _genl_msg_finalize(msg): 282 return struct.pack("I", len(msg) + 4) + msg 283 284 285def _genl_load_families(): 286 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: 287 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 288 289 msg = _genl_msg(Netlink.GENL_ID_CTRL, 290 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, 291 Netlink.CTRL_CMD_GETFAMILY, 1) 292 msg = _genl_msg_finalize(msg) 293 294 sock.send(msg, 0) 295 296 global genl_family_name_to_id 297 genl_family_name_to_id = dict() 298 299 while True: 300 reply = sock.recv(128 * 1024) 301 nms = NlMsgs(reply) 302 for nl_msg in nms: 303 if nl_msg.error: 304 print("Netlink error:", nl_msg.error) 305 return 306 if nl_msg.done: 307 return 308 309 gm = GenlMsg(nl_msg) 310 fam = dict() 311 for attr in NlAttrs(gm.raw): 312 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: 313 fam['id'] = attr.as_scalar('u16') 314 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: 315 fam['name'] = attr.as_strz() 316 elif attr.type == Netlink.CTRL_ATTR_MAXATTR: 317 fam['maxattr'] = attr.as_scalar('u32') 318 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: 319 fam['mcast'] = dict() 320 for entry in NlAttrs(attr.raw): 321 mcast_name = None 322 mcast_id = None 323 for entry_attr in NlAttrs(entry.raw): 324 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: 325 mcast_name = entry_attr.as_strz() 326 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: 327 mcast_id = entry_attr.as_scalar('u32') 328 if mcast_name and mcast_id is not None: 329 fam['mcast'][mcast_name] = mcast_id 330 if 'name' in fam and 'id' in fam: 331 genl_family_name_to_id[fam['name']] = fam 332 333 334class GenlMsg: 335 def __init__(self, nl_msg): 336 self.nl = nl_msg 337 self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0) 338 self.raw = nl_msg.raw[4:] 339 340 def cmd(self): 341 return self.genl_cmd 342 343 def __repr__(self): 344 msg = repr(self.nl) 345 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" 346 for a in self.raw_attrs: 347 msg += '\t\t' + repr(a) + '\n' 348 return msg 349 350 351class NetlinkProtocol: 352 def __init__(self, family_name, proto_num): 353 self.family_name = family_name 354 self.proto_num = proto_num 355 356 def _message(self, nl_type, nl_flags, seq=None): 357 if seq is None: 358 seq = random.randint(1, 1024) 359 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 360 return nlmsg 361 362 def message(self, flags, command, version, seq=None): 363 return self._message(command, flags, seq) 364 365 def _decode(self, nl_msg): 366 return nl_msg 367 368 def decode(self, ynl, nl_msg): 369 msg = self._decode(nl_msg) 370 fixed_header_size = 0 371 if ynl: 372 op = ynl.rsp_by_value[msg.cmd()] 373 fixed_header_size = ynl._fixed_header_size(op.fixed_header) 374 msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size) 375 return msg 376 377 def get_mcast_id(self, mcast_name, mcast_groups): 378 if mcast_name not in mcast_groups: 379 raise Exception(f'Multicast group "{mcast_name}" not present in the spec') 380 return mcast_groups[mcast_name].value 381 382 383class GenlProtocol(NetlinkProtocol): 384 def __init__(self, family_name): 385 super().__init__(family_name, Netlink.NETLINK_GENERIC) 386 387 global genl_family_name_to_id 388 if genl_family_name_to_id is None: 389 _genl_load_families() 390 391 self.genl_family = genl_family_name_to_id[family_name] 392 self.family_id = genl_family_name_to_id[family_name]['id'] 393 394 def message(self, flags, command, version, seq=None): 395 nlmsg = self._message(self.family_id, flags, seq) 396 genlmsg = struct.pack("BBH", command, version, 0) 397 return nlmsg + genlmsg 398 399 def _decode(self, nl_msg): 400 return GenlMsg(nl_msg) 401 402 def get_mcast_id(self, mcast_name, mcast_groups): 403 if mcast_name not in self.genl_family['mcast']: 404 raise Exception(f'Multicast group "{mcast_name}" not present in the family') 405 return self.genl_family['mcast'][mcast_name] 406 407 408# 409# YNL implementation details. 410# 411 412 413class YnlFamily(SpecFamily): 414 def __init__(self, def_path, schema=None, process_unknown=False): 415 super().__init__(def_path, schema) 416 417 self.include_raw = False 418 self.process_unknown = process_unknown 419 420 try: 421 if self.proto == "netlink-raw": 422 self.nlproto = NetlinkProtocol(self.yaml['name'], 423 self.yaml['protonum']) 424 else: 425 self.nlproto = GenlProtocol(self.yaml['name']) 426 except KeyError: 427 raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel") 428 429 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num) 430 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 431 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) 432 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1) 433 434 self.async_msg_ids = set() 435 self.async_msg_queue = [] 436 437 for msg in self.msgs.values(): 438 if msg.is_async: 439 self.async_msg_ids.add(msg.rsp_value) 440 441 for op_name, op in self.ops.items(): 442 bound_f = functools.partial(self._op, op_name) 443 setattr(self, op.ident_name, bound_f) 444 445 446 def ntf_subscribe(self, mcast_name): 447 mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups) 448 self.sock.bind((0, 0)) 449 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, 450 mcast_id) 451 452 def _add_attr(self, space, name, value): 453 try: 454 attr = self.attr_sets[space][name] 455 except KeyError: 456 raise Exception(f"Space '{space}' has no attribute '{name}'") 457 nl_type = attr.value 458 if attr["type"] == 'nest': 459 nl_type |= Netlink.NLA_F_NESTED 460 attr_payload = b'' 461 for subname, subvalue in value.items(): 462 attr_payload += self._add_attr(attr['nested-attributes'], subname, subvalue) 463 elif attr["type"] == 'flag': 464 attr_payload = b'' 465 elif attr["type"] == 'string': 466 attr_payload = str(value).encode('ascii') + b'\x00' 467 elif attr["type"] == 'binary': 468 if isinstance(value, bytes): 469 attr_payload = value 470 elif isinstance(value, str): 471 attr_payload = bytes.fromhex(value) 472 else: 473 raise Exception(f'Unknown type for binary attribute, value: {value}') 474 elif attr.is_auto_scalar: 475 scalar = int(value) 476 real_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64') 477 format = NlAttr.get_format(real_type, attr.byte_order) 478 attr_payload = format.pack(int(value)) 479 elif attr['type'] in NlAttr.type_formats: 480 format = NlAttr.get_format(attr['type'], attr.byte_order) 481 attr_payload = format.pack(int(value)) 482 elif attr['type'] in "bitfield32": 483 attr_payload = struct.pack("II", int(value["value"]), int(value["selector"])) 484 else: 485 raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') 486 487 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) 488 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad 489 490 def _decode_enum(self, raw, attr_spec): 491 enum = self.consts[attr_spec['enum']] 492 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 493 i = 0 494 value = set() 495 while raw: 496 if raw & 1: 497 value.add(enum.entries_by_val[i].name) 498 raw >>= 1 499 i += 1 500 else: 501 value = enum.entries_by_val[raw].name 502 return value 503 504 def _decode_binary(self, attr, attr_spec): 505 if attr_spec.struct_name: 506 members = self.consts[attr_spec.struct_name] 507 decoded = attr.as_struct(members) 508 for m in members: 509 if m.enum: 510 decoded[m.name] = self._decode_enum(decoded[m.name], m) 511 elif attr_spec.sub_type: 512 decoded = attr.as_c_array(attr_spec.sub_type) 513 else: 514 decoded = attr.as_bin() 515 if attr_spec.display_hint: 516 decoded = NlAttr.formatted_string(decoded, attr_spec.display_hint) 517 return decoded 518 519 def _decode_array_nest(self, attr, attr_spec): 520 decoded = [] 521 offset = 0 522 while offset < len(attr.raw): 523 item = NlAttr(attr.raw, offset) 524 offset += item.full_len 525 526 subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes']) 527 decoded.append({ item.type: subattrs }) 528 return decoded 529 530 def _decode_unknown(self, attr): 531 if attr.is_nest: 532 return self._decode(NlAttrs(attr.raw), None) 533 else: 534 return attr.as_bin() 535 536 def _rsp_add(self, rsp, name, is_multi, decoded): 537 if is_multi == None: 538 if name in rsp and type(rsp[name]) is not list: 539 rsp[name] = [rsp[name]] 540 is_multi = True 541 else: 542 is_multi = False 543 544 if not is_multi: 545 rsp[name] = decoded 546 elif name in rsp: 547 rsp[name].append(decoded) 548 else: 549 rsp[name] = [decoded] 550 551 def _resolve_selector(self, attr_spec, vals): 552 sub_msg = attr_spec.sub_message 553 if sub_msg not in self.sub_msgs: 554 raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}") 555 sub_msg_spec = self.sub_msgs[sub_msg] 556 557 selector = attr_spec.selector 558 if selector not in vals: 559 raise Exception(f"There is no value for {selector} to resolve '{attr_spec.name}'") 560 value = vals[selector] 561 if value not in sub_msg_spec.formats: 562 raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'") 563 564 spec = sub_msg_spec.formats[value] 565 return spec 566 567 def _decode_sub_msg(self, attr, attr_spec, rsp): 568 msg_format = self._resolve_selector(attr_spec, rsp) 569 decoded = {} 570 offset = 0 571 if msg_format.fixed_header: 572 decoded.update(self._decode_fixed_header(attr, msg_format.fixed_header)); 573 offset = self._fixed_header_size(msg_format.fixed_header) 574 if msg_format.attr_set: 575 if msg_format.attr_set in self.attr_sets: 576 subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set) 577 decoded.update(subdict) 578 else: 579 raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'") 580 return decoded 581 582 def _decode(self, attrs, space): 583 if space: 584 attr_space = self.attr_sets[space] 585 rsp = dict() 586 for attr in attrs: 587 try: 588 attr_spec = attr_space.attrs_by_val[attr.type] 589 except (KeyError, UnboundLocalError): 590 if not self.process_unknown: 591 raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'") 592 attr_name = f"UnknownAttr({attr.type})" 593 self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr)) 594 continue 595 596 if attr_spec["type"] == 'nest': 597 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes']) 598 decoded = subdict 599 elif attr_spec["type"] == 'string': 600 decoded = attr.as_strz() 601 elif attr_spec["type"] == 'binary': 602 decoded = self._decode_binary(attr, attr_spec) 603 elif attr_spec["type"] == 'flag': 604 decoded = True 605 elif attr_spec.is_auto_scalar: 606 decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order) 607 elif attr_spec["type"] in NlAttr.type_formats: 608 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) 609 if 'enum' in attr_spec: 610 decoded = self._decode_enum(decoded, attr_spec) 611 elif attr_spec["type"] == 'array-nest': 612 decoded = self._decode_array_nest(attr, attr_spec) 613 elif attr_spec["type"] == 'bitfield32': 614 value, selector = struct.unpack("II", attr.raw) 615 if 'enum' in attr_spec: 616 value = self._decode_enum(value, attr_spec) 617 selector = self._decode_enum(selector, attr_spec) 618 decoded = {"value": value, "selector": selector} 619 elif attr_spec["type"] == 'sub-message': 620 decoded = self._decode_sub_msg(attr, attr_spec, rsp) 621 else: 622 if not self.process_unknown: 623 raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') 624 decoded = self._decode_unknown(attr) 625 626 self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded) 627 628 return rsp 629 630 def _decode_extack_path(self, attrs, attr_set, offset, target): 631 for attr in attrs: 632 try: 633 attr_spec = attr_set.attrs_by_val[attr.type] 634 except KeyError: 635 raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") 636 if offset > target: 637 break 638 if offset == target: 639 return '.' + attr_spec.name 640 641 if offset + attr.full_len <= target: 642 offset += attr.full_len 643 continue 644 if attr_spec['type'] != 'nest': 645 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") 646 offset += 4 647 subpath = self._decode_extack_path(NlAttrs(attr.raw), 648 self.attr_sets[attr_spec['nested-attributes']], 649 offset, target) 650 if subpath is None: 651 return None 652 return '.' + attr_spec.name + subpath 653 654 return None 655 656 def _decode_extack(self, request, op, extack): 657 if 'bad-attr-offs' not in extack: 658 return 659 660 msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set)) 661 offset = 20 + self._fixed_header_size(op.fixed_header) 662 path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset, 663 extack['bad-attr-offs']) 664 if path: 665 del extack['bad-attr-offs'] 666 extack['bad-attr'] = path 667 668 def _fixed_header_size(self, name): 669 if name: 670 fixed_header_members = self.consts[name].members 671 size = 0 672 for m in fixed_header_members: 673 if m.type in ['pad', 'binary']: 674 size += m.len 675 else: 676 format = NlAttr.get_format(m.type, m.byte_order) 677 size += format.size 678 return size 679 else: 680 return 0 681 682 def _decode_fixed_header(self, msg, name): 683 fixed_header_members = self.consts[name].members 684 fixed_header_attrs = dict() 685 offset = 0 686 for m in fixed_header_members: 687 value = None 688 if m.type == 'pad': 689 offset += m.len 690 elif m.type == 'binary': 691 value = msg.raw[offset : offset + m.len] 692 offset += m.len 693 else: 694 format = NlAttr.get_format(m.type, m.byte_order) 695 [ value ] = format.unpack_from(msg.raw, offset) 696 offset += format.size 697 if value is not None: 698 if m.enum: 699 value = self._decode_enum(value, m) 700 fixed_header_attrs[m.name] = value 701 return fixed_header_attrs 702 703 def handle_ntf(self, decoded): 704 msg = dict() 705 if self.include_raw: 706 msg['raw'] = decoded 707 op = self.rsp_by_value[decoded.cmd()] 708 attrs = self._decode(decoded.raw_attrs, op.attr_set.name) 709 if op.fixed_header: 710 attrs.update(self._decode_fixed_header(decoded, op.fixed_header)) 711 712 msg['name'] = op['name'] 713 msg['msg'] = attrs 714 self.async_msg_queue.append(msg) 715 716 def check_ntf(self): 717 while True: 718 try: 719 reply = self.sock.recv(128 * 1024, socket.MSG_DONTWAIT) 720 except BlockingIOError: 721 return 722 723 nms = NlMsgs(reply) 724 for nl_msg in nms: 725 if nl_msg.error: 726 print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) 727 print(nl_msg) 728 continue 729 if nl_msg.done: 730 print("Netlink done while checking for ntf!?") 731 continue 732 733 decoded = self.nlproto.decode(self, nl_msg) 734 if decoded.cmd() not in self.async_msg_ids: 735 print("Unexpected msg id done while checking for ntf", decoded) 736 continue 737 738 self.handle_ntf(decoded) 739 740 def operation_do_attributes(self, name): 741 """ 742 For a given operation name, find and return a supported 743 set of attributes (as a dict). 744 """ 745 op = self.find_operation(name) 746 if not op: 747 return None 748 749 return op['do']['request']['attributes'].copy() 750 751 def _op(self, method, vals, flags=None, dump=False): 752 op = self.ops[method] 753 754 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK 755 for flag in flags or []: 756 nl_flags |= flag 757 if dump: 758 nl_flags |= Netlink.NLM_F_DUMP 759 760 req_seq = random.randint(1024, 65535) 761 msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq) 762 fixed_header_members = [] 763 if op.fixed_header: 764 fixed_header_members = self.consts[op.fixed_header].members 765 for m in fixed_header_members: 766 value = vals.pop(m.name) if m.name in vals else 0 767 if m.type == 'pad': 768 msg += bytearray(m.len) 769 elif m.type == 'binary': 770 msg += bytes.fromhex(value) 771 else: 772 format = NlAttr.get_format(m.type, m.byte_order) 773 msg += format.pack(value) 774 for name, value in vals.items(): 775 msg += self._add_attr(op.attr_set.name, name, value) 776 msg = _genl_msg_finalize(msg) 777 778 self.sock.send(msg, 0) 779 780 done = False 781 rsp = [] 782 while not done: 783 reply = self.sock.recv(128 * 1024) 784 nms = NlMsgs(reply, attr_space=op.attr_set) 785 for nl_msg in nms: 786 if nl_msg.extack: 787 self._decode_extack(msg, op, nl_msg.extack) 788 789 if nl_msg.error: 790 raise NlError(nl_msg) 791 if nl_msg.done: 792 if nl_msg.extack: 793 print("Netlink warning:") 794 print(nl_msg) 795 done = True 796 break 797 798 decoded = self.nlproto.decode(self, nl_msg) 799 800 # Check if this is a reply to our request 801 if nl_msg.nl_seq != req_seq or decoded.cmd() != op.rsp_value: 802 if decoded.cmd() in self.async_msg_ids: 803 self.handle_ntf(decoded) 804 continue 805 else: 806 print('Unexpected message: ' + repr(decoded)) 807 continue 808 809 rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name) 810 if op.fixed_header: 811 rsp_msg.update(self._decode_fixed_header(decoded, op.fixed_header)) 812 rsp.append(rsp_msg) 813 814 if not rsp: 815 return None 816 if not dump and len(rsp) == 1: 817 return rsp[0] 818 return rsp 819 820 def do(self, method, vals, flags=None): 821 return self._op(method, vals, flags) 822 823 def dump(self, method, vals): 824 return self._op(method, vals, [], dump=True) 825