1#! /usr/bin/env python3 2# SPDX-License-Identifier: GPL-2.0 3""" 4This module provides functional testing for the net/rds component. 5""" 6 7import argparse 8import ctypes 9import errno 10import hashlib 11import os 12import select 13import signal 14import socket 15import subprocess 16import sys 17 18# Allow utils module to be imported from different directory 19this_dir = os.path.dirname(os.path.realpath(__file__)) 20sys.path.append(os.path.join(this_dir, "../")) 21# pylint: disable-next=wrong-import-position,import-error,no-name-in-module 22from lib.py.utils import ip # noqa: E402 23# pylint: disable-next=wrong-import-position,import-error,no-name-in-module 24from lib.py.ksft import ksft_pr # noqa: E402 25 26libc = ctypes.cdll.LoadLibrary('libc.so.6') 27setns = libc.setns 28 29NET0 = 'net0' 30NET1 = 'net1' 31 32VETH0 = 'veth0' 33VETH1 = 'veth1' 34 35# Helper function for creating a socket inside a network namespace. 36# We need this because otherwise RDS will detect that the two TCP 37# sockets are on the same interface and use the loop transport instead 38# of the TCP transport. 39def netns_socket(netns, *sock_args): 40 """ 41 Creates sockets inside of network namespace 42 43 :param netns: the name of the network namespace 44 :param sock_args: socket family and type 45 """ 46 u0, u1 = socket.socketpair(socket.AF_UNIX, socket.SOCK_SEQPACKET) 47 48 child = os.fork() 49 if child == 0: 50 # change network namespace 51 with open(f'/var/run/netns/{netns}', encoding='utf-8') as f: 52 try: 53 setns(f.fileno(), 0) 54 except IOError as e: 55 print(e.errno) 56 print(e) 57 58 # create socket in target namespace 59 sock = socket.socket(*sock_args) 60 61 # send resulting socket to parent 62 socket.send_fds(u0, [], [sock.fileno()]) 63 64 os._exit(0) 65 66 # receive socket from child 67 _, fds, _, _ = socket.recv_fds(u1, 0, 1) 68 os.waitpid(child, 0) 69 u0.close() 70 u1.close() 71 return socket.fromfd(fds[0], *sock_args) 72 73def stop_pcaps(): 74 """Stop tcpdump processes. 75 76 We use pop() here to drain the list in the event that the test 77 completes after the signal handler is fired. List will be empty 78 if logdir is not set 79 """ 80 81 if not tcpdump_procs: 82 return 83 84 ksft_pr("Stopping network packet captures") 85 while tcpdump_procs: 86 proc = tcpdump_procs.pop() 87 proc.terminate() 88 try: 89 proc.wait(timeout=5) 90 except subprocess.TimeoutExpired: 91 proc.kill() 92 proc.wait() 93 94def signal_handler(_sig, _frame): 95 """ 96 Test timed out signal handler 97 """ 98 ksft_pr("Test timed out") 99 stop_pcaps() 100 print("not ok 1 rds selftest") 101 sys.exit(1) 102 103#Parse out command line arguments. We take an optional 104# timeout parameter and an optional log output folder 105parser = argparse.ArgumentParser(description="init script args", 106 formatter_class=argparse.ArgumentDefaultsHelpFormatter) 107parser.add_argument("-d", "--logdir", action="store", 108 help="directory to store logs", default=None) 109parser.add_argument('-t', '--timeout', help="timeout to terminate hung test", 110 type=int, default=0) 111parser.add_argument('-l', '--loss', help="Simulate tcp packet loss", 112 type=int, default=0) 113parser.add_argument('-c', '--corruption', help="Simulate tcp packet corruption", 114 type=int, default=0) 115parser.add_argument('-u', '--duplicate', help="Simulate tcp packet duplication", 116 type=int, default=0) 117args = parser.parse_args() 118logdir=args.logdir 119PACKET_LOSS=str(args.loss)+'%' 120PACKET_CORRUPTION=str(args.corruption)+'%' 121PACKET_DUPLICATE=str(args.duplicate)+'%' 122 123ip(f"netns add {NET0}") 124ip(f"netns add {NET1}") 125ip("link add type veth") 126 127addrs = [ 128 # we technically don't need different port numbers, but this will 129 # help identify traffic in the network analyzer 130 ('10.0.0.1', 10000), 131 ('10.0.0.2', 20000), 132] 133 134# move interfaces to separate namespaces so they can no longer be 135# bound directly; this prevents rds from switching over from the tcp 136# transport to the loop transport. 137ip(f"link set {VETH0} netns {NET0} up") 138ip(f"link set {VETH1} netns {NET1} up") 139 140 141 142# add addresses 143ip(f"-n {NET0} addr add {addrs[0][0]}/32 dev {VETH0}") 144ip(f"-n {NET1} addr add {addrs[1][0]}/32 dev {VETH1}") 145 146# add routes 147ip(f"-n {NET0} route add {addrs[1][0]}/32 dev {VETH0}") 148ip(f"-n {NET1} route add {addrs[0][0]}/32 dev {VETH1}") 149 150# sanity check that our two interfaces/addresses are correctly set up 151# and communicating by doing a single ping 152ip(f"netns exec {NET0} ping -c 1 {addrs[1][0]}") 153 154tcpdump_procs = [] 155# Start a packet capture on each network 156if logdir is not None: 157 for net in [NET0, NET1]: 158 pcap = logdir+'/rds-'+net+'.pcap' 159 160 tcpdump_cmd = ['ip', 'netns', 'exec', net, '/usr/sbin/tcpdump'] 161 sudo_user = os.environ.get('SUDO_USER') 162 if sudo_user: 163 tcpdump_cmd.extend(['-Z', sudo_user]) 164 tcpdump_cmd.extend(['-i', 'any', '-w', pcap]) 165 166 # pylint: disable-next=consider-using-with 167 p = subprocess.Popen(tcpdump_cmd, 168 stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) 169 tcpdump_procs.append(p) 170 171# simulate packet loss, duplication and corruption 172for net, iface in [(NET0, VETH0), (NET1, VETH1)]: 173 ip(f"netns exec {net} /usr/sbin/tc qdisc add dev {iface} root netem \ 174 corrupt {PACKET_CORRUPTION} loss {PACKET_LOSS} duplicate \ 175 {PACKET_DUPLICATE}") 176 177print("TAP version 13") 178print("1..1") 179 180# add a timeout 181if args.timeout > 0: 182 signal.alarm(args.timeout) 183 signal.signal(signal.SIGALRM, signal_handler) 184 185sockets = [ 186 netns_socket(NET0, socket.AF_RDS, socket.SOCK_SEQPACKET), 187 netns_socket(NET1, socket.AF_RDS, socket.SOCK_SEQPACKET), 188] 189 190for s, addr in zip(sockets, addrs): 191 s.bind(addr) 192 s.setblocking(0) 193 194fileno_to_socket = { 195 s.fileno(): s for s in sockets 196} 197 198addr_to_socket = dict(zip(addrs, sockets)) 199 200socket_to_addr = { 201 s: addr for addr, s in zip(addrs, sockets) 202} 203 204send_hashes = {} 205recv_hashes = {} 206 207ep = select.epoll() 208 209for s in sockets: 210 ep.register(s, select.EPOLLRDNORM) 211 212NUM_PACKETS = 50000 213nr_send = 0 214nr_recv = 0 215 216while nr_send < NUM_PACKETS: 217 # Send as much as we can without blocking 218 ksft_pr("sending...", nr_send, nr_recv) 219 while nr_send < NUM_PACKETS: 220 send_data = hashlib.sha256( 221 f'packet {nr_send}'.encode('utf-8')).hexdigest().encode('utf-8') 222 223 # pseudo-random send/receive pattern 224 sender = sockets[nr_send % 2] 225 receiver = sockets[1 - (nr_send % 3) % 2] 226 227 try: 228 sender.sendto(send_data, socket_to_addr[receiver]) 229 send_hashes.setdefault((sender.fileno(), receiver.fileno()), 230 hashlib.sha256()).update(f'<{send_data}>'.encode('utf-8')) 231 nr_send = nr_send + 1 232 except BlockingIOError: 233 break 234 except OSError as e: 235 if e.errno in [errno.ENOBUFS, errno.ECONNRESET, errno.EPIPE]: 236 break 237 raise 238 239 # Receive as much as we can without blocking 240 ksft_pr("receiving...", nr_send, nr_recv) 241 while nr_recv < nr_send: 242 for fileno, eventmask in ep.poll(): 243 receiver = fileno_to_socket[fileno] 244 245 if eventmask & select.EPOLLRDNORM: 246 while True: 247 try: 248 recv_data, address = receiver.recvfrom(1024) 249 sender = addr_to_socket[address] 250 recv_hashes.setdefault((sender.fileno(), 251 receiver.fileno()), hashlib.sha256()).update( 252 f'<{recv_data}>'.encode('utf-8')) 253 nr_recv = nr_recv + 1 254 except BlockingIOError: 255 break 256 257 # exercise net/rds/tcp.c:rds_tcp_sysctl_reset() 258 for net in [NET0, NET1]: 259 ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_rcvbuf=10000") 260 ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_sndbuf=10000") 261 262ksft_pr("done", nr_send, nr_recv) 263 264# the Python socket module doesn't know these 265RDS_INFO_FIRST = 10000 266RDS_INFO_LAST = 10017 267 268nr_success = 0 269nr_error = 0 270 271for s in sockets: 272 for optname in range(RDS_INFO_FIRST, RDS_INFO_LAST + 1): 273 # Sigh, the Python socket module doesn't allow us to pass 274 # buffer lengths greater than 1024 for some reason. RDS 275 # wants multiple pages. 276 try: 277 s.getsockopt(socket.SOL_RDS, optname, 1024) 278 nr_success = nr_success + 1 279 except OSError as e: 280 nr_error = nr_error + 1 281 if e.errno == errno.ENOSPC: 282 # ignore 283 pass 284 285ksft_pr(f"getsockopt(): {nr_success}/{nr_error}") 286 287# cancel timeout 288signal.alarm(0) 289 290stop_pcaps() 291 292# We're done sending and receiving stuff, now let's check if what 293# we received is what we sent. 294ret = 0 295for (sender, receiver), send_hash in send_hashes.items(): 296 recv_hash = recv_hashes.get((sender, receiver)) 297 298 if recv_hash is None: 299 ksft_pr("FAIL: No data received") 300 ret = 1 301 break 302 303 if send_hash.hexdigest() != recv_hash.hexdigest(): 304 ksft_pr("FAIL: Send/recv mismatch") 305 ksft_pr("hash expected:", send_hash.hexdigest()) 306 ksft_pr("hash received:", recv_hash.hexdigest()) 307 ret = 1 308 break 309 310 ksft_pr(f"{sender}/{receiver}: ok") 311 312if ret == 0: 313 ksft_pr("Success") 314 print("ok 1 rds selftest") 315else: 316 print("not ok 1 rds selftest") 317 318ksft_pr(f"Totals: pass:{1-ret} fail:{ret} skip:0") 319sys.exit(ret) 320