xref: /linux/tools/testing/selftests/net/rds/test.py (revision a3f143c461444c0b56360bbf468615fa814a8372)
1#! /usr/bin/env python3
2# SPDX-License-Identifier: GPL-2.0
3
4import argparse
5import ctypes
6import errno
7import hashlib
8import os
9import select
10import signal
11import socket
12import subprocess
13import sys
14import atexit
15from pwd import getpwuid
16from os import stat
17
18# Allow utils module to be imported from different directory
19this_dir = os.path.dirname(os.path.realpath(__file__))
20sys.path.append(os.path.join(this_dir, "../"))
21from lib.py.utils import ip
22
23libc = ctypes.cdll.LoadLibrary('libc.so.6')
24setns = libc.setns
25
26net0 = 'net0'
27net1 = 'net1'
28
29veth0 = 'veth0'
30veth1 = 'veth1'
31
32# Helper function for creating a socket inside a network namespace.
33# We need this because otherwise RDS will detect that the two TCP
34# sockets are on the same interface and use the loop transport instead
35# of the TCP transport.
36def netns_socket(netns, *args):
37    u0, u1 = socket.socketpair(socket.AF_UNIX, socket.SOCK_SEQPACKET)
38
39    child = os.fork()
40    if child == 0:
41        # change network namespace
42        with open(f'/var/run/netns/{netns}') as f:
43            try:
44                ret = setns(f.fileno(), 0)
45            except IOError as e:
46                print(e.errno)
47                print(e)
48
49        # create socket in target namespace
50        s = socket.socket(*args)
51
52        # send resulting socket to parent
53        socket.send_fds(u0, [], [s.fileno()])
54
55        sys.exit(0)
56
57    # receive socket from child
58    _, s, _, _ = socket.recv_fds(u1, 0, 1)
59    os.waitpid(child, 0)
60    u0.close()
61    u1.close()
62    return socket.fromfd(s[0], *args)
63
64def signal_handler(sig, frame):
65    print('Test timed out')
66    sys.exit(1)
67
68#Parse out command line arguments.  We take an optional
69# timeout parameter and an optional log output folder
70parser = argparse.ArgumentParser(description="init script args",
71                  formatter_class=argparse.ArgumentDefaultsHelpFormatter)
72parser.add_argument("-d", "--logdir", action="store",
73                    help="directory to store logs", default="/tmp")
74parser.add_argument('--timeout', help="timeout to terminate hung test",
75                    type=int, default=0)
76parser.add_argument('-l', '--loss', help="Simulate tcp packet loss",
77                    type=int, default=0)
78parser.add_argument('-c', '--corruption', help="Simulate tcp packet corruption",
79                    type=int, default=0)
80parser.add_argument('-u', '--duplicate', help="Simulate tcp packet duplication",
81                    type=int, default=0)
82args = parser.parse_args()
83logdir=args.logdir
84packet_loss=str(args.loss)+'%'
85packet_corruption=str(args.corruption)+'%'
86packet_duplicate=str(args.duplicate)+'%'
87
88ip(f"netns add {net0}")
89ip(f"netns add {net1}")
90ip(f"link add type veth")
91
92addrs = [
93    # we technically don't need different port numbers, but this will
94    # help identify traffic in the network analyzer
95    ('10.0.0.1', 10000),
96    ('10.0.0.2', 20000),
97]
98
99# move interfaces to separate namespaces so they can no longer be
100# bound directly; this prevents rds from switching over from the tcp
101# transport to the loop transport.
102ip(f"link set {veth0} netns {net0} up")
103ip(f"link set {veth1} netns {net1} up")
104
105
106
107# add addresses
108ip(f"-n {net0} addr add {addrs[0][0]}/32 dev {veth0}")
109ip(f"-n {net1} addr add {addrs[1][0]}/32 dev {veth1}")
110
111# add routes
112ip(f"-n {net0} route add {addrs[1][0]}/32 dev {veth0}")
113ip(f"-n {net1} route add {addrs[0][0]}/32 dev {veth1}")
114
115# sanity check that our two interfaces/addresses are correctly set up
116# and communicating by doing a single ping
117ip(f"netns exec {net0} ping -c 1 {addrs[1][0]}")
118
119# Start a packet capture on each network
120for net in [net0, net1]:
121    tcpdump_pid = os.fork()
122    if tcpdump_pid == 0:
123        pcap = logdir+'/'+net+'.pcap'
124        subprocess.check_call(['touch', pcap])
125        user = getpwuid(stat(pcap).st_uid).pw_name
126        ip(f"netns exec {net} /usr/sbin/tcpdump -Z {user} -i any -w {pcap}")
127        sys.exit(0)
128
129# simulate packet loss, duplication and corruption
130for net, iface in [(net0, veth0), (net1, veth1)]:
131    ip(f"netns exec {net} /usr/sbin/tc qdisc add dev {iface} root netem  \
132         corrupt {packet_corruption} loss {packet_loss} duplicate  \
133         {packet_duplicate}")
134
135# add a timeout
136if args.timeout > 0:
137    signal.alarm(args.timeout)
138    signal.signal(signal.SIGALRM, signal_handler)
139
140sockets = [
141    netns_socket(net0, socket.AF_RDS, socket.SOCK_SEQPACKET),
142    netns_socket(net1, socket.AF_RDS, socket.SOCK_SEQPACKET),
143]
144
145for s, addr in zip(sockets, addrs):
146    s.bind(addr)
147    s.setblocking(0)
148
149fileno_to_socket = {
150    s.fileno(): s for s in sockets
151}
152
153addr_to_socket = {
154    addr: s for addr, s in zip(addrs, sockets)
155}
156
157socket_to_addr = {
158    s: addr for addr, s in zip(addrs, sockets)
159}
160
161send_hashes = {}
162recv_hashes = {}
163
164ep = select.epoll()
165
166for s in sockets:
167    ep.register(s, select.EPOLLRDNORM)
168
169n = 50000
170nr_send = 0
171nr_recv = 0
172
173while nr_send < n:
174    # Send as much as we can without blocking
175    print("sending...", nr_send, nr_recv)
176    while nr_send < n:
177        send_data = hashlib.sha256(
178            f'packet {nr_send}'.encode('utf-8')).hexdigest().encode('utf-8')
179
180        # pseudo-random send/receive pattern
181        sender = sockets[nr_send % 2]
182        receiver = sockets[1 - (nr_send % 3) % 2]
183
184        try:
185            sender.sendto(send_data, socket_to_addr[receiver])
186            send_hashes.setdefault((sender.fileno(), receiver.fileno()),
187                    hashlib.sha256()).update(f'<{send_data}>'.encode('utf-8'))
188            nr_send = nr_send + 1
189        except BlockingIOError as e:
190            break
191        except OSError as e:
192            if e.errno in [errno.ENOBUFS, errno.ECONNRESET, errno.EPIPE]:
193                break
194            raise
195
196    # Receive as much as we can without blocking
197    print("receiving...", nr_send, nr_recv)
198    while nr_recv < nr_send:
199        for fileno, eventmask in ep.poll():
200            receiver = fileno_to_socket[fileno]
201
202            if eventmask & select.EPOLLRDNORM:
203                while True:
204                    try:
205                        recv_data, address = receiver.recvfrom(1024)
206                        sender = addr_to_socket[address]
207                        recv_hashes.setdefault((sender.fileno(),
208                            receiver.fileno()), hashlib.sha256()).update(
209                                    f'<{recv_data}>'.encode('utf-8'))
210                        nr_recv = nr_recv + 1
211                    except BlockingIOError as e:
212                        break
213
214    # exercise net/rds/tcp.c:rds_tcp_sysctl_reset()
215    for net in [net0, net1]:
216        ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_rcvbuf=10000")
217        ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_sndbuf=10000")
218
219print("done", nr_send, nr_recv)
220
221# the Python socket module doesn't know these
222RDS_INFO_FIRST = 10000
223RDS_INFO_LAST = 10017
224
225nr_success = 0
226nr_error = 0
227
228for s in sockets:
229    for optname in range(RDS_INFO_FIRST, RDS_INFO_LAST + 1):
230        # Sigh, the Python socket module doesn't allow us to pass
231        # buffer lengths greater than 1024 for some reason. RDS
232        # wants multiple pages.
233        try:
234            s.getsockopt(socket.SOL_RDS, optname, 1024)
235            nr_success = nr_success + 1
236        except OSError as e:
237            nr_error = nr_error + 1
238            if e.errno == errno.ENOSPC:
239                # ignore
240                pass
241
242print(f"getsockopt(): {nr_success}/{nr_error}")
243
244print("Stopping network packet captures")
245subprocess.check_call(['killall', '-q', 'tcpdump'])
246
247# We're done sending and receiving stuff, now let's check if what
248# we received is what we sent.
249for (sender, receiver), send_hash in send_hashes.items():
250    recv_hash = recv_hashes.get((sender, receiver))
251
252    if recv_hash is None:
253        print("FAIL: No data received")
254        sys.exit(1)
255
256    if send_hash.hexdigest() != recv_hash.hexdigest():
257        print("FAIL: Send/recv mismatch")
258        print("hash expected:", send_hash.hexdigest())
259        print("hash received:", recv_hash.hexdigest())
260        sys.exit(1)
261
262    print(f"{sender}/{receiver}: ok")
263
264print("Success")
265sys.exit(0)
266