#! /usr/bin/env python3 # SPDX-License-Identifier: GPL-2.0 import argparse import ctypes import errno import hashlib import os import select import signal import socket import subprocess import sys import atexit from pwd import getpwuid from os import stat from lib.py import ip libc = ctypes.cdll.LoadLibrary('libc.so.6') setns = libc.setns net0 = 'net0' net1 = 'net1' veth0 = 'veth0' veth1 = 'veth1' # Helper function for creating a socket inside a network namespace. # We need this because otherwise RDS will detect that the two TCP # sockets are on the same interface and use the loop transport instead # of the TCP transport. def netns_socket(netns, *args): u0, u1 = socket.socketpair(socket.AF_UNIX, socket.SOCK_SEQPACKET) child = os.fork() if child == 0: # change network namespace with open(f'/var/run/netns/{netns}') as f: try: ret = setns(f.fileno(), 0) except IOError as e: print(e.errno) print(e) # create socket in target namespace s = socket.socket(*args) # send resulting socket to parent socket.send_fds(u0, [], [s.fileno()]) sys.exit(0) # receive socket from child _, s, _, _ = socket.recv_fds(u1, 0, 1) os.waitpid(child, 0) u0.close() u1.close() return socket.fromfd(s[0], *args) def signal_handler(sig, frame): print('Test timed out') sys.exit(1) #Parse out command line arguments. We take an optional # timeout parameter and an optional log output folder parser = argparse.ArgumentParser(description="init script args", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("-d", "--logdir", action="store", help="directory to store logs", default="/tmp") parser.add_argument('--timeout', help="timeout to terminate hung test", type=int, default=0) parser.add_argument('-l', '--loss', help="Simulate tcp packet loss", type=int, default=0) parser.add_argument('-c', '--corruption', help="Simulate tcp packet corruption", type=int, default=0) parser.add_argument('-u', '--duplicate', help="Simulate tcp packet duplication", type=int, default=0) args = parser.parse_args() logdir=args.logdir packet_loss=str(args.loss)+'%' packet_corruption=str(args.corruption)+'%' packet_duplicate=str(args.duplicate)+'%' ip(f"netns add {net0}") ip(f"netns add {net1}") ip(f"link add type veth") addrs = [ # we technically don't need different port numbers, but this will # help identify traffic in the network analyzer ('10.0.0.1', 10000), ('10.0.0.2', 20000), ] # move interfaces to separate namespaces so they can no longer be # bound directly; this prevents rds from switching over from the tcp # transport to the loop transport. ip(f"link set {veth0} netns {net0} up") ip(f"link set {veth1} netns {net1} up") # add addresses ip(f"-n {net0} addr add {addrs[0][0]}/32 dev {veth0}") ip(f"-n {net1} addr add {addrs[1][0]}/32 dev {veth1}") # add routes ip(f"-n {net0} route add {addrs[1][0]}/32 dev {veth0}") ip(f"-n {net1} route add {addrs[0][0]}/32 dev {veth1}") # sanity check that our two interfaces/addresses are correctly set up # and communicating by doing a single ping ip(f"netns exec {net0} ping -c 1 {addrs[1][0]}") # Start a packet capture on each network for net in [net0, net1]: tcpdump_pid = os.fork() if tcpdump_pid == 0: pcap = logdir+'/'+net+'.pcap' subprocess.check_call(['touch', pcap]) user = getpwuid(stat(pcap).st_uid).pw_name ip(f"netns exec {net} /usr/sbin/tcpdump -Z {user} -i any -w {pcap}") sys.exit(0) # simulate packet loss, duplication and corruption for net, iface in [(net0, veth0), (net1, veth1)]: ip(f"netns exec {net} /usr/sbin/tc qdisc add dev {iface} root netem \ corrupt {packet_corruption} loss {packet_loss} duplicate \ {packet_duplicate}") # add a timeout if args.timeout > 0: signal.alarm(args.timeout) signal.signal(signal.SIGALRM, signal_handler) sockets = [ netns_socket(net0, socket.AF_RDS, socket.SOCK_SEQPACKET), netns_socket(net1, socket.AF_RDS, socket.SOCK_SEQPACKET), ] for s, addr in zip(sockets, addrs): s.bind(addr) s.setblocking(0) fileno_to_socket = { s.fileno(): s for s in sockets } addr_to_socket = { addr: s for addr, s in zip(addrs, sockets) } socket_to_addr = { s: addr for addr, s in zip(addrs, sockets) } send_hashes = {} recv_hashes = {} ep = select.epoll() for s in sockets: ep.register(s, select.EPOLLRDNORM) n = 50000 nr_send = 0 nr_recv = 0 while nr_send < n: # Send as much as we can without blocking print("sending...", nr_send, nr_recv) while nr_send < n: send_data = hashlib.sha256( f'packet {nr_send}'.encode('utf-8')).hexdigest().encode('utf-8') # pseudo-random send/receive pattern sender = sockets[nr_send % 2] receiver = sockets[1 - (nr_send % 3) % 2] try: sender.sendto(send_data, socket_to_addr[receiver]) send_hashes.setdefault((sender.fileno(), receiver.fileno()), hashlib.sha256()).update(f'<{send_data}>'.encode('utf-8')) nr_send = nr_send + 1 except BlockingIOError as e: break except OSError as e: if e.errno in [errno.ENOBUFS, errno.ECONNRESET, errno.EPIPE]: break raise # Receive as much as we can without blocking print("receiving...", nr_send, nr_recv) while nr_recv < nr_send: for fileno, eventmask in ep.poll(): receiver = fileno_to_socket[fileno] if eventmask & select.EPOLLRDNORM: while True: try: recv_data, address = receiver.recvfrom(1024) sender = addr_to_socket[address] recv_hashes.setdefault((sender.fileno(), receiver.fileno()), hashlib.sha256()).update( f'<{recv_data}>'.encode('utf-8')) nr_recv = nr_recv + 1 except BlockingIOError as e: break # exercise net/rds/tcp.c:rds_tcp_sysctl_reset() for net in [net0, net1]: ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_rcvbuf=10000") ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_sndbuf=10000") print("done", nr_send, nr_recv) # the Python socket module doesn't know these RDS_INFO_FIRST = 10000 RDS_INFO_LAST = 10017 nr_success = 0 nr_error = 0 for s in sockets: for optname in range(RDS_INFO_FIRST, RDS_INFO_LAST + 1): # Sigh, the Python socket module doesn't allow us to pass # buffer lengths greater than 1024 for some reason. RDS # wants multiple pages. try: s.getsockopt(socket.SOL_RDS, optname, 1024) nr_success = nr_success + 1 except OSError as e: nr_error = nr_error + 1 if e.errno == errno.ENOSPC: # ignore pass print(f"getsockopt(): {nr_success}/{nr_error}") print("Stopping network packet captures") subprocess.check_call(['killall', '-q', 'tcpdump']) # We're done sending and receiving stuff, now let's check if what # we received is what we sent. for (sender, receiver), send_hash in send_hashes.items(): recv_hash = recv_hashes.get((sender, receiver)) if recv_hash is None: print("FAIL: No data received") sys.exit(1) if send_hash.hexdigest() != recv_hash.hexdigest(): print("FAIL: Send/recv mismatch") print("hash expected:", send_hash.hexdigest()) print("hash received:", recv_hash.hexdigest()) sys.exit(1) print(f"{sender}/{receiver}: ok") print("Success") sys.exit(0)