xref: /linux/tools/testing/selftests/net/rds/test.py (revision 566ab427f827b0256d3e8ce0235d088e6a9c28bd)
1#! /usr/bin/env python3
2# SPDX-License-Identifier: GPL-2.0
3
4import argparse
5import ctypes
6import errno
7import hashlib
8import os
9import select
10import signal
11import socket
12import subprocess
13import sys
14import atexit
15from pwd import getpwuid
16from os import stat
17from lib.py import ip
18
19
20libc = ctypes.cdll.LoadLibrary('libc.so.6')
21setns = libc.setns
22
23net0 = 'net0'
24net1 = 'net1'
25
26veth0 = 'veth0'
27veth1 = 'veth1'
28
29# Helper function for creating a socket inside a network namespace.
30# We need this because otherwise RDS will detect that the two TCP
31# sockets are on the same interface and use the loop transport instead
32# of the TCP transport.
33def netns_socket(netns, *args):
34    u0, u1 = socket.socketpair(socket.AF_UNIX, socket.SOCK_SEQPACKET)
35
36    child = os.fork()
37    if child == 0:
38        # change network namespace
39        with open(f'/var/run/netns/{netns}') as f:
40            try:
41                ret = setns(f.fileno(), 0)
42            except IOError as e:
43                print(e.errno)
44                print(e)
45
46        # create socket in target namespace
47        s = socket.socket(*args)
48
49        # send resulting socket to parent
50        socket.send_fds(u0, [], [s.fileno()])
51
52        sys.exit(0)
53
54    # receive socket from child
55    _, s, _, _ = socket.recv_fds(u1, 0, 1)
56    os.waitpid(child, 0)
57    u0.close()
58    u1.close()
59    return socket.fromfd(s[0], *args)
60
61def signal_handler(sig, frame):
62    print('Test timed out')
63    sys.exit(1)
64
65#Parse out command line arguments.  We take an optional
66# timeout parameter and an optional log output folder
67parser = argparse.ArgumentParser(description="init script args",
68                  formatter_class=argparse.ArgumentDefaultsHelpFormatter)
69parser.add_argument("-d", "--logdir", action="store",
70                    help="directory to store logs", default="/tmp")
71parser.add_argument('--timeout', help="timeout to terminate hung test",
72                    type=int, default=0)
73parser.add_argument('-l', '--loss', help="Simulate tcp packet loss",
74                    type=int, default=0)
75parser.add_argument('-c', '--corruption', help="Simulate tcp packet corruption",
76                    type=int, default=0)
77parser.add_argument('-u', '--duplicate', help="Simulate tcp packet duplication",
78                    type=int, default=0)
79args = parser.parse_args()
80logdir=args.logdir
81packet_loss=str(args.loss)+'%'
82packet_corruption=str(args.corruption)+'%'
83packet_duplicate=str(args.duplicate)+'%'
84
85ip(f"netns add {net0}")
86ip(f"netns add {net1}")
87ip(f"link add type veth")
88
89addrs = [
90    # we technically don't need different port numbers, but this will
91    # help identify traffic in the network analyzer
92    ('10.0.0.1', 10000),
93    ('10.0.0.2', 20000),
94]
95
96# move interfaces to separate namespaces so they can no longer be
97# bound directly; this prevents rds from switching over from the tcp
98# transport to the loop transport.
99ip(f"link set {veth0} netns {net0} up")
100ip(f"link set {veth1} netns {net1} up")
101
102
103
104# add addresses
105ip(f"-n {net0} addr add {addrs[0][0]}/32 dev {veth0}")
106ip(f"-n {net1} addr add {addrs[1][0]}/32 dev {veth1}")
107
108# add routes
109ip(f"-n {net0} route add {addrs[1][0]}/32 dev {veth0}")
110ip(f"-n {net1} route add {addrs[0][0]}/32 dev {veth1}")
111
112# sanity check that our two interfaces/addresses are correctly set up
113# and communicating by doing a single ping
114ip(f"netns exec {net0} ping -c 1 {addrs[1][0]}")
115
116# Start a packet capture on each network
117for net in [net0, net1]:
118    tcpdump_pid = os.fork()
119    if tcpdump_pid == 0:
120        pcap = logdir+'/'+net+'.pcap'
121        subprocess.check_call(['touch', pcap])
122        user = getpwuid(stat(pcap).st_uid).pw_name
123        ip(f"netns exec {net} /usr/sbin/tcpdump -Z {user} -i any -w {pcap}")
124        sys.exit(0)
125
126# simulate packet loss, duplication and corruption
127for net, iface in [(net0, veth0), (net1, veth1)]:
128    ip(f"netns exec {net} /usr/sbin/tc qdisc add dev {iface} root netem  \
129         corrupt {packet_corruption} loss {packet_loss} duplicate  \
130         {packet_duplicate}")
131
132# add a timeout
133if args.timeout > 0:
134    signal.alarm(args.timeout)
135    signal.signal(signal.SIGALRM, signal_handler)
136
137sockets = [
138    netns_socket(net0, socket.AF_RDS, socket.SOCK_SEQPACKET),
139    netns_socket(net1, socket.AF_RDS, socket.SOCK_SEQPACKET),
140]
141
142for s, addr in zip(sockets, addrs):
143    s.bind(addr)
144    s.setblocking(0)
145
146fileno_to_socket = {
147    s.fileno(): s for s in sockets
148}
149
150addr_to_socket = {
151    addr: s for addr, s in zip(addrs, sockets)
152}
153
154socket_to_addr = {
155    s: addr for addr, s in zip(addrs, sockets)
156}
157
158send_hashes = {}
159recv_hashes = {}
160
161ep = select.epoll()
162
163for s in sockets:
164    ep.register(s, select.EPOLLRDNORM)
165
166n = 50000
167nr_send = 0
168nr_recv = 0
169
170while nr_send < n:
171    # Send as much as we can without blocking
172    print("sending...", nr_send, nr_recv)
173    while nr_send < n:
174        send_data = hashlib.sha256(
175            f'packet {nr_send}'.encode('utf-8')).hexdigest().encode('utf-8')
176
177        # pseudo-random send/receive pattern
178        sender = sockets[nr_send % 2]
179        receiver = sockets[1 - (nr_send % 3) % 2]
180
181        try:
182            sender.sendto(send_data, socket_to_addr[receiver])
183            send_hashes.setdefault((sender.fileno(), receiver.fileno()),
184                    hashlib.sha256()).update(f'<{send_data}>'.encode('utf-8'))
185            nr_send = nr_send + 1
186        except BlockingIOError as e:
187            break
188        except OSError as e:
189            if e.errno in [errno.ENOBUFS, errno.ECONNRESET, errno.EPIPE]:
190                break
191            raise
192
193    # Receive as much as we can without blocking
194    print("receiving...", nr_send, nr_recv)
195    while nr_recv < nr_send:
196        for fileno, eventmask in ep.poll():
197            receiver = fileno_to_socket[fileno]
198
199            if eventmask & select.EPOLLRDNORM:
200                while True:
201                    try:
202                        recv_data, address = receiver.recvfrom(1024)
203                        sender = addr_to_socket[address]
204                        recv_hashes.setdefault((sender.fileno(),
205                            receiver.fileno()), hashlib.sha256()).update(
206                                    f'<{recv_data}>'.encode('utf-8'))
207                        nr_recv = nr_recv + 1
208                    except BlockingIOError as e:
209                        break
210
211    # exercise net/rds/tcp.c:rds_tcp_sysctl_reset()
212    for net in [net0, net1]:
213        ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_rcvbuf=10000")
214        ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_sndbuf=10000")
215
216print("done", nr_send, nr_recv)
217
218# the Python socket module doesn't know these
219RDS_INFO_FIRST = 10000
220RDS_INFO_LAST = 10017
221
222nr_success = 0
223nr_error = 0
224
225for s in sockets:
226    for optname in range(RDS_INFO_FIRST, RDS_INFO_LAST + 1):
227        # Sigh, the Python socket module doesn't allow us to pass
228        # buffer lengths greater than 1024 for some reason. RDS
229        # wants multiple pages.
230        try:
231            s.getsockopt(socket.SOL_RDS, optname, 1024)
232            nr_success = nr_success + 1
233        except OSError as e:
234            nr_error = nr_error + 1
235            if e.errno == errno.ENOSPC:
236                # ignore
237                pass
238
239print(f"getsockopt(): {nr_success}/{nr_error}")
240
241print("Stopping network packet captures")
242subprocess.check_call(['killall', '-q', 'tcpdump'])
243
244# We're done sending and receiving stuff, now let's check if what
245# we received is what we sent.
246for (sender, receiver), send_hash in send_hashes.items():
247    recv_hash = recv_hashes.get((sender, receiver))
248
249    if recv_hash is None:
250        print("FAIL: No data received")
251        sys.exit(1)
252
253    if send_hash.hexdigest() != recv_hash.hexdigest():
254        print("FAIL: Send/recv mismatch")
255        print("hash expected:", send_hash.hexdigest())
256        print("hash received:", recv_hash.hexdigest())
257        sys.exit(1)
258
259    print(f"{sender}/{receiver}: ok")
260
261print("Success")
262sys.exit(0)
263