xref: /linux/tools/testing/selftests/net/rds/test.py (revision 40286d6379aacfcc053253ef78dc78b09addffda)
1#! /usr/bin/env python3
2# SPDX-License-Identifier: GPL-2.0
3
4import argparse
5import ctypes
6import errno
7import hashlib
8import os
9import select
10import signal
11import socket
12import subprocess
13import sys
14import tempfile
15import shutil
16
17# Allow utils module to be imported from different directory
18this_dir = os.path.dirname(os.path.realpath(__file__))
19sys.path.append(os.path.join(this_dir, "../"))
20from lib.py.utils import ip
21
22libc = ctypes.cdll.LoadLibrary('libc.so.6')
23setns = libc.setns
24
25NET0 = 'net0'
26NET1 = 'net1'
27
28VETH0 = 'veth0'
29VETH1 = 'veth1'
30
31# Helper function for creating a socket inside a network namespace.
32# We need this because otherwise RDS will detect that the two TCP
33# sockets are on the same interface and use the loop transport instead
34# of the TCP transport.
35def netns_socket(netns, *sock_args):
36    """
37    Creates sockets inside of network namespace
38
39    :param netns: the name of the network namespace
40    :param sock_args: socket family and type
41    """
42    u0, u1 = socket.socketpair(socket.AF_UNIX, socket.SOCK_SEQPACKET)
43
44    child = os.fork()
45    if child == 0:
46        # change network namespace
47        with open(f'/var/run/netns/{netns}', encoding='utf-8') as f:
48            try:
49                setns(f.fileno(), 0)
50            except IOError as e:
51                print(e.errno)
52                print(e)
53
54        # create socket in target namespace
55        sock = socket.socket(*sock_args)
56
57        # send resulting socket to parent
58        socket.send_fds(u0, [], [sock.fileno()])
59
60        sys.exit(0)
61
62    # receive socket from child
63    _, fds, _, _ = socket.recv_fds(u1, 0, 1)
64    os.waitpid(child, 0)
65    u0.close()
66    u1.close()
67    return socket.fromfd(fds[0], *sock_args)
68
69def signal_handler(_sig, _frame):
70    """
71    Test timed out signal handler
72    """
73    print('Test timed out')
74    sys.exit(1)
75
76#Parse out command line arguments.  We take an optional
77# timeout parameter and an optional log output folder
78parser = argparse.ArgumentParser(description="init script args",
79                  formatter_class=argparse.ArgumentDefaultsHelpFormatter)
80parser.add_argument("-d", "--logdir", action="store",
81                    help="directory to store logs", default="/tmp")
82parser.add_argument('--timeout', help="timeout to terminate hung test",
83                    type=int, default=0)
84parser.add_argument('-l', '--loss', help="Simulate tcp packet loss",
85                    type=int, default=0)
86parser.add_argument('-c', '--corruption', help="Simulate tcp packet corruption",
87                    type=int, default=0)
88parser.add_argument('-u', '--duplicate', help="Simulate tcp packet duplication",
89                    type=int, default=0)
90args = parser.parse_args()
91logdir=args.logdir
92PACKET_LOSS=str(args.loss)+'%'
93PACKET_CORRUPTION=str(args.corruption)+'%'
94PACKET_DUPLICATE=str(args.duplicate)+'%'
95
96ip(f"netns add {NET0}")
97ip(f"netns add {NET1}")
98ip("link add type veth")
99
100addrs = [
101    # we technically don't need different port numbers, but this will
102    # help identify traffic in the network analyzer
103    ('10.0.0.1', 10000),
104    ('10.0.0.2', 20000),
105]
106
107# move interfaces to separate namespaces so they can no longer be
108# bound directly; this prevents rds from switching over from the tcp
109# transport to the loop transport.
110ip(f"link set {VETH0} netns {NET0} up")
111ip(f"link set {VETH1} netns {NET1} up")
112
113
114
115# add addresses
116ip(f"-n {NET0} addr add {addrs[0][0]}/32 dev {VETH0}")
117ip(f"-n {NET1} addr add {addrs[1][0]}/32 dev {VETH1}")
118
119# add routes
120ip(f"-n {NET0} route add {addrs[1][0]}/32 dev {VETH0}")
121ip(f"-n {NET1} route add {addrs[0][0]}/32 dev {VETH1}")
122
123# sanity check that our two interfaces/addresses are correctly set up
124# and communicating by doing a single ping
125ip(f"netns exec {NET0} ping -c 1 {addrs[1][0]}")
126
127# Start a packet capture on each network
128tcpdump_procs = []
129for net in [NET0, NET1]:
130    pcap = logdir+'/'+net+'.pcap'
131    fd, pcap_tmp = tempfile.mkstemp(suffix=".pcap", prefix=f"{net}-", dir="/tmp")
132    p = subprocess.Popen(
133        ['ip', 'netns', 'exec', net,
134         '/usr/sbin/tcpdump', '-i', 'any', '-w', pcap_tmp])
135    tcpdump_procs.append((p, pcap_tmp, pcap, fd))
136
137# simulate packet loss, duplication and corruption
138for net, iface in [(NET0, VETH0), (NET1, VETH1)]:
139    ip(f"netns exec {net} /usr/sbin/tc qdisc add dev {iface} root netem  \
140         corrupt {PACKET_CORRUPTION} loss {PACKET_LOSS} duplicate  \
141         {PACKET_DUPLICATE}")
142
143# add a timeout
144if args.timeout > 0:
145    signal.alarm(args.timeout)
146    signal.signal(signal.SIGALRM, signal_handler)
147
148sockets = [
149    netns_socket(NET0, socket.AF_RDS, socket.SOCK_SEQPACKET),
150    netns_socket(NET1, socket.AF_RDS, socket.SOCK_SEQPACKET),
151]
152
153for s, addr in zip(sockets, addrs):
154    s.bind(addr)
155    s.setblocking(0)
156
157fileno_to_socket = {
158    s.fileno(): s for s in sockets
159}
160
161addr_to_socket = dict(zip(addrs, sockets))
162
163socket_to_addr = {
164    s: addr for addr, s in zip(addrs, sockets)
165}
166
167send_hashes = {}
168recv_hashes = {}
169
170ep = select.epoll()
171
172for s in sockets:
173    ep.register(s, select.EPOLLRDNORM)
174
175NUM_PACKETS = 50000
176nr_send = 0
177nr_recv = 0
178
179while nr_send < NUM_PACKETS:
180    # Send as much as we can without blocking
181    print("sending...", nr_send, nr_recv)
182    while nr_send < NUM_PACKETS:
183        send_data = hashlib.sha256(
184            f'packet {nr_send}'.encode('utf-8')).hexdigest().encode('utf-8')
185
186        # pseudo-random send/receive pattern
187        sender = sockets[nr_send % 2]
188        receiver = sockets[1 - (nr_send % 3) % 2]
189
190        try:
191            sender.sendto(send_data, socket_to_addr[receiver])
192            send_hashes.setdefault((sender.fileno(), receiver.fileno()),
193                    hashlib.sha256()).update(f'<{send_data}>'.encode('utf-8'))
194            nr_send = nr_send + 1
195        except BlockingIOError as e:
196            break
197        except OSError as e:
198            if e.errno in [errno.ENOBUFS, errno.ECONNRESET, errno.EPIPE]:
199                break
200            raise
201
202    # Receive as much as we can without blocking
203    print("receiving...", nr_send, nr_recv)
204    while nr_recv < nr_send:
205        for fileno, eventmask in ep.poll():
206            receiver = fileno_to_socket[fileno]
207
208            if eventmask & select.EPOLLRDNORM:
209                while True:
210                    try:
211                        recv_data, address = receiver.recvfrom(1024)
212                        sender = addr_to_socket[address]
213                        recv_hashes.setdefault((sender.fileno(),
214                            receiver.fileno()), hashlib.sha256()).update(
215                                    f'<{recv_data}>'.encode('utf-8'))
216                        nr_recv = nr_recv + 1
217                    except BlockingIOError as e:
218                        break
219
220    # exercise net/rds/tcp.c:rds_tcp_sysctl_reset()
221    for net in [NET0, NET1]:
222        ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_rcvbuf=10000")
223        ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_sndbuf=10000")
224
225print("done", nr_send, nr_recv)
226
227# the Python socket module doesn't know these
228RDS_INFO_FIRST = 10000
229RDS_INFO_LAST = 10017
230
231nr_success = 0
232nr_error = 0
233
234for s in sockets:
235    for optname in range(RDS_INFO_FIRST, RDS_INFO_LAST + 1):
236        # Sigh, the Python socket module doesn't allow us to pass
237        # buffer lengths greater than 1024 for some reason. RDS
238        # wants multiple pages.
239        try:
240            s.getsockopt(socket.SOL_RDS, optname, 1024)
241            nr_success = nr_success + 1
242        except OSError as e:
243            nr_error = nr_error + 1
244            if e.errno == errno.ENOSPC:
245                # ignore
246                pass
247
248print(f"getsockopt(): {nr_success}/{nr_error}")
249
250print("Stopping network packet captures")
251for p, pcap_tmp, pcap, fd in tcpdump_procs:
252    p.terminate()
253    p.wait()
254    os.close(fd)
255    shutil.move(pcap_tmp, pcap)
256
257# We're done sending and receiving stuff, now let's check if what
258# we received is what we sent.
259for (sender, receiver), send_hash in send_hashes.items():
260    recv_hash = recv_hashes.get((sender, receiver))
261
262    if recv_hash is None:
263        print("FAIL: No data received")
264        sys.exit(1)
265
266    if send_hash.hexdigest() != recv_hash.hexdigest():
267        print("FAIL: Send/recv mismatch")
268        print("hash expected:", send_hash.hexdigest())
269        print("hash received:", recv_hash.hexdigest())
270        sys.exit(1)
271
272    print(f"{sender}/{receiver}: ok")
273
274print("Success")
275sys.exit(0)
276