1#!/usr/local/bin/python3 2import os 3import socket 4import struct 5import sys 6from ctypes import c_byte 7from ctypes import c_char 8from ctypes import c_int 9from ctypes import c_long 10from ctypes import c_uint32 11from ctypes import c_ulong 12from ctypes import c_ushort 13from ctypes import sizeof 14from ctypes import Structure 15from typing import Dict 16from typing import List 17from typing import Optional 18from typing import Union 19 20 21def roundup2(val: int, num: int) -> int: 22 if val % num: 23 return (val | (num - 1)) + 1 24 else: 25 return val 26 27 28class RtSockException(OSError): 29 pass 30 31 32class RtConst: 33 RTM_VERSION = 5 34 ALIGN = sizeof(c_long) 35 36 AF_INET = socket.AF_INET 37 AF_INET6 = socket.AF_INET6 38 AF_LINK = socket.AF_LINK 39 40 RTA_DST = 0x1 41 RTA_GATEWAY = 0x2 42 RTA_NETMASK = 0x4 43 RTA_GENMASK = 0x8 44 RTA_IFP = 0x10 45 RTA_IFA = 0x20 46 RTA_AUTHOR = 0x40 47 RTA_BRD = 0x80 48 49 RTM_ADD = 1 50 RTM_DELETE = 2 51 RTM_CHANGE = 3 52 RTM_GET = 4 53 54 RTF_UP = 0x1 55 RTF_GATEWAY = 0x2 56 RTF_HOST = 0x4 57 RTF_REJECT = 0x8 58 RTF_DYNAMIC = 0x10 59 RTF_MODIFIED = 0x20 60 RTF_DONE = 0x40 61 RTF_XRESOLVE = 0x200 62 RTF_LLINFO = 0x400 63 RTF_LLDATA = 0x400 64 RTF_STATIC = 0x800 65 RTF_BLACKHOLE = 0x1000 66 RTF_PROTO2 = 0x4000 67 RTF_PROTO1 = 0x8000 68 RTF_PROTO3 = 0x40000 69 RTF_FIXEDMTU = 0x80000 70 RTF_PINNED = 0x100000 71 RTF_LOCAL = 0x200000 72 RTF_BROADCAST = 0x400000 73 RTF_MULTICAST = 0x800000 74 RTF_STICKY = 0x10000000 75 RTF_RNH_LOCKED = 0x40000000 76 RTF_GWFLAG_COMPAT = 0x80000000 77 78 RTV_MTU = 0x1 79 RTV_HOPCOUNT = 0x2 80 RTV_EXPIRE = 0x4 81 RTV_RPIPE = 0x8 82 RTV_SPIPE = 0x10 83 RTV_SSTHRESH = 0x20 84 RTV_RTT = 0x40 85 RTV_RTTVAR = 0x80 86 RTV_WEIGHT = 0x100 87 88 @staticmethod 89 def get_props(prefix: str) -> List[str]: 90 return [n for n in dir(RtConst) if n.startswith(prefix)] 91 92 @staticmethod 93 def get_name(prefix: str, value: int) -> str: 94 props = RtConst.get_props(prefix) 95 for prop in props: 96 if getattr(RtConst, prop) == value: 97 return prop 98 return "U:{}:{}".format(prefix, value) 99 100 @staticmethod 101 def get_bitmask_map(prefix: str, value: int) -> Dict[int, str]: 102 props = RtConst.get_props(prefix) 103 propmap = {getattr(RtConst, prop): prop for prop in props} 104 v = 1 105 ret = {} 106 while value: 107 if v & value: 108 if v in propmap: 109 ret[v] = propmap[v] 110 else: 111 ret[v] = hex(v) 112 value -= v 113 v *= 2 114 return ret 115 116 @staticmethod 117 def get_bitmask_str(prefix: str, value: int) -> str: 118 bmap = RtConst.get_bitmask_map(prefix, value) 119 return ",".join([v for k, v in bmap.items()]) 120 121 122class RtMetrics(Structure): 123 _fields_ = [ 124 ("rmx_locks", c_ulong), 125 ("rmx_mtu", c_ulong), 126 ("rmx_hopcount", c_ulong), 127 ("rmx_expire", c_ulong), 128 ("rmx_recvpipe", c_ulong), 129 ("rmx_sendpipe", c_ulong), 130 ("rmx_ssthresh", c_ulong), 131 ("rmx_rtt", c_ulong), 132 ("rmx_rttvar", c_ulong), 133 ("rmx_pksent", c_ulong), 134 ("rmx_weight", c_ulong), 135 ("rmx_nhidx", c_ulong), 136 ("rmx_filler", c_ulong * 2), 137 ] 138 139 140class RtMsgHdr(Structure): 141 _fields_ = [ 142 ("rtm_msglen", c_ushort), 143 ("rtm_version", c_byte), 144 ("rtm_type", c_byte), 145 ("rtm_index", c_ushort), 146 ("_rtm_spare1", c_ushort), 147 ("rtm_flags", c_int), 148 ("rtm_addrs", c_int), 149 ("rtm_pid", c_int), 150 ("rtm_seq", c_int), 151 ("rtm_errno", c_int), 152 ("rtm_fmask", c_int), 153 ("rtm_inits", c_ulong), 154 ("rtm_rmx", RtMetrics), 155 ] 156 157 158class SockaddrIn(Structure): 159 _fields_ = [ 160 ("sin_len", c_byte), 161 ("sin_family", c_byte), 162 ("sin_port", c_ushort), 163 ("sin_addr", c_uint32), 164 ("sin_zero", c_char * 8), 165 ] 166 167 168class SockaddrIn6(Structure): 169 _fields_ = [ 170 ("sin6_len", c_byte), 171 ("sin6_family", c_byte), 172 ("sin6_port", c_ushort), 173 ("sin6_flowinfo", c_uint32), 174 ("sin6_addr", c_byte * 16), 175 ("sin6_scope_id", c_uint32), 176 ] 177 178 179class SockaddrDl(Structure): 180 _fields_ = [ 181 ("sdl_len", c_byte), 182 ("sdl_family", c_byte), 183 ("sdl_index", c_ushort), 184 ("sdl_type", c_byte), 185 ("sdl_nlen", c_byte), 186 ("sdl_alen", c_byte), 187 ("sdl_slen", c_byte), 188 ("sdl_data", c_byte * 8), 189 ] 190 191 192class SaHelper(object): 193 @staticmethod 194 def is_ipv6(ip: str) -> bool: 195 return ":" in ip 196 197 @staticmethod 198 def ip_sa(ip: str, scopeid: int = 0) -> bytes: 199 if SaHelper.is_ipv6(ip): 200 return SaHelper.ip6_sa(ip, scopeid) 201 else: 202 return SaHelper.ip4_sa(ip) 203 204 @staticmethod 205 def ip4_sa(ip: str) -> bytes: 206 addr_int = int.from_bytes(socket.inet_pton(2, ip), sys.byteorder) 207 sin = SockaddrIn(sizeof(SockaddrIn), socket.AF_INET, 0, addr_int) 208 return bytes(sin) 209 210 @staticmethod 211 def ip6_sa(ip6: str, scopeid: int) -> bytes: 212 addr_bytes = (c_byte * 16)() 213 for i, b in enumerate(socket.inet_pton(socket.AF_INET6, ip6)): 214 addr_bytes[i] = b 215 sin6 = SockaddrIn6( 216 sizeof(SockaddrIn6), socket.AF_INET6, 0, 0, addr_bytes, scopeid 217 ) 218 return bytes(sin6) 219 220 @staticmethod 221 def link_sa(ifindex: int = 0, iftype: int = 0) -> bytes: 222 sa = SockaddrDl(sizeof(SockaddrDl), socket.AF_LINK, c_ushort(ifindex), iftype) 223 return bytes(sa) 224 225 @staticmethod 226 def pxlen4_sa(pxlen: int) -> bytes: 227 return SaHelper.ip_sa(SaHelper.pxlen_to_ip4(pxlen)) 228 229 @staticmethod 230 def pxlen_to_ip4(pxlen: int) -> str: 231 if pxlen == 32: 232 return "255.255.255.255" 233 else: 234 addr = 0xFFFFFFFF - ((1 << (32 - pxlen)) - 1) 235 addr_bytes = struct.pack("!I", addr) 236 return socket.inet_ntop(socket.AF_INET, addr_bytes) 237 238 @staticmethod 239 def pxlen6_sa(pxlen: int) -> bytes: 240 return SaHelper.ip_sa(SaHelper.pxlen_to_ip6(pxlen)) 241 242 @staticmethod 243 def pxlen_to_ip6(pxlen: int) -> str: 244 ip6_b = [0] * 16 245 start = 0 246 while pxlen > 8: 247 ip6_b[start] = 0xFF 248 pxlen -= 8 249 start += 1 250 ip6_b[start] = 0xFF - ((1 << (8 - pxlen)) - 1) 251 return socket.inet_ntop(socket.AF_INET6, bytes(ip6_b)) 252 253 @staticmethod 254 def print_sa_inet(sa: bytes): 255 if len(sa) < 8: 256 raise RtSockException("IPv4 sa size too small: {}".format(len(sa))) 257 addr = socket.inet_ntop(socket.AF_INET, sa[4:8]) 258 return "{}".format(addr) 259 260 @staticmethod 261 def print_sa_inet6(sa: bytes): 262 if len(sa) < sizeof(SockaddrIn6): 263 raise RtSockException("IPv6 sa size too small: {}".format(len(sa))) 264 addr = socket.inet_ntop(socket.AF_INET6, sa[8:24]) 265 scopeid = struct.unpack(">I", sa[24:28])[0] 266 return "{} scopeid {}".format(addr, scopeid) 267 268 @staticmethod 269 def print_sa_link(sa: bytes, hd: Optional[bool] = True): 270 if len(sa) < sizeof(SockaddrDl): 271 raise RtSockException("LINK sa size too small: {}".format(len(sa))) 272 sdl = SockaddrDl.from_buffer_copy(sa) 273 if sdl.sdl_index: 274 ifindex = "link#{} ".format(sdl.sdl_index) 275 else: 276 ifindex = "" 277 if sdl.sdl_nlen: 278 iface_offset = 8 279 if sdl.sdl_nlen + iface_offset > len(sa): 280 raise RtSockException( 281 "LINK sa sdl_nlen {} > total len {}".format(sdl.sdl_nlen, len(sa)) 282 ) 283 ifname = "ifname:{} ".format( 284 bytes.decode(sa[iface_offset : iface_offset + sdl.sdl_nlen]) 285 ) 286 else: 287 ifname = "" 288 return "{}{}".format(ifindex, ifname) 289 290 @staticmethod 291 def print_sa_unknown(sa: bytes): 292 return "unknown_type:{}".format(sa[1]) 293 294 @classmethod 295 def print_sa(cls, sa: bytes, hd: Optional[bool] = False): 296 if sa[0] != len(sa): 297 raise Exception("sa size {} != buffer size {}".format(sa[0], len(sa))) 298 299 if len(sa) < 2: 300 raise Exception( 301 "sa type {} too short: {}".format( 302 RtConst.get_name("AF_", sa[1]), len(sa) 303 ) 304 ) 305 306 if sa[1] == socket.AF_INET: 307 text = cls.print_sa_inet(sa) 308 elif sa[1] == socket.AF_INET6: 309 text = cls.print_sa_inet6(sa) 310 elif sa[1] == socket.AF_LINK: 311 text = cls.print_sa_link(sa) 312 else: 313 text = cls.print_sa_unknown(sa) 314 if hd: 315 dump = " [{!r}]".format(sa) 316 else: 317 dump = "" 318 return "{}{}".format(text, dump) 319 320 321class BaseRtsockMessage(object): 322 def __init__(self, rtm_type): 323 self.rtm_type = rtm_type 324 self.sa = SaHelper() 325 326 @staticmethod 327 def print_rtm_type(rtm_type): 328 return RtConst.get_name("RTM_", rtm_type) 329 330 @property 331 def rtm_type_str(self): 332 return self.print_rtm_type(self.rtm_type) 333 334 335class RtsockRtMessage(BaseRtsockMessage): 336 messages = [ 337 RtConst.RTM_ADD, 338 RtConst.RTM_DELETE, 339 RtConst.RTM_CHANGE, 340 RtConst.RTM_GET, 341 ] 342 343 def __init__(self, rtm_type, rtm_seq=1, dst_sa=None, mask_sa=None): 344 super().__init__(rtm_type) 345 self.rtm_flags = 0 346 self.rtm_seq = rtm_seq 347 self._attrs = {} 348 self.rtm_errno = 0 349 self.rtm_pid = 0 350 self.rtm_inits = 0 351 self.rtm_rmx = RtMetrics() 352 self._orig_data = None 353 if dst_sa: 354 self.add_sa_attr(RtConst.RTA_DST, dst_sa) 355 if mask_sa: 356 self.add_sa_attr(RtConst.RTA_NETMASK, mask_sa) 357 358 def add_sa_attr(self, attr_type, attr_bytes: bytes): 359 self._attrs[attr_type] = attr_bytes 360 361 def add_ip_attr(self, attr_type, ip_addr: str, scopeid: int = 0): 362 if ":" in ip_addr: 363 self.add_ip6_attr(attr_type, ip_addr, scopeid) 364 else: 365 self.add_ip4_attr(attr_type, ip_addr) 366 367 def add_ip4_attr(self, attr_type, ip: str): 368 self.add_sa_attr(attr_type, self.sa.ip_sa(ip)) 369 370 def add_ip6_attr(self, attr_type, ip6: str, scopeid: int): 371 self.add_sa_attr(attr_type, self.sa.ip6_sa(ip6, scopeid)) 372 373 def add_link_attr(self, attr_type, ifindex: Optional[int] = 0): 374 self.add_sa_attr(attr_type, self.sa.link_sa(ifindex)) 375 376 def get_sa(self, attr_type) -> bytes: 377 return self._attrs.get(attr_type) 378 379 def print_message(self): 380 # RTM_GET: Report Metrics: len 272, pid: 87839, seq 1, errno 0, flags:<UP,GATEWAY,DONE,STATIC> 381 if self._orig_data: 382 rtm_len = len(self._orig_data) 383 else: 384 rtm_len = len(bytes(self)) 385 print( 386 "{}: len {}, pid: {}, seq {}, errno {}, flags: <{}>".format( 387 self.rtm_type_str, 388 rtm_len, 389 self.rtm_pid, 390 self.rtm_seq, 391 self.rtm_errno, 392 RtConst.get_bitmask_str("RTF_", self.rtm_flags), 393 ) 394 ) 395 rtm_addrs = sum(list(self._attrs.keys())) 396 print("Addrs: <{}>".format(RtConst.get_bitmask_str("RTA_", rtm_addrs))) 397 for attr in sorted(self._attrs.keys()): 398 sa_data = SaHelper.print_sa(self._attrs[attr]) 399 print(" {}: {}".format(RtConst.get_name("RTA_", attr), sa_data)) 400 401 def print_in_message(self): 402 print("vvvvvvvv IN vvvvvvvv") 403 self.print_message() 404 print() 405 406 def verify_sa_inet(self, sa_data): 407 if len(sa_data) < 8: 408 raise Exception("IPv4 sa size too small: {}".format(sa_data)) 409 if sa_data[0] > len(sa_data): 410 raise Exception( 411 "IPv4 sin_len too big: {} vs sa size {}: {}".format( 412 sa_data[0], len(sa_data), sa_data 413 ) 414 ) 415 sin = SockaddrIn.from_buffer_copy(sa_data) 416 assert sin.sin_port == 0 417 assert sin.sin_zero == [0] * 8 418 419 def compare_sa(self, sa_type, sa_data): 420 if len(sa_data) < 4: 421 sa_type_name = RtConst.get_name("RTA_", sa_type) 422 raise Exception( 423 "sa_len for type {} too short: {}".format(sa_type_name, len(sa_data)) 424 ) 425 our_sa = self._attrs[sa_type] 426 assert SaHelper.print_sa(sa_data) == SaHelper.print_sa(our_sa) 427 assert len(sa_data) == len(our_sa) 428 assert sa_data == our_sa 429 430 def verify(self, rtm_type: int, rtm_sa): 431 assert self.rtm_type_str == self.print_rtm_type(rtm_type) 432 assert self.rtm_errno == 0 433 hdr = RtMsgHdr.from_buffer_copy(self._orig_data) 434 assert hdr._rtm_spare1 == 0 435 for sa_type, sa_data in rtm_sa.items(): 436 if sa_type not in self._attrs: 437 sa_type_name = RtConst.get_name("RTA_", sa_type) 438 raise Exception("SA type {} not present".format(sa_type_name)) 439 self.compare_sa(sa_type, sa_data) 440 441 @classmethod 442 def from_bytes(cls, data: bytes): 443 if len(data) < sizeof(RtMsgHdr): 444 raise Exception( 445 "messages size {} is less than expected {}".format( 446 len(data), sizeof(RtMsgHdr) 447 ) 448 ) 449 hdr = RtMsgHdr.from_buffer_copy(data) 450 451 self = cls(hdr.rtm_type) 452 self.rtm_flags = hdr.rtm_flags 453 self.rtm_seq = hdr.rtm_seq 454 self.rtm_errno = hdr.rtm_errno 455 self.rtm_pid = hdr.rtm_pid 456 self.rtm_inits = hdr.rtm_inits 457 self.rtm_rmx = hdr.rtm_rmx 458 self._orig_data = data 459 460 off = sizeof(RtMsgHdr) 461 v = 1 462 addrs_mask = hdr.rtm_addrs 463 while addrs_mask: 464 if addrs_mask & v: 465 addrs_mask -= v 466 467 if off + data[off] > len(data): 468 raise Exception( 469 "SA sizeof for {} > total message length: {}+{} > {}".format( 470 RtConst.get_name("RTA_", v), off, data[off], len(data) 471 ) 472 ) 473 self._attrs[v] = data[off : off + data[off]] 474 off += roundup2(data[off], RtConst.ALIGN) 475 v *= 2 476 return self 477 478 def __bytes__(self): 479 sz = sizeof(RtMsgHdr) 480 addrs_mask = 0 481 for k, v in self._attrs.items(): 482 sz += roundup2(len(v), RtConst.ALIGN) 483 addrs_mask += k 484 hdr = RtMsgHdr( 485 rtm_msglen=sz, 486 rtm_version=RtConst.RTM_VERSION, 487 rtm_type=self.rtm_type, 488 rtm_flags=self.rtm_flags, 489 rtm_seq=self.rtm_seq, 490 rtm_addrs=addrs_mask, 491 rtm_inits=self.rtm_inits, 492 rtm_rmx=self.rtm_rmx, 493 ) 494 buf = bytearray(sz) 495 buf[0 : sizeof(RtMsgHdr)] = hdr 496 off = sizeof(RtMsgHdr) 497 for attr in sorted(self._attrs.keys()): 498 v = self._attrs[attr] 499 sa_len = len(v) 500 buf[off : off + sa_len] = v 501 off += roundup2(len(v), RtConst.ALIGN) 502 return bytes(buf) 503 504 505class Rtsock: 506 def __init__(self): 507 self.socket = self._setup_rtsock() 508 self.rtm_seq = 1 509 self.msgmap = self.build_msgmap() 510 511 def build_msgmap(self): 512 classes = [RtsockRtMessage] 513 xmap = {} 514 for cls in classes: 515 for message in cls.messages: 516 xmap[message] = cls 517 return xmap 518 519 def get_seq(self): 520 ret = self.rtm_seq 521 self.rtm_seq += 1 522 return ret 523 524 def get_weight(self, weight) -> int: 525 if weight: 526 return weight 527 else: 528 return 1 # RT_DEFAULT_WEIGHT 529 530 def new_rtm_any(self, msg_type, prefix: str, gw: Union[str, bytes]): 531 px = prefix.split("/") 532 addr_sa = SaHelper.ip_sa(px[0]) 533 if len(px) > 1: 534 pxlen = int(px[1]) 535 if SaHelper.is_ipv6(px[0]): 536 mask_sa = SaHelper.pxlen6_sa(pxlen) 537 else: 538 mask_sa = SaHelper.pxlen4_sa(pxlen) 539 else: 540 mask_sa = None 541 msg = RtsockRtMessage(msg_type, self.get_seq(), addr_sa, mask_sa) 542 if isinstance(gw, bytes): 543 msg.add_sa_attr(RtConst.RTA_GATEWAY, gw) 544 else: 545 # String 546 msg.add_ip_attr(RtConst.RTA_GATEWAY, gw) 547 return msg 548 549 def new_rtm_add(self, prefix: str, gw: Union[str, bytes]): 550 return self.new_rtm_any(RtConst.RTM_ADD, prefix, gw) 551 552 def new_rtm_del(self, prefix: str, gw: Union[str, bytes]): 553 return self.new_rtm_any(RtConst.RTM_DELETE, prefix, gw) 554 555 def new_rtm_change(self, prefix: str, gw: Union[str, bytes]): 556 return self.new_rtm_any(RtConst.RTM_CHANGE, prefix, gw) 557 558 def _setup_rtsock(self) -> socket.socket: 559 s = socket.socket(socket.AF_ROUTE, socket.SOCK_RAW, socket.AF_UNSPEC) 560 s.setsockopt(socket.SOL_SOCKET, socket.SO_USELOOPBACK, 1) 561 return s 562 563 def print_hd(self, data: bytes): 564 width = 16 565 print("==========================================") 566 for chunk in [data[i : i + width] for i in range(0, len(data), width)]: 567 for b in chunk: 568 print("0x{:02X} ".format(b), end="") 569 print() 570 print() 571 572 def write_message(self, msg): 573 print("vvvvvvvv OUT vvvvvvvv") 574 msg.print_message() 575 print() 576 msg_bytes = bytes(msg) 577 ret = os.write(self.socket.fileno(), msg_bytes) 578 if ret != -1: 579 assert ret == len(msg_bytes) 580 581 def parse_message(self, data: bytes): 582 if len(data) < 4: 583 raise OSError("Short read from rtsock: {} bytes".format(len(data))) 584 rtm_type = data[4] 585 if rtm_type not in self.msgmap: 586 return None 587 588 def write_data(self, data: bytes): 589 self.socket.send(data) 590 591 def read_data(self, seq: Optional[int] = None) -> bytes: 592 while True: 593 data = self.socket.recv(4096) 594 if seq is None: 595 break 596 if len(data) > sizeof(RtMsgHdr): 597 hdr = RtMsgHdr.from_buffer_copy(data) 598 if hdr.rtm_seq == seq: 599 break 600 return data 601 602 def read_message(self) -> bytes: 603 data = self.read_data() 604 return self.parse_message(data) 605