1#! /usr/bin/env python3 2# SPDX-License-Identifier: GPL-2.0 3 4import argparse 5import ctypes 6import errno 7import hashlib 8import os 9import select 10import signal 11import socket 12import subprocess 13import sys 14import atexit 15from pwd import getpwuid 16from os import stat 17from lib.py import ip 18 19 20libc = ctypes.cdll.LoadLibrary('libc.so.6') 21setns = libc.setns 22 23net0 = 'net0' 24net1 = 'net1' 25 26veth0 = 'veth0' 27veth1 = 'veth1' 28 29# Helper function for creating a socket inside a network namespace. 30# We need this because otherwise RDS will detect that the two TCP 31# sockets are on the same interface and use the loop transport instead 32# of the TCP transport. 33def netns_socket(netns, *args): 34 u0, u1 = socket.socketpair(socket.AF_UNIX, socket.SOCK_SEQPACKET) 35 36 child = os.fork() 37 if child == 0: 38 # change network namespace 39 with open(f'/var/run/netns/{netns}') as f: 40 try: 41 ret = setns(f.fileno(), 0) 42 except IOError as e: 43 print(e.errno) 44 print(e) 45 46 # create socket in target namespace 47 s = socket.socket(*args) 48 49 # send resulting socket to parent 50 socket.send_fds(u0, [], [s.fileno()]) 51 52 sys.exit(0) 53 54 # receive socket from child 55 _, s, _, _ = socket.recv_fds(u1, 0, 1) 56 os.waitpid(child, 0) 57 u0.close() 58 u1.close() 59 return socket.fromfd(s[0], *args) 60 61def signal_handler(sig, frame): 62 print('Test timed out') 63 sys.exit(1) 64 65#Parse out command line arguments. We take an optional 66# timeout parameter and an optional log output folder 67parser = argparse.ArgumentParser(description="init script args", 68 formatter_class=argparse.ArgumentDefaultsHelpFormatter) 69parser.add_argument("-d", "--logdir", action="store", 70 help="directory to store logs", default="/tmp") 71parser.add_argument('--timeout', help="timeout to terminate hung test", 72 type=int, default=0) 73parser.add_argument('-l', '--loss', help="Simulate tcp packet loss", 74 type=int, default=0) 75parser.add_argument('-c', '--corruption', help="Simulate tcp packet corruption", 76 type=int, default=0) 77parser.add_argument('-u', '--duplicate', help="Simulate tcp packet duplication", 78 type=int, default=0) 79args = parser.parse_args() 80logdir=args.logdir 81packet_loss=str(args.loss)+'%' 82packet_corruption=str(args.corruption)+'%' 83packet_duplicate=str(args.duplicate)+'%' 84 85ip(f"netns add {net0}") 86ip(f"netns add {net1}") 87ip(f"link add type veth") 88 89addrs = [ 90 # we technically don't need different port numbers, but this will 91 # help identify traffic in the network analyzer 92 ('10.0.0.1', 10000), 93 ('10.0.0.2', 20000), 94] 95 96# move interfaces to separate namespaces so they can no longer be 97# bound directly; this prevents rds from switching over from the tcp 98# transport to the loop transport. 99ip(f"link set {veth0} netns {net0} up") 100ip(f"link set {veth1} netns {net1} up") 101 102 103 104# add addresses 105ip(f"-n {net0} addr add {addrs[0][0]}/32 dev {veth0}") 106ip(f"-n {net1} addr add {addrs[1][0]}/32 dev {veth1}") 107 108# add routes 109ip(f"-n {net0} route add {addrs[1][0]}/32 dev {veth0}") 110ip(f"-n {net1} route add {addrs[0][0]}/32 dev {veth1}") 111 112# sanity check that our two interfaces/addresses are correctly set up 113# and communicating by doing a single ping 114ip(f"netns exec {net0} ping -c 1 {addrs[1][0]}") 115 116# Start a packet capture on each network 117for net in [net0, net1]: 118 tcpdump_pid = os.fork() 119 if tcpdump_pid == 0: 120 pcap = logdir+'/'+net+'.pcap' 121 subprocess.check_call(['touch', pcap]) 122 user = getpwuid(stat(pcap).st_uid).pw_name 123 ip(f"netns exec {net} /usr/sbin/tcpdump -Z {user} -i any -w {pcap}") 124 sys.exit(0) 125 126# simulate packet loss, duplication and corruption 127for net, iface in [(net0, veth0), (net1, veth1)]: 128 ip(f"netns exec {net} /usr/sbin/tc qdisc add dev {iface} root netem \ 129 corrupt {packet_corruption} loss {packet_loss} duplicate \ 130 {packet_duplicate}") 131 132# add a timeout 133if args.timeout > 0: 134 signal.alarm(args.timeout) 135 signal.signal(signal.SIGALRM, signal_handler) 136 137sockets = [ 138 netns_socket(net0, socket.AF_RDS, socket.SOCK_SEQPACKET), 139 netns_socket(net1, socket.AF_RDS, socket.SOCK_SEQPACKET), 140] 141 142for s, addr in zip(sockets, addrs): 143 s.bind(addr) 144 s.setblocking(0) 145 146fileno_to_socket = { 147 s.fileno(): s for s in sockets 148} 149 150addr_to_socket = { 151 addr: s for addr, s in zip(addrs, sockets) 152} 153 154socket_to_addr = { 155 s: addr for addr, s in zip(addrs, sockets) 156} 157 158send_hashes = {} 159recv_hashes = {} 160 161ep = select.epoll() 162 163for s in sockets: 164 ep.register(s, select.EPOLLRDNORM) 165 166n = 50000 167nr_send = 0 168nr_recv = 0 169 170while nr_send < n: 171 # Send as much as we can without blocking 172 print("sending...", nr_send, nr_recv) 173 while nr_send < n: 174 send_data = hashlib.sha256( 175 f'packet {nr_send}'.encode('utf-8')).hexdigest().encode('utf-8') 176 177 # pseudo-random send/receive pattern 178 sender = sockets[nr_send % 2] 179 receiver = sockets[1 - (nr_send % 3) % 2] 180 181 try: 182 sender.sendto(send_data, socket_to_addr[receiver]) 183 send_hashes.setdefault((sender.fileno(), receiver.fileno()), 184 hashlib.sha256()).update(f'<{send_data}>'.encode('utf-8')) 185 nr_send = nr_send + 1 186 except BlockingIOError as e: 187 break 188 except OSError as e: 189 if e.errno in [errno.ENOBUFS, errno.ECONNRESET, errno.EPIPE]: 190 break 191 raise 192 193 # Receive as much as we can without blocking 194 print("receiving...", nr_send, nr_recv) 195 while nr_recv < nr_send: 196 for fileno, eventmask in ep.poll(): 197 receiver = fileno_to_socket[fileno] 198 199 if eventmask & select.EPOLLRDNORM: 200 while True: 201 try: 202 recv_data, address = receiver.recvfrom(1024) 203 sender = addr_to_socket[address] 204 recv_hashes.setdefault((sender.fileno(), 205 receiver.fileno()), hashlib.sha256()).update( 206 f'<{recv_data}>'.encode('utf-8')) 207 nr_recv = nr_recv + 1 208 except BlockingIOError as e: 209 break 210 211 # exercise net/rds/tcp.c:rds_tcp_sysctl_reset() 212 for net in [net0, net1]: 213 ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_rcvbuf=10000") 214 ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_sndbuf=10000") 215 216print("done", nr_send, nr_recv) 217 218# the Python socket module doesn't know these 219RDS_INFO_FIRST = 10000 220RDS_INFO_LAST = 10017 221 222nr_success = 0 223nr_error = 0 224 225for s in sockets: 226 for optname in range(RDS_INFO_FIRST, RDS_INFO_LAST + 1): 227 # Sigh, the Python socket module doesn't allow us to pass 228 # buffer lengths greater than 1024 for some reason. RDS 229 # wants multiple pages. 230 try: 231 s.getsockopt(socket.SOL_RDS, optname, 1024) 232 nr_success = nr_success + 1 233 except OSError as e: 234 nr_error = nr_error + 1 235 if e.errno == errno.ENOSPC: 236 # ignore 237 pass 238 239print(f"getsockopt(): {nr_success}/{nr_error}") 240 241print("Stopping network packet captures") 242subprocess.check_call(['killall', '-q', 'tcpdump']) 243 244# We're done sending and receiving stuff, now let's check if what 245# we received is what we sent. 246for (sender, receiver), send_hash in send_hashes.items(): 247 recv_hash = recv_hashes.get((sender, receiver)) 248 249 if recv_hash is None: 250 print("FAIL: No data received") 251 sys.exit(1) 252 253 if send_hash.hexdigest() != recv_hash.hexdigest(): 254 print("FAIL: Send/recv mismatch") 255 print("hash expected:", send_hash.hexdigest()) 256 print("hash received:", recv_hash.hexdigest()) 257 sys.exit(1) 258 259 print(f"{sender}/{receiver}: ok") 260 261print("Success") 262sys.exit(0) 263