1#! /usr/bin/env python3 2# SPDX-License-Identifier: GPL-2.0 3 4import argparse 5import ctypes 6import errno 7import hashlib 8import os 9import select 10import signal 11import socket 12import subprocess 13import sys 14import atexit 15from pwd import getpwuid 16from os import stat 17 18# Allow utils module to be imported from different directory 19this_dir = os.path.dirname(os.path.realpath(__file__)) 20sys.path.append(os.path.join(this_dir, "../")) 21from lib.py.utils import ip 22 23libc = ctypes.cdll.LoadLibrary('libc.so.6') 24setns = libc.setns 25 26net0 = 'net0' 27net1 = 'net1' 28 29veth0 = 'veth0' 30veth1 = 'veth1' 31 32# Helper function for creating a socket inside a network namespace. 33# We need this because otherwise RDS will detect that the two TCP 34# sockets are on the same interface and use the loop transport instead 35# of the TCP transport. 36def netns_socket(netns, *args): 37 u0, u1 = socket.socketpair(socket.AF_UNIX, socket.SOCK_SEQPACKET) 38 39 child = os.fork() 40 if child == 0: 41 # change network namespace 42 with open(f'/var/run/netns/{netns}') as f: 43 try: 44 ret = setns(f.fileno(), 0) 45 except IOError as e: 46 print(e.errno) 47 print(e) 48 49 # create socket in target namespace 50 s = socket.socket(*args) 51 52 # send resulting socket to parent 53 socket.send_fds(u0, [], [s.fileno()]) 54 55 sys.exit(0) 56 57 # receive socket from child 58 _, s, _, _ = socket.recv_fds(u1, 0, 1) 59 os.waitpid(child, 0) 60 u0.close() 61 u1.close() 62 return socket.fromfd(s[0], *args) 63 64def signal_handler(sig, frame): 65 print('Test timed out') 66 sys.exit(1) 67 68#Parse out command line arguments. We take an optional 69# timeout parameter and an optional log output folder 70parser = argparse.ArgumentParser(description="init script args", 71 formatter_class=argparse.ArgumentDefaultsHelpFormatter) 72parser.add_argument("-d", "--logdir", action="store", 73 help="directory to store logs", default="/tmp") 74parser.add_argument('--timeout', help="timeout to terminate hung test", 75 type=int, default=0) 76parser.add_argument('-l', '--loss', help="Simulate tcp packet loss", 77 type=int, default=0) 78parser.add_argument('-c', '--corruption', help="Simulate tcp packet corruption", 79 type=int, default=0) 80parser.add_argument('-u', '--duplicate', help="Simulate tcp packet duplication", 81 type=int, default=0) 82args = parser.parse_args() 83logdir=args.logdir 84packet_loss=str(args.loss)+'%' 85packet_corruption=str(args.corruption)+'%' 86packet_duplicate=str(args.duplicate)+'%' 87 88ip(f"netns add {net0}") 89ip(f"netns add {net1}") 90ip(f"link add type veth") 91 92addrs = [ 93 # we technically don't need different port numbers, but this will 94 # help identify traffic in the network analyzer 95 ('10.0.0.1', 10000), 96 ('10.0.0.2', 20000), 97] 98 99# move interfaces to separate namespaces so they can no longer be 100# bound directly; this prevents rds from switching over from the tcp 101# transport to the loop transport. 102ip(f"link set {veth0} netns {net0} up") 103ip(f"link set {veth1} netns {net1} up") 104 105 106 107# add addresses 108ip(f"-n {net0} addr add {addrs[0][0]}/32 dev {veth0}") 109ip(f"-n {net1} addr add {addrs[1][0]}/32 dev {veth1}") 110 111# add routes 112ip(f"-n {net0} route add {addrs[1][0]}/32 dev {veth0}") 113ip(f"-n {net1} route add {addrs[0][0]}/32 dev {veth1}") 114 115# sanity check that our two interfaces/addresses are correctly set up 116# and communicating by doing a single ping 117ip(f"netns exec {net0} ping -c 1 {addrs[1][0]}") 118 119# Start a packet capture on each network 120for net in [net0, net1]: 121 tcpdump_pid = os.fork() 122 if tcpdump_pid == 0: 123 pcap = logdir+'/'+net+'.pcap' 124 subprocess.check_call(['touch', pcap]) 125 user = getpwuid(stat(pcap).st_uid).pw_name 126 ip(f"netns exec {net} /usr/sbin/tcpdump -Z {user} -i any -w {pcap}") 127 sys.exit(0) 128 129# simulate packet loss, duplication and corruption 130for net, iface in [(net0, veth0), (net1, veth1)]: 131 ip(f"netns exec {net} /usr/sbin/tc qdisc add dev {iface} root netem \ 132 corrupt {packet_corruption} loss {packet_loss} duplicate \ 133 {packet_duplicate}") 134 135# add a timeout 136if args.timeout > 0: 137 signal.alarm(args.timeout) 138 signal.signal(signal.SIGALRM, signal_handler) 139 140sockets = [ 141 netns_socket(net0, socket.AF_RDS, socket.SOCK_SEQPACKET), 142 netns_socket(net1, socket.AF_RDS, socket.SOCK_SEQPACKET), 143] 144 145for s, addr in zip(sockets, addrs): 146 s.bind(addr) 147 s.setblocking(0) 148 149fileno_to_socket = { 150 s.fileno(): s for s in sockets 151} 152 153addr_to_socket = { 154 addr: s for addr, s in zip(addrs, sockets) 155} 156 157socket_to_addr = { 158 s: addr for addr, s in zip(addrs, sockets) 159} 160 161send_hashes = {} 162recv_hashes = {} 163 164ep = select.epoll() 165 166for s in sockets: 167 ep.register(s, select.EPOLLRDNORM) 168 169n = 50000 170nr_send = 0 171nr_recv = 0 172 173while nr_send < n: 174 # Send as much as we can without blocking 175 print("sending...", nr_send, nr_recv) 176 while nr_send < n: 177 send_data = hashlib.sha256( 178 f'packet {nr_send}'.encode('utf-8')).hexdigest().encode('utf-8') 179 180 # pseudo-random send/receive pattern 181 sender = sockets[nr_send % 2] 182 receiver = sockets[1 - (nr_send % 3) % 2] 183 184 try: 185 sender.sendto(send_data, socket_to_addr[receiver]) 186 send_hashes.setdefault((sender.fileno(), receiver.fileno()), 187 hashlib.sha256()).update(f'<{send_data}>'.encode('utf-8')) 188 nr_send = nr_send + 1 189 except BlockingIOError as e: 190 break 191 except OSError as e: 192 if e.errno in [errno.ENOBUFS, errno.ECONNRESET, errno.EPIPE]: 193 break 194 raise 195 196 # Receive as much as we can without blocking 197 print("receiving...", nr_send, nr_recv) 198 while nr_recv < nr_send: 199 for fileno, eventmask in ep.poll(): 200 receiver = fileno_to_socket[fileno] 201 202 if eventmask & select.EPOLLRDNORM: 203 while True: 204 try: 205 recv_data, address = receiver.recvfrom(1024) 206 sender = addr_to_socket[address] 207 recv_hashes.setdefault((sender.fileno(), 208 receiver.fileno()), hashlib.sha256()).update( 209 f'<{recv_data}>'.encode('utf-8')) 210 nr_recv = nr_recv + 1 211 except BlockingIOError as e: 212 break 213 214 # exercise net/rds/tcp.c:rds_tcp_sysctl_reset() 215 for net in [net0, net1]: 216 ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_rcvbuf=10000") 217 ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_sndbuf=10000") 218 219print("done", nr_send, nr_recv) 220 221# the Python socket module doesn't know these 222RDS_INFO_FIRST = 10000 223RDS_INFO_LAST = 10017 224 225nr_success = 0 226nr_error = 0 227 228for s in sockets: 229 for optname in range(RDS_INFO_FIRST, RDS_INFO_LAST + 1): 230 # Sigh, the Python socket module doesn't allow us to pass 231 # buffer lengths greater than 1024 for some reason. RDS 232 # wants multiple pages. 233 try: 234 s.getsockopt(socket.SOL_RDS, optname, 1024) 235 nr_success = nr_success + 1 236 except OSError as e: 237 nr_error = nr_error + 1 238 if e.errno == errno.ENOSPC: 239 # ignore 240 pass 241 242print(f"getsockopt(): {nr_success}/{nr_error}") 243 244print("Stopping network packet captures") 245subprocess.check_call(['killall', '-q', 'tcpdump']) 246 247# We're done sending and receiving stuff, now let's check if what 248# we received is what we sent. 249for (sender, receiver), send_hash in send_hashes.items(): 250 recv_hash = recv_hashes.get((sender, receiver)) 251 252 if recv_hash is None: 253 print("FAIL: No data received") 254 sys.exit(1) 255 256 if send_hash.hexdigest() != recv_hash.hexdigest(): 257 print("FAIL: Send/recv mismatch") 258 print("hash expected:", send_hash.hexdigest()) 259 print("hash received:", recv_hash.hexdigest()) 260 sys.exit(1) 261 262 print(f"{sender}/{receiver}: ok") 263 264print("Success") 265sys.exit(0) 266