1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause 2# 3# pylint: disable=missing-class-docstring, missing-function-docstring 4# pylint: disable=too-many-branches, too-many-locals, too-many-instance-attributes 5# pylint: disable=too-many-lines 6 7""" 8YAML Netlink Library 9 10An implementation of the genetlink and raw netlink protocols. 11""" 12 13from collections import namedtuple 14from enum import Enum 15import functools 16import os 17import random 18import socket 19import struct 20from struct import Struct 21import sys 22import ipaddress 23import uuid 24import queue 25import selectors 26import time 27 28from .nlspec import SpecFamily 29 30# 31# Generic Netlink code which should really be in some library, but I can't quickly find one. 32# 33 34 35class YnlException(Exception): 36 pass 37 38 39# pylint: disable=too-few-public-methods 40class Netlink: 41 # Netlink socket 42 SOL_NETLINK = 270 43 44 NETLINK_ADD_MEMBERSHIP = 1 45 NETLINK_CAP_ACK = 10 46 NETLINK_EXT_ACK = 11 47 NETLINK_GET_STRICT_CHK = 12 48 49 # Netlink message 50 NLMSG_ERROR = 2 51 NLMSG_DONE = 3 52 53 NLM_F_REQUEST = 1 54 NLM_F_ACK = 4 55 NLM_F_ROOT = 0x100 56 NLM_F_MATCH = 0x200 57 58 NLM_F_REPLACE = 0x100 59 NLM_F_EXCL = 0x200 60 NLM_F_CREATE = 0x400 61 NLM_F_APPEND = 0x800 62 63 NLM_F_CAPPED = 0x100 64 NLM_F_ACK_TLVS = 0x200 65 66 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH 67 68 NLA_F_NESTED = 0x8000 69 NLA_F_NET_BYTEORDER = 0x4000 70 71 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER 72 73 # Genetlink defines 74 NETLINK_GENERIC = 16 75 76 GENL_ID_CTRL = 0x10 77 78 # nlctrl 79 CTRL_CMD_GETFAMILY = 3 80 81 CTRL_ATTR_FAMILY_ID = 1 82 CTRL_ATTR_FAMILY_NAME = 2 83 CTRL_ATTR_MAXATTR = 5 84 CTRL_ATTR_MCAST_GROUPS = 7 85 86 CTRL_ATTR_MCAST_GRP_NAME = 1 87 CTRL_ATTR_MCAST_GRP_ID = 2 88 89 # Extack types 90 NLMSGERR_ATTR_MSG = 1 91 NLMSGERR_ATTR_OFFS = 2 92 NLMSGERR_ATTR_COOKIE = 3 93 NLMSGERR_ATTR_POLICY = 4 94 NLMSGERR_ATTR_MISS_TYPE = 5 95 NLMSGERR_ATTR_MISS_NEST = 6 96 97 # Policy types 98 NL_POLICY_TYPE_ATTR_TYPE = 1 99 NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2 100 NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3 101 NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4 102 NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5 103 NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6 104 NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7 105 NL_POLICY_TYPE_ATTR_POLICY_IDX = 8 106 NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9 107 NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10 108 NL_POLICY_TYPE_ATTR_PAD = 11 109 NL_POLICY_TYPE_ATTR_MASK = 12 110 111 AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64', 112 's8', 's16', 's32', 's64', 113 'binary', 'string', 'nul-string', 114 'nested', 'nested-array', 115 'bitfield32', 'sint', 'uint']) 116 117class NlError(Exception): 118 def __init__(self, nl_msg): 119 self.nl_msg = nl_msg 120 self.error = -nl_msg.error 121 122 def __str__(self): 123 msg = "Netlink error: " 124 125 extack = self.nl_msg.extack.copy() if self.nl_msg.extack else {} 126 if 'msg' in extack: 127 msg += extack['msg'] + ': ' 128 del extack['msg'] 129 msg += os.strerror(self.error) 130 if extack: 131 msg += ' ' + str(extack) 132 return msg 133 134 135class ConfigError(Exception): 136 pass 137 138 139class NlAttr: 140 ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) 141 type_formats = { 142 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), 143 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), 144 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")), 145 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")), 146 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")), 147 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")), 148 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")), 149 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q")) 150 } 151 152 def __init__(self, raw, offset): 153 self._len, self._type = struct.unpack("HH", raw[offset : offset + 4]) 154 self.type = self._type & ~Netlink.NLA_TYPE_MASK 155 self.is_nest = self._type & Netlink.NLA_F_NESTED 156 self.payload_len = self._len 157 self.full_len = (self.payload_len + 3) & ~3 158 self.raw = raw[offset + 4 : offset + self.payload_len] 159 160 @classmethod 161 def get_format(cls, attr_type, byte_order=None): 162 format_ = cls.type_formats[attr_type] 163 if byte_order: 164 return format_.big if byte_order == "big-endian" \ 165 else format_.little 166 return format_.native 167 168 def as_scalar(self, attr_type, byte_order=None): 169 format_ = self.get_format(attr_type, byte_order) 170 return format_.unpack(self.raw)[0] 171 172 def as_auto_scalar(self, attr_type, byte_order=None): 173 if len(self.raw) != 4 and len(self.raw) != 8: 174 raise YnlException(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}") 175 real_type = attr_type[0] + str(len(self.raw) * 8) 176 format_ = self.get_format(real_type, byte_order) 177 return format_.unpack(self.raw)[0] 178 179 def as_strz(self): 180 return self.raw.decode('ascii')[:-1] 181 182 def as_bin(self): 183 return self.raw 184 185 def as_c_array(self, c_type): 186 format_ = self.get_format(c_type) 187 return [ x[0] for x in format_.iter_unpack(self.raw) ] 188 189 def __repr__(self): 190 return f"[type:{self.type} len:{self._len}] {self.raw}" 191 192 193class NlAttrs: 194 def __init__(self, msg, offset=0): 195 self.attrs = [] 196 197 while offset < len(msg): 198 attr = NlAttr(msg, offset) 199 offset += attr.full_len 200 self.attrs.append(attr) 201 202 def __iter__(self): 203 yield from self.attrs 204 205 def __repr__(self): 206 msg = '' 207 for a in self.attrs: 208 if msg: 209 msg += '\n' 210 msg += repr(a) 211 return msg 212 213 214class NlMsg: 215 def __init__(self, msg, offset, attr_space=None): 216 self.hdr = msg[offset : offset + 16] 217 218 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ 219 struct.unpack("IHHII", self.hdr) 220 221 self.raw = msg[offset + 16 : offset + self.nl_len] 222 223 self.error = 0 224 self.done = 0 225 226 extack_off = None 227 if self.nl_type == Netlink.NLMSG_ERROR: 228 self.error = struct.unpack("i", self.raw[0:4])[0] 229 self.done = 1 230 extack_off = 20 231 elif self.nl_type == Netlink.NLMSG_DONE: 232 self.error = struct.unpack("i", self.raw[0:4])[0] 233 self.done = 1 234 extack_off = 4 235 236 self.extack = None 237 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: 238 self.extack = {} 239 extack_attrs = NlAttrs(self.raw[extack_off:]) 240 for extack in extack_attrs: 241 if extack.type == Netlink.NLMSGERR_ATTR_MSG: 242 self.extack['msg'] = extack.as_strz() 243 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: 244 self.extack['miss-type'] = extack.as_scalar('u32') 245 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: 246 self.extack['miss-nest'] = extack.as_scalar('u32') 247 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: 248 self.extack['bad-attr-offs'] = extack.as_scalar('u32') 249 elif extack.type == Netlink.NLMSGERR_ATTR_POLICY: 250 self.extack['policy'] = self._decode_policy(extack.raw) 251 else: 252 if 'unknown' not in self.extack: 253 self.extack['unknown'] = [] 254 self.extack['unknown'].append(extack) 255 256 if attr_space: 257 self.annotate_extack(attr_space) 258 259 def _decode_policy(self, raw): 260 policy = {} 261 for attr in NlAttrs(raw): 262 if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE: 263 type_ = attr.as_scalar('u32') 264 policy['type'] = Netlink.AttrType(type_).name 265 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S: 266 policy['min-value'] = attr.as_scalar('s64') 267 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S: 268 policy['max-value'] = attr.as_scalar('s64') 269 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U: 270 policy['min-value'] = attr.as_scalar('u64') 271 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U: 272 policy['max-value'] = attr.as_scalar('u64') 273 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH: 274 policy['min-length'] = attr.as_scalar('u32') 275 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH: 276 policy['max-length'] = attr.as_scalar('u32') 277 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK: 278 policy['bitfield32-mask'] = attr.as_scalar('u32') 279 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK: 280 policy['mask'] = attr.as_scalar('u64') 281 return policy 282 283 def annotate_extack(self, attr_space): 284 """ Make extack more human friendly with attribute information """ 285 286 # We don't have the ability to parse nests yet, so only do global 287 if 'miss-type' in self.extack and 'miss-nest' not in self.extack: 288 miss_type = self.extack['miss-type'] 289 if miss_type in attr_space.attrs_by_val: 290 spec = attr_space.attrs_by_val[miss_type] 291 self.extack['miss-type'] = spec['name'] 292 if 'doc' in spec: 293 self.extack['miss-type-doc'] = spec['doc'] 294 295 def cmd(self): 296 return self.nl_type 297 298 def __repr__(self): 299 msg = (f"nl_len = {self.nl_len} ({len(self.raw)}) " 300 f"nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}") 301 if self.error: 302 msg += '\n\terror: ' + str(self.error) 303 if self.extack: 304 msg += '\n\textack: ' + repr(self.extack) 305 return msg 306 307 308# pylint: disable=too-few-public-methods 309class NlMsgs: 310 def __init__(self, data): 311 self.msgs = [] 312 313 offset = 0 314 while offset < len(data): 315 msg = NlMsg(data, offset) 316 offset += msg.nl_len 317 self.msgs.append(msg) 318 319 def __iter__(self): 320 yield from self.msgs 321 322 323def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): 324 # we prepend length in _genl_msg_finalize() 325 if seq is None: 326 seq = random.randint(1, 1024) 327 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 328 genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0) 329 return nlmsg + genlmsg 330 331 332def _genl_msg_finalize(msg): 333 return struct.pack("I", len(msg) + 4) + msg 334 335 336# pylint: disable=too-many-nested-blocks 337def _genl_load_families(): 338 genl_family_name_to_id = {} 339 340 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: 341 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 342 343 msg = _genl_msg(Netlink.GENL_ID_CTRL, 344 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, 345 Netlink.CTRL_CMD_GETFAMILY, 1) 346 msg = _genl_msg_finalize(msg) 347 348 sock.send(msg, 0) 349 350 while True: 351 reply = sock.recv(128 * 1024) 352 nms = NlMsgs(reply) 353 for nl_msg in nms: 354 if nl_msg.error: 355 raise YnlException(f"Netlink error: {nl_msg.error}") 356 if nl_msg.done: 357 return genl_family_name_to_id 358 359 gm = GenlMsg(nl_msg) 360 fam = {} 361 for attr in NlAttrs(gm.raw): 362 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: 363 fam['id'] = attr.as_scalar('u16') 364 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: 365 fam['name'] = attr.as_strz() 366 elif attr.type == Netlink.CTRL_ATTR_MAXATTR: 367 fam['maxattr'] = attr.as_scalar('u32') 368 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: 369 fam['mcast'] = {} 370 for entry in NlAttrs(attr.raw): 371 mcast_name = None 372 mcast_id = None 373 for entry_attr in NlAttrs(entry.raw): 374 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: 375 mcast_name = entry_attr.as_strz() 376 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: 377 mcast_id = entry_attr.as_scalar('u32') 378 if mcast_name and mcast_id is not None: 379 fam['mcast'][mcast_name] = mcast_id 380 if 'name' in fam and 'id' in fam: 381 genl_family_name_to_id[fam['name']] = fam 382 383 384class GenlMsg: 385 def __init__(self, nl_msg): 386 self.nl = nl_msg 387 self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0) 388 self.raw = nl_msg.raw[4:] 389 self.raw_attrs = [] 390 391 def cmd(self): 392 return self.genl_cmd 393 394 def __repr__(self): 395 msg = repr(self.nl) 396 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" 397 for a in self.raw_attrs: 398 msg += '\t\t' + repr(a) + '\n' 399 return msg 400 401 402class NetlinkProtocol: 403 def __init__(self, family_name, proto_num): 404 self.family_name = family_name 405 self.proto_num = proto_num 406 407 def _message(self, nl_type, nl_flags, seq=None): 408 if seq is None: 409 seq = random.randint(1, 1024) 410 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 411 return nlmsg 412 413 def message(self, flags, command, _version, seq=None): 414 return self._message(command, flags, seq) 415 416 def _decode(self, nl_msg): 417 return nl_msg 418 419 def decode(self, ynl, nl_msg, op): 420 msg = self._decode(nl_msg) 421 if op is None: 422 op = ynl.rsp_by_value[msg.cmd()] 423 fixed_header_size = ynl.struct_size(op.fixed_header) 424 msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size) 425 return msg 426 427 def get_mcast_id(self, mcast_name, mcast_groups): 428 if mcast_name not in mcast_groups: 429 raise YnlException(f'Multicast group "{mcast_name}" not present in the spec') 430 return mcast_groups[mcast_name].value 431 432 def msghdr_size(self): 433 return 16 434 435 436class GenlProtocol(NetlinkProtocol): 437 genl_family_name_to_id = {} 438 439 def __init__(self, family_name): 440 super().__init__(family_name, Netlink.NETLINK_GENERIC) 441 442 if not GenlProtocol.genl_family_name_to_id: 443 GenlProtocol.genl_family_name_to_id = _genl_load_families() 444 445 self.genl_family = GenlProtocol.genl_family_name_to_id[family_name] 446 self.family_id = GenlProtocol.genl_family_name_to_id[family_name]['id'] 447 448 def message(self, flags, command, version, seq=None): 449 nlmsg = self._message(self.family_id, flags, seq) 450 genlmsg = struct.pack("BBH", command, version, 0) 451 return nlmsg + genlmsg 452 453 def _decode(self, nl_msg): 454 return GenlMsg(nl_msg) 455 456 def get_mcast_id(self, mcast_name, mcast_groups): 457 if mcast_name not in self.genl_family['mcast']: 458 raise YnlException(f'Multicast group "{mcast_name}" not present in the family') 459 return self.genl_family['mcast'][mcast_name] 460 461 def msghdr_size(self): 462 return super().msghdr_size() + 4 463 464 465# pylint: disable=too-few-public-methods 466class SpaceAttrs: 467 SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values']) 468 469 def __init__(self, attr_space, attrs, outer = None): 470 outer_scopes = outer.scopes if outer else [] 471 inner_scope = self.SpecValuesPair(attr_space, attrs) 472 self.scopes = [inner_scope] + outer_scopes 473 474 def lookup(self, name): 475 for scope in self.scopes: 476 if name in scope.spec: 477 if name in scope.values: 478 return scope.values[name] 479 spec_name = scope.spec.yaml['name'] 480 raise YnlException( 481 f"No value for '{name}' in attribute space '{spec_name}'") 482 raise YnlException(f"Attribute '{name}' not defined in any attribute-set") 483 484 485# 486# YNL implementation details. 487# 488 489 490class YnlFamily(SpecFamily): 491 def __init__(self, def_path, schema=None, process_unknown=False, 492 recv_size=0): 493 super().__init__(def_path, schema) 494 495 self.include_raw = False 496 self.process_unknown = process_unknown 497 498 try: 499 if self.proto == "netlink-raw": 500 self.nlproto = NetlinkProtocol(self.yaml['name'], 501 self.yaml['protonum']) 502 else: 503 self.nlproto = GenlProtocol(self.yaml['name']) 504 except KeyError as err: 505 raise YnlException(f"Family '{self.yaml['name']}' not supported by the kernel") from err 506 507 self._recv_dbg = False 508 # Note that netlink will use conservative (min) message size for 509 # the first dump recv() on the socket, our setting will only matter 510 # from the second recv() on. 511 self._recv_size = recv_size if recv_size else 131072 512 # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo) 513 # for a message, so smaller receive sizes will lead to truncation. 514 # Note that the min size for other families may be larger than 4k! 515 if self._recv_size < 4000: 516 raise ConfigError() 517 518 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num) 519 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 520 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) 521 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1) 522 523 self.async_msg_ids = set() 524 self.async_msg_queue = queue.Queue() 525 526 for msg in self.msgs.values(): 527 if msg.is_async: 528 self.async_msg_ids.add(msg.rsp_value) 529 530 for op_name, op in self.ops.items(): 531 bound_f = functools.partial(self._op, op_name) 532 setattr(self, op.ident_name, bound_f) 533 534 535 def ntf_subscribe(self, mcast_name): 536 mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups) 537 self.sock.bind((0, 0)) 538 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, 539 mcast_id) 540 541 def set_recv_dbg(self, enabled): 542 self._recv_dbg = enabled 543 544 def _recv_dbg_print(self, reply, nl_msgs): 545 if not self._recv_dbg: 546 return 547 print("Recv: read", len(reply), "bytes,", 548 len(nl_msgs.msgs), "messages", file=sys.stderr) 549 for nl_msg in nl_msgs: 550 print(" ", nl_msg, file=sys.stderr) 551 552 def _encode_enum(self, attr_spec, value): 553 enum = self.consts[attr_spec['enum']] 554 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 555 scalar = 0 556 if isinstance(value, str): 557 value = [value] 558 for single_value in value: 559 scalar += enum.entries[single_value].user_value(as_flags = True) 560 return scalar 561 return enum.entries[value].user_value() 562 563 def _get_scalar(self, attr_spec, value): 564 try: 565 return int(value) 566 except (ValueError, TypeError) as e: 567 if 'enum' in attr_spec: 568 return self._encode_enum(attr_spec, value) 569 if attr_spec.display_hint: 570 return self._from_string(value, attr_spec) 571 raise e 572 573 # pylint: disable=too-many-statements 574 def _add_attr(self, space, name, value, search_attrs): 575 try: 576 attr = self.attr_sets[space][name] 577 except KeyError as err: 578 raise YnlException(f"Space '{space}' has no attribute '{name}'") from err 579 nl_type = attr.value 580 581 if attr.is_multi and isinstance(value, list): 582 attr_payload = b'' 583 for subvalue in value: 584 attr_payload += self._add_attr(space, name, subvalue, search_attrs) 585 return attr_payload 586 587 if attr["type"] == 'nest': 588 nl_type |= Netlink.NLA_F_NESTED 589 sub_space = attr['nested-attributes'] 590 attr_payload = self._add_nest_attrs(value, sub_space, search_attrs) 591 elif attr['type'] == 'indexed-array' and attr['sub-type'] == 'nest': 592 nl_type |= Netlink.NLA_F_NESTED 593 sub_space = attr['nested-attributes'] 594 attr_payload = self._encode_indexed_array(value, sub_space, 595 search_attrs) 596 elif attr["type"] == 'flag': 597 if not value: 598 # If value is absent or false then skip attribute creation. 599 return b'' 600 attr_payload = b'' 601 elif attr["type"] == 'string': 602 attr_payload = str(value).encode('ascii') + b'\x00' 603 elif attr["type"] == 'binary': 604 if value is None: 605 attr_payload = b'' 606 elif isinstance(value, bytes): 607 attr_payload = value 608 elif isinstance(value, str): 609 if attr.display_hint: 610 attr_payload = self._from_string(value, attr) 611 else: 612 attr_payload = bytes.fromhex(value) 613 elif isinstance(value, dict) and attr.struct_name: 614 attr_payload = self._encode_struct(attr.struct_name, value) 615 elif isinstance(value, list) and attr.sub_type in NlAttr.type_formats: 616 format_ = NlAttr.get_format(attr.sub_type) 617 attr_payload = b''.join([format_.pack(x) for x in value]) 618 else: 619 raise YnlException(f'Unknown type for binary attribute, value: {value}') 620 elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar: 621 scalar = self._get_scalar(attr, value) 622 if attr.is_auto_scalar: 623 attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64') 624 else: 625 attr_type = attr["type"] 626 format_ = NlAttr.get_format(attr_type, attr.byte_order) 627 attr_payload = format_.pack(scalar) 628 elif attr['type'] in "bitfield32": 629 scalar_value = self._get_scalar(attr, value["value"]) 630 scalar_selector = self._get_scalar(attr, value["selector"]) 631 attr_payload = struct.pack("II", scalar_value, scalar_selector) 632 elif attr['type'] == 'sub-message': 633 msg_format, _ = self._resolve_selector(attr, search_attrs) 634 attr_payload = b'' 635 if msg_format.fixed_header: 636 attr_payload += self._encode_struct(msg_format.fixed_header, value) 637 if msg_format.attr_set: 638 if msg_format.attr_set in self.attr_sets: 639 nl_type |= Netlink.NLA_F_NESTED 640 sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs) 641 for subname, subvalue in value.items(): 642 attr_payload += self._add_attr(msg_format.attr_set, 643 subname, subvalue, sub_attrs) 644 else: 645 raise YnlException(f"Unknown attribute-set '{msg_format.attr_set}'") 646 else: 647 raise YnlException(f'Unknown type at {space} {name} {value} {attr["type"]}') 648 649 return self._add_attr_raw(nl_type, attr_payload) 650 651 def _add_attr_raw(self, nl_type, attr_payload): 652 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) 653 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad 654 655 def _add_nest_attrs(self, value, sub_space, search_attrs): 656 sub_attrs = SpaceAttrs(self.attr_sets[sub_space], value, search_attrs) 657 attr_payload = b'' 658 for subname, subvalue in value.items(): 659 attr_payload += self._add_attr(sub_space, subname, subvalue, 660 sub_attrs) 661 return attr_payload 662 663 def _encode_indexed_array(self, vals, sub_space, search_attrs): 664 attr_payload = b'' 665 for i, val in enumerate(vals): 666 idx = i | Netlink.NLA_F_NESTED 667 val_payload = self._add_nest_attrs(val, sub_space, search_attrs) 668 attr_payload += self._add_attr_raw(idx, val_payload) 669 return attr_payload 670 671 def _get_enum_or_unknown(self, enum, raw): 672 try: 673 name = enum.entries_by_val[raw].name 674 except KeyError as error: 675 if self.process_unknown: 676 name = f"Unknown({raw})" 677 else: 678 raise error 679 return name 680 681 def _decode_enum(self, raw, attr_spec): 682 enum = self.consts[attr_spec['enum']] 683 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 684 i = 0 685 value = set() 686 while raw: 687 if raw & 1: 688 value.add(self._get_enum_or_unknown(enum, i)) 689 raw >>= 1 690 i += 1 691 else: 692 value = self._get_enum_or_unknown(enum, raw) 693 return value 694 695 def _decode_binary(self, attr, attr_spec): 696 if attr_spec.struct_name: 697 decoded = self._decode_struct(attr.raw, attr_spec.struct_name) 698 elif attr_spec.sub_type: 699 decoded = attr.as_c_array(attr_spec.sub_type) 700 if 'enum' in attr_spec: 701 decoded = [ self._decode_enum(x, attr_spec) for x in decoded ] 702 elif attr_spec.display_hint: 703 decoded = [ self._formatted_string(x, attr_spec.display_hint) 704 for x in decoded ] 705 else: 706 decoded = attr.as_bin() 707 if attr_spec.display_hint: 708 decoded = self._formatted_string(decoded, attr_spec.display_hint) 709 return decoded 710 711 def _decode_array_attr(self, attr, attr_spec): 712 decoded = [] 713 offset = 0 714 while offset < len(attr.raw): 715 item = NlAttr(attr.raw, offset) 716 offset += item.full_len 717 718 if attr_spec["sub-type"] == 'nest': 719 subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes']) 720 decoded.append({ item.type: subattrs }) 721 elif attr_spec["sub-type"] == 'binary': 722 subattr = item.as_bin() 723 if attr_spec.display_hint: 724 subattr = self._formatted_string(subattr, attr_spec.display_hint) 725 decoded.append(subattr) 726 elif attr_spec["sub-type"] in NlAttr.type_formats: 727 subattr = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order) 728 if 'enum' in attr_spec: 729 subattr = self._decode_enum(subattr, attr_spec) 730 elif attr_spec.display_hint: 731 subattr = self._formatted_string(subattr, attr_spec.display_hint) 732 decoded.append(subattr) 733 else: 734 raise YnlException(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}') 735 return decoded 736 737 def _decode_nest_type_value(self, attr, attr_spec): 738 decoded = {} 739 value = attr 740 for name in attr_spec['type-value']: 741 value = NlAttr(value.raw, 0) 742 decoded[name] = value.type 743 subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes']) 744 decoded.update(subattrs) 745 return decoded 746 747 def _decode_unknown(self, attr): 748 if attr.is_nest: 749 return self._decode(NlAttrs(attr.raw), None) 750 return attr.as_bin() 751 752 def _rsp_add(self, rsp, name, is_multi, decoded): 753 if is_multi is None: 754 if name in rsp and not isinstance(rsp[name], list): 755 rsp[name] = [rsp[name]] 756 is_multi = True 757 else: 758 is_multi = False 759 760 if not is_multi: 761 rsp[name] = decoded 762 elif name in rsp: 763 rsp[name].append(decoded) 764 else: 765 rsp[name] = [decoded] 766 767 def _resolve_selector(self, attr_spec, search_attrs): 768 sub_msg = attr_spec.sub_message 769 if sub_msg not in self.sub_msgs: 770 raise YnlException(f"No sub-message spec named {sub_msg} for {attr_spec.name}") 771 sub_msg_spec = self.sub_msgs[sub_msg] 772 773 selector = attr_spec.selector 774 value = search_attrs.lookup(selector) 775 if value not in sub_msg_spec.formats: 776 raise YnlException(f"No message format for '{value}' in sub-message spec '{sub_msg}'") 777 778 spec = sub_msg_spec.formats[value] 779 return spec, value 780 781 def _decode_sub_msg(self, attr, attr_spec, search_attrs): 782 msg_format, _ = self._resolve_selector(attr_spec, search_attrs) 783 decoded = {} 784 offset = 0 785 if msg_format.fixed_header: 786 decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header)) 787 offset = self.struct_size(msg_format.fixed_header) 788 if msg_format.attr_set: 789 if msg_format.attr_set in self.attr_sets: 790 subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set) 791 decoded.update(subdict) 792 else: 793 raise YnlException(f"Unknown attribute-set '{msg_format.attr_set}' " 794 f"when decoding '{attr_spec.name}'") 795 return decoded 796 797 # pylint: disable=too-many-statements 798 def _decode(self, attrs, space, outer_attrs = None): 799 rsp = {} 800 search_attrs = {} 801 if space: 802 attr_space = self.attr_sets[space] 803 search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs) 804 805 for attr in attrs: 806 try: 807 attr_spec = attr_space.attrs_by_val[attr.type] 808 except (KeyError, UnboundLocalError) as err: 809 if not self.process_unknown: 810 raise YnlException(f"Space '{space}' has no attribute " 811 f"with value '{attr.type}'") from err 812 attr_name = f"UnknownAttr({attr.type})" 813 self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr)) 814 continue 815 816 try: 817 if attr_spec["type"] == 'nest': 818 subdict = self._decode(NlAttrs(attr.raw), 819 attr_spec['nested-attributes'], 820 search_attrs) 821 decoded = subdict 822 elif attr_spec["type"] == 'string': 823 decoded = attr.as_strz() 824 elif attr_spec["type"] == 'binary': 825 decoded = self._decode_binary(attr, attr_spec) 826 elif attr_spec["type"] == 'flag': 827 decoded = True 828 elif attr_spec.is_auto_scalar: 829 decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order) 830 if 'enum' in attr_spec: 831 decoded = self._decode_enum(decoded, attr_spec) 832 elif attr_spec["type"] in NlAttr.type_formats: 833 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) 834 if 'enum' in attr_spec: 835 decoded = self._decode_enum(decoded, attr_spec) 836 elif attr_spec.display_hint: 837 decoded = self._formatted_string(decoded, attr_spec.display_hint) 838 elif attr_spec["type"] == 'indexed-array': 839 decoded = self._decode_array_attr(attr, attr_spec) 840 elif attr_spec["type"] == 'bitfield32': 841 value, selector = struct.unpack("II", attr.raw) 842 if 'enum' in attr_spec: 843 value = self._decode_enum(value, attr_spec) 844 selector = self._decode_enum(selector, attr_spec) 845 decoded = {"value": value, "selector": selector} 846 elif attr_spec["type"] == 'sub-message': 847 decoded = self._decode_sub_msg(attr, attr_spec, search_attrs) 848 elif attr_spec["type"] == 'nest-type-value': 849 decoded = self._decode_nest_type_value(attr, attr_spec) 850 else: 851 if not self.process_unknown: 852 raise YnlException(f'Unknown {attr_spec["type"]} ' 853 f'with name {attr_spec["name"]}') 854 decoded = self._decode_unknown(attr) 855 856 self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded) 857 except: 858 print(f"Error decoding '{attr_spec.name}' from '{space}'") 859 raise 860 861 return rsp 862 863 # pylint: disable=too-many-arguments, too-many-positional-arguments 864 def _decode_extack_path(self, attrs, attr_set, offset, target, search_attrs): 865 for attr in attrs: 866 try: 867 attr_spec = attr_set.attrs_by_val[attr.type] 868 except KeyError as err: 869 raise YnlException( 870 f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") from err 871 if offset > target: 872 break 873 if offset == target: 874 return '.' + attr_spec.name 875 876 if offset + attr.full_len <= target: 877 offset += attr.full_len 878 continue 879 880 pathname = attr_spec.name 881 if attr_spec['type'] == 'nest': 882 sub_attrs = self.attr_sets[attr_spec['nested-attributes']] 883 search_attrs = SpaceAttrs(sub_attrs, search_attrs.lookup(attr_spec['name'])) 884 elif attr_spec['type'] == 'sub-message': 885 msg_format, value = self._resolve_selector(attr_spec, search_attrs) 886 if msg_format is None: 887 raise YnlException(f"Can't resolve sub-message of " 888 f"{attr_spec['name']} for extack") 889 sub_attrs = self.attr_sets[msg_format.attr_set] 890 pathname += f"({value})" 891 else: 892 raise YnlException(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") 893 offset += 4 894 subpath = self._decode_extack_path(NlAttrs(attr.raw), sub_attrs, 895 offset, target, search_attrs) 896 if subpath is None: 897 return None 898 return '.' + pathname + subpath 899 900 return None 901 902 def _decode_extack(self, request, op, extack, vals): 903 if 'bad-attr-offs' not in extack: 904 return 905 906 msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op) 907 offset = self.nlproto.msghdr_size() + self.struct_size(op.fixed_header) 908 search_attrs = SpaceAttrs(op.attr_set, vals) 909 path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset, 910 extack['bad-attr-offs'], search_attrs) 911 if path: 912 del extack['bad-attr-offs'] 913 extack['bad-attr'] = path 914 915 def struct_size(self, name): 916 if name: 917 members = self.consts[name].members 918 size = 0 919 for m in members: 920 if m.type in ['pad', 'binary']: 921 if m.struct: 922 size += self.struct_size(m.struct) 923 else: 924 size += m.len 925 else: 926 format_ = NlAttr.get_format(m.type, m.byte_order) 927 size += format_.size 928 return size 929 return 0 930 931 def _decode_struct(self, data, name): 932 members = self.consts[name].members 933 attrs = {} 934 offset = 0 935 for m in members: 936 value = None 937 if m.type == 'pad': 938 offset += m.len 939 elif m.type == 'binary': 940 if m.struct: 941 len_ = self.struct_size(m.struct) 942 value = self._decode_struct(data[offset : offset + len_], 943 m.struct) 944 offset += len_ 945 else: 946 value = data[offset : offset + m.len] 947 offset += m.len 948 else: 949 format_ = NlAttr.get_format(m.type, m.byte_order) 950 [ value ] = format_.unpack_from(data, offset) 951 offset += format_.size 952 if value is not None: 953 if m.enum: 954 value = self._decode_enum(value, m) 955 elif m.display_hint: 956 value = self._formatted_string(value, m.display_hint) 957 attrs[m.name] = value 958 return attrs 959 960 def _encode_struct(self, name, vals): 961 members = self.consts[name].members 962 attr_payload = b'' 963 for m in members: 964 value = vals.pop(m.name) if m.name in vals else None 965 if m.type == 'pad': 966 attr_payload += bytearray(m.len) 967 elif m.type == 'binary': 968 if m.struct: 969 if value is None: 970 value = {} 971 attr_payload += self._encode_struct(m.struct, value) 972 else: 973 if value is None: 974 attr_payload += bytearray(m.len) 975 else: 976 attr_payload += bytes.fromhex(value) 977 else: 978 if value is None: 979 value = 0 980 format_ = NlAttr.get_format(m.type, m.byte_order) 981 attr_payload += format_.pack(value) 982 return attr_payload 983 984 def _formatted_string(self, raw, display_hint): 985 if display_hint == 'mac': 986 formatted = ':'.join(f'{b:02x}' for b in raw) 987 elif display_hint == 'hex': 988 if isinstance(raw, int): 989 formatted = hex(raw) 990 else: 991 formatted = bytes.hex(raw, ' ') 992 elif display_hint in [ 'ipv4', 'ipv6', 'ipv4-or-v6' ]: 993 formatted = format(ipaddress.ip_address(raw)) 994 elif display_hint == 'uuid': 995 formatted = str(uuid.UUID(bytes=raw)) 996 else: 997 formatted = raw 998 return formatted 999 1000 def _from_string(self, string, attr_spec): 1001 if attr_spec.display_hint in ['ipv4', 'ipv6', 'ipv4-or-v6']: 1002 ip = ipaddress.ip_address(string) 1003 if attr_spec['type'] == 'binary': 1004 raw = ip.packed 1005 else: 1006 raw = int(ip) 1007 elif attr_spec.display_hint == 'hex': 1008 if attr_spec['type'] == 'binary': 1009 raw = bytes.fromhex(string) 1010 else: 1011 raw = int(string, 16) 1012 elif attr_spec.display_hint == 'mac': 1013 # Parse MAC address in format "00:11:22:33:44:55" or "001122334455" 1014 if ':' in string: 1015 mac_bytes = [int(x, 16) for x in string.split(':')] 1016 else: 1017 if len(string) % 2 != 0: 1018 raise YnlException(f"Invalid MAC address format: {string}") 1019 mac_bytes = [int(string[i:i+2], 16) for i in range(0, len(string), 2)] 1020 raw = bytes(mac_bytes) 1021 else: 1022 raise YnlException(f"Display hint '{attr_spec.display_hint}' not implemented" 1023 f" when parsing '{attr_spec['name']}'") 1024 return raw 1025 1026 def handle_ntf(self, decoded): 1027 msg = {} 1028 if self.include_raw: 1029 msg['raw'] = decoded 1030 op = self.rsp_by_value[decoded.cmd()] 1031 attrs = self._decode(decoded.raw_attrs, op.attr_set.name) 1032 if op.fixed_header: 1033 attrs.update(self._decode_struct(decoded.raw, op.fixed_header)) 1034 1035 msg['name'] = op['name'] 1036 msg['msg'] = attrs 1037 self.async_msg_queue.put(msg) 1038 1039 def check_ntf(self): 1040 while True: 1041 try: 1042 reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT) 1043 except BlockingIOError: 1044 return 1045 1046 nms = NlMsgs(reply) 1047 self._recv_dbg_print(reply, nms) 1048 for nl_msg in nms: 1049 if nl_msg.error: 1050 print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) 1051 print(nl_msg) 1052 continue 1053 if nl_msg.done: 1054 print("Netlink done while checking for ntf!?") 1055 continue 1056 1057 decoded = self.nlproto.decode(self, nl_msg, None) 1058 if decoded.cmd() not in self.async_msg_ids: 1059 print("Unexpected msg id while checking for ntf", decoded) 1060 continue 1061 1062 self.handle_ntf(decoded) 1063 1064 def poll_ntf(self, duration=None): 1065 start_time = time.time() 1066 selector = selectors.DefaultSelector() 1067 selector.register(self.sock, selectors.EVENT_READ) 1068 1069 while True: 1070 try: 1071 yield self.async_msg_queue.get_nowait() 1072 except queue.Empty: 1073 if duration is not None: 1074 timeout = start_time + duration - time.time() 1075 if timeout <= 0: 1076 return 1077 else: 1078 timeout = None 1079 events = selector.select(timeout) 1080 if events: 1081 self.check_ntf() 1082 1083 def operation_do_attributes(self, name): 1084 """ 1085 For a given operation name, find and return a supported 1086 set of attributes (as a dict). 1087 """ 1088 op = self.find_operation(name) 1089 if not op: 1090 return None 1091 1092 return op['do']['request']['attributes'].copy() 1093 1094 def _encode_message(self, op, vals, flags, req_seq): 1095 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK 1096 for flag in flags or []: 1097 nl_flags |= flag 1098 1099 msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq) 1100 if op.fixed_header: 1101 msg += self._encode_struct(op.fixed_header, vals) 1102 search_attrs = SpaceAttrs(op.attr_set, vals) 1103 for name, value in vals.items(): 1104 msg += self._add_attr(op.attr_set.name, name, value, search_attrs) 1105 msg = _genl_msg_finalize(msg) 1106 return msg 1107 1108 # pylint: disable=too-many-statements 1109 def _ops(self, ops): 1110 reqs_by_seq = {} 1111 req_seq = random.randint(1024, 65535) 1112 payload = b'' 1113 for (method, vals, flags) in ops: 1114 op = self.ops[method] 1115 msg = self._encode_message(op, vals, flags, req_seq) 1116 reqs_by_seq[req_seq] = (op, vals, msg, flags) 1117 payload += msg 1118 req_seq += 1 1119 1120 self.sock.send(payload, 0) 1121 1122 done = False 1123 rsp = [] 1124 op_rsp = [] 1125 while not done: 1126 reply = self.sock.recv(self._recv_size) 1127 nms = NlMsgs(reply) 1128 self._recv_dbg_print(reply, nms) 1129 for nl_msg in nms: 1130 if nl_msg.nl_seq in reqs_by_seq: 1131 (op, vals, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq] 1132 if nl_msg.extack: 1133 nl_msg.annotate_extack(op.attr_set) 1134 self._decode_extack(req_msg, op, nl_msg.extack, vals) 1135 else: 1136 op = None 1137 req_flags = [] 1138 1139 if nl_msg.error: 1140 raise NlError(nl_msg) 1141 if nl_msg.done: 1142 if nl_msg.extack: 1143 print("Netlink warning:") 1144 print(nl_msg) 1145 1146 if Netlink.NLM_F_DUMP in req_flags: 1147 rsp.append(op_rsp) 1148 elif not op_rsp: 1149 rsp.append(None) 1150 elif len(op_rsp) == 1: 1151 rsp.append(op_rsp[0]) 1152 else: 1153 rsp.append(op_rsp) 1154 op_rsp = [] 1155 1156 del reqs_by_seq[nl_msg.nl_seq] 1157 done = len(reqs_by_seq) == 0 1158 break 1159 1160 decoded = self.nlproto.decode(self, nl_msg, op) 1161 1162 # Check if this is a reply to our request 1163 if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value: 1164 if decoded.cmd() in self.async_msg_ids: 1165 self.handle_ntf(decoded) 1166 continue 1167 print('Unexpected message: ' + repr(decoded)) 1168 continue 1169 1170 rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name) 1171 if op.fixed_header: 1172 rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header)) 1173 op_rsp.append(rsp_msg) 1174 1175 return rsp 1176 1177 def _op(self, method, vals, flags=None, dump=False): 1178 req_flags = flags or [] 1179 if dump: 1180 req_flags.append(Netlink.NLM_F_DUMP) 1181 1182 ops = [(method, vals, req_flags)] 1183 return self._ops(ops)[0] 1184 1185 def do(self, method, vals, flags=None): 1186 return self._op(method, vals, flags) 1187 1188 def dump(self, method, vals): 1189 return self._op(method, vals, dump=True) 1190 1191 def do_multi(self, ops): 1192 return self._ops(ops) 1193