xref: /linux/tools/testing/selftests/net/rds/test.py (revision 7d3ab852dcd853692197a830d20052127754a087)
1#! /usr/bin/env python3
2# SPDX-License-Identifier: GPL-2.0
3"""
4This module provides functional testing for the net/rds component.
5"""
6
7import argparse
8import ctypes
9import errno
10import hashlib
11import os
12import select
13import signal
14import socket
15import subprocess
16import sys
17
18# Allow utils module to be imported from different directory
19this_dir = os.path.dirname(os.path.realpath(__file__))
20sys.path.append(os.path.join(this_dir, "../"))
21# pylint: disable-next=wrong-import-position,import-error,no-name-in-module
22from lib.py.utils import ip # noqa: E402
23# pylint: disable-next=wrong-import-position,import-error,no-name-in-module
24from lib.py.ksft import ksft_pr # noqa: E402
25
26libc = ctypes.cdll.LoadLibrary('libc.so.6')
27setns = libc.setns
28
29NET0 = 'net0'
30NET1 = 'net1'
31
32VETH0 = 'veth0'
33VETH1 = 'veth1'
34
35# Helper function for creating a socket inside a network namespace.
36# We need this because otherwise RDS will detect that the two TCP
37# sockets are on the same interface and use the loop transport instead
38# of the TCP transport.
39def netns_socket(netns, *sock_args):
40    """
41    Creates sockets inside of network namespace
42
43    :param netns: the name of the network namespace
44    :param sock_args: socket family and type
45    """
46    u0, u1 = socket.socketpair(socket.AF_UNIX, socket.SOCK_SEQPACKET)
47
48    child = os.fork()
49    if child == 0:
50        # change network namespace
51        with open(f'/var/run/netns/{netns}', encoding='utf-8') as f:
52            try:
53                setns(f.fileno(), 0)
54            except IOError as e:
55                print(e.errno)
56                print(e)
57
58        # create socket in target namespace
59        sock = socket.socket(*sock_args)
60
61        # send resulting socket to parent
62        socket.send_fds(u0, [], [sock.fileno()])
63
64        os._exit(0)
65
66    # receive socket from child
67    _, fds, _, _ = socket.recv_fds(u1, 0, 1)
68    os.waitpid(child, 0)
69    u0.close()
70    u1.close()
71    return socket.fromfd(fds[0], *sock_args)
72
73def stop_pcaps():
74    """Stop tcpdump processes.
75
76    We use pop() here to drain the list in the event that the test
77    completes after the signal handler is fired.  List will be empty
78    if logdir is not set
79    """
80
81    if not tcpdump_procs:
82        return
83
84    ksft_pr("Stopping network packet captures")
85    while tcpdump_procs:
86        proc = tcpdump_procs.pop()
87        proc.terminate()
88        try:
89            proc.wait(timeout=5)
90        except subprocess.TimeoutExpired:
91            proc.kill()
92            proc.wait()
93
94def signal_handler(_sig, _frame):
95    """
96    Test timed out signal handler
97    """
98    ksft_pr("Test timed out")
99    stop_pcaps()
100    print("not ok 1 rds selftest")
101    sys.exit(1)
102
103#Parse out command line arguments.  We take an optional
104# timeout parameter and an optional log output folder
105parser = argparse.ArgumentParser(description="init script args",
106                  formatter_class=argparse.ArgumentDefaultsHelpFormatter)
107parser.add_argument("-d", "--logdir", action="store",
108                    help="directory to store logs", default=None)
109parser.add_argument('-t', '--timeout', help="timeout to terminate hung test",
110                    type=int, default=0)
111parser.add_argument('-l', '--loss', help="Simulate tcp packet loss",
112                    type=int, default=0)
113parser.add_argument('-c', '--corruption', help="Simulate tcp packet corruption",
114                    type=int, default=0)
115parser.add_argument('-u', '--duplicate', help="Simulate tcp packet duplication",
116                    type=int, default=0)
117args = parser.parse_args()
118logdir=args.logdir
119PACKET_LOSS=str(args.loss)+'%'
120PACKET_CORRUPTION=str(args.corruption)+'%'
121PACKET_DUPLICATE=str(args.duplicate)+'%'
122
123ip(f"netns add {NET0}")
124ip(f"netns add {NET1}")
125ip("link add type veth")
126
127addrs = [
128    # we technically don't need different port numbers, but this will
129    # help identify traffic in the network analyzer
130    ('10.0.0.1', 10000),
131    ('10.0.0.2', 20000),
132]
133
134# move interfaces to separate namespaces so they can no longer be
135# bound directly; this prevents rds from switching over from the tcp
136# transport to the loop transport.
137ip(f"link set {VETH0} netns {NET0} up")
138ip(f"link set {VETH1} netns {NET1} up")
139
140
141
142# add addresses
143ip(f"-n {NET0} addr add {addrs[0][0]}/32 dev {VETH0}")
144ip(f"-n {NET1} addr add {addrs[1][0]}/32 dev {VETH1}")
145
146# add routes
147ip(f"-n {NET0} route add {addrs[1][0]}/32 dev {VETH0}")
148ip(f"-n {NET1} route add {addrs[0][0]}/32 dev {VETH1}")
149
150# sanity check that our two interfaces/addresses are correctly set up
151# and communicating by doing a single ping
152ip(f"netns exec {NET0} ping -c 1 {addrs[1][0]}")
153
154tcpdump_procs = []
155# Start a packet capture on each network
156if logdir is not None:
157    for net in [NET0, NET1]:
158        pcap = logdir+'/rds-'+net+'.pcap'
159
160        tcpdump_cmd = ['ip', 'netns', 'exec', net, '/usr/sbin/tcpdump']
161        sudo_user = os.environ.get('SUDO_USER')
162        if sudo_user:
163            tcpdump_cmd.extend(['-Z', sudo_user])
164        tcpdump_cmd.extend(['-i', 'any', '-w', pcap])
165
166        # pylint: disable-next=consider-using-with
167        p = subprocess.Popen(tcpdump_cmd,
168                             stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
169        tcpdump_procs.append(p)
170
171# simulate packet loss, duplication and corruption
172for net, iface in [(NET0, VETH0), (NET1, VETH1)]:
173    ip(f"netns exec {net} /usr/sbin/tc qdisc add dev {iface} root netem  \
174         corrupt {PACKET_CORRUPTION} loss {PACKET_LOSS} duplicate  \
175         {PACKET_DUPLICATE}")
176
177print("TAP version 13")
178print("1..1")
179
180# add a timeout
181if args.timeout > 0:
182    signal.alarm(args.timeout)
183    signal.signal(signal.SIGALRM, signal_handler)
184
185sockets = [
186    netns_socket(NET0, socket.AF_RDS, socket.SOCK_SEQPACKET),
187    netns_socket(NET1, socket.AF_RDS, socket.SOCK_SEQPACKET),
188]
189
190for s, addr in zip(sockets, addrs):
191    s.bind(addr)
192    s.setblocking(0)
193
194fileno_to_socket = {
195    s.fileno(): s for s in sockets
196}
197
198addr_to_socket = dict(zip(addrs, sockets))
199
200socket_to_addr = {
201    s: addr for addr, s in zip(addrs, sockets)
202}
203
204send_hashes = {}
205recv_hashes = {}
206
207ep = select.epoll()
208
209for s in sockets:
210    ep.register(s, select.EPOLLRDNORM)
211
212NUM_PACKETS = 50000
213nr_send = 0
214nr_recv = 0
215
216while nr_send < NUM_PACKETS:
217    # Send as much as we can without blocking
218    ksft_pr("sending...", nr_send, nr_recv)
219    while nr_send < NUM_PACKETS:
220        send_data = hashlib.sha256(
221            f'packet {nr_send}'.encode('utf-8')).hexdigest().encode('utf-8')
222
223        # pseudo-random send/receive pattern
224        sender = sockets[nr_send % 2]
225        receiver = sockets[1 - (nr_send % 3) % 2]
226
227        try:
228            sender.sendto(send_data, socket_to_addr[receiver])
229            send_hashes.setdefault((sender.fileno(), receiver.fileno()),
230                    hashlib.sha256()).update(f'<{send_data}>'.encode('utf-8'))
231            nr_send = nr_send + 1
232        except BlockingIOError:
233            break
234        except OSError as e:
235            if e.errno in [errno.ENOBUFS, errno.ECONNRESET, errno.EPIPE]:
236                break
237            raise
238
239    # Receive as much as we can without blocking
240    ksft_pr("receiving...", nr_send, nr_recv)
241    while nr_recv < nr_send:
242        for fileno, eventmask in ep.poll():
243            receiver = fileno_to_socket[fileno]
244
245            if eventmask & select.EPOLLRDNORM:
246                while True:
247                    try:
248                        recv_data, address = receiver.recvfrom(1024)
249                        sender = addr_to_socket[address]
250                        recv_hashes.setdefault((sender.fileno(),
251                            receiver.fileno()), hashlib.sha256()).update(
252                                    f'<{recv_data}>'.encode('utf-8'))
253                        nr_recv = nr_recv + 1
254                    except BlockingIOError:
255                        break
256
257    # exercise net/rds/tcp.c:rds_tcp_sysctl_reset()
258    for net in [NET0, NET1]:
259        ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_rcvbuf=10000")
260        ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_sndbuf=10000")
261
262ksft_pr("done", nr_send, nr_recv)
263
264# the Python socket module doesn't know these
265RDS_INFO_FIRST = 10000
266RDS_INFO_LAST = 10017
267
268nr_success = 0
269nr_error = 0
270
271for s in sockets:
272    for optname in range(RDS_INFO_FIRST, RDS_INFO_LAST + 1):
273        # Sigh, the Python socket module doesn't allow us to pass
274        # buffer lengths greater than 1024 for some reason. RDS
275        # wants multiple pages.
276        try:
277            s.getsockopt(socket.SOL_RDS, optname, 1024)
278            nr_success = nr_success + 1
279        except OSError as e:
280            nr_error = nr_error + 1
281            if e.errno == errno.ENOSPC:
282                # ignore
283                pass
284
285ksft_pr(f"getsockopt(): {nr_success}/{nr_error}")
286
287# cancel timeout
288signal.alarm(0)
289
290stop_pcaps()
291
292# We're done sending and receiving stuff, now let's check if what
293# we received is what we sent.
294ret = 0
295for (sender, receiver), send_hash in send_hashes.items():
296    recv_hash = recv_hashes.get((sender, receiver))
297
298    if recv_hash is None:
299        ksft_pr("FAIL: No data received")
300        ret = 1
301        break
302
303    if send_hash.hexdigest() != recv_hash.hexdigest():
304        ksft_pr("FAIL: Send/recv mismatch")
305        ksft_pr("hash expected:", send_hash.hexdigest())
306        ksft_pr("hash received:", recv_hash.hexdigest())
307        ret = 1
308        break
309
310    ksft_pr(f"{sender}/{receiver}: ok")
311
312if ret == 0:
313    ksft_pr("Success")
314    print("ok 1 rds selftest")
315else:
316    print("not ok 1 rds selftest")
317
318ksft_pr(f"Totals: pass:{1-ret} fail:{ret} skip:0")
319sys.exit(ret)
320