xref: /linux/tools/testing/selftests/net/rds/test.py (revision 6a4c4656b0d2d4056a1f0c35442db4e8a5cf8021)
1#! /usr/bin/env python3
2# SPDX-License-Identifier: GPL-2.0
3"""
4This module provides functional testing for the net/rds component.
5"""
6
7import argparse
8import ctypes
9import errno
10import hashlib
11import os
12import select
13import signal
14import socket
15import subprocess
16import sys
17
18# Allow utils module to be imported from different directory
19this_dir = os.path.dirname(os.path.realpath(__file__))
20sys.path.append(os.path.join(this_dir, "../"))
21# pylint: disable-next=wrong-import-position,import-error,no-name-in-module
22from lib.py.utils import ip # noqa: E402
23# pylint: disable-next=wrong-import-position,import-error,no-name-in-module
24from lib.py.ksft import ksft_pr # noqa: E402
25
26libc = ctypes.cdll.LoadLibrary('libc.so.6')
27setns = libc.setns
28
29NET0 = 'net0'
30NET1 = 'net1'
31
32VETH0 = 'veth0'
33VETH1 = 'veth1'
34
35# Helper function for creating a socket inside a network namespace.
36# We need this because otherwise RDS will detect that the two TCP
37# sockets are on the same interface and use the loop transport instead
38# of the TCP transport.
39def netns_socket(netns, *sock_args):
40    """
41    Creates sockets inside of network namespace
42
43    :param netns: the name of the network namespace
44    :param sock_args: socket family and type
45    """
46    u0, u1 = socket.socketpair(socket.AF_UNIX, socket.SOCK_SEQPACKET)
47
48    child = os.fork()
49    if child == 0:
50        # change network namespace
51        with open(f'/var/run/netns/{netns}', encoding='utf-8') as f:
52            try:
53                setns(f.fileno(), 0)
54            except IOError as e:
55                print(e.errno)
56                print(e)
57
58        # create socket in target namespace
59        sock = socket.socket(*sock_args)
60
61        # send resulting socket to parent
62        socket.send_fds(u0, [], [sock.fileno()])
63
64        os._exit(0)
65
66    # receive socket from child
67    _, fds, _, _ = socket.recv_fds(u1, 0, 1)
68    os.waitpid(child, 0)
69    u0.close()
70    u1.close()
71    return socket.fromfd(fds[0], *sock_args)
72
73def stop_pcaps():
74    """Stop tcpdump processes.
75
76    We use pop() here to drain the list in the event that the test
77    completes after the signal handler is fired.  List will be empty
78    if logdir is not set
79    """
80    ksft_pr("Stopping network packet captures")
81    while tcpdump_procs:
82        proc = tcpdump_procs.pop()
83        proc.terminate()
84        try:
85            proc.wait(timeout=5)
86        except subprocess.TimeoutExpired:
87            proc.kill()
88            proc.wait()
89
90def signal_handler(_sig, _frame):
91    """
92    Test timed out signal handler
93    """
94    ksft_pr("Test timed out")
95    stop_pcaps()
96    print("not ok 1 rds selftest")
97    sys.exit(1)
98
99#Parse out command line arguments.  We take an optional
100# timeout parameter and an optional log output folder
101parser = argparse.ArgumentParser(description="init script args",
102                  formatter_class=argparse.ArgumentDefaultsHelpFormatter)
103parser.add_argument("-d", "--logdir", action="store",
104                    help="directory to store logs", default=None)
105parser.add_argument('-t', '--timeout', help="timeout to terminate hung test",
106                    type=int, default=0)
107parser.add_argument('-l', '--loss', help="Simulate tcp packet loss",
108                    type=int, default=0)
109parser.add_argument('-c', '--corruption', help="Simulate tcp packet corruption",
110                    type=int, default=0)
111parser.add_argument('-u', '--duplicate', help="Simulate tcp packet duplication",
112                    type=int, default=0)
113args = parser.parse_args()
114logdir=args.logdir
115PACKET_LOSS=str(args.loss)+'%'
116PACKET_CORRUPTION=str(args.corruption)+'%'
117PACKET_DUPLICATE=str(args.duplicate)+'%'
118
119ip(f"netns add {NET0}")
120ip(f"netns add {NET1}")
121ip("link add type veth")
122
123addrs = [
124    # we technically don't need different port numbers, but this will
125    # help identify traffic in the network analyzer
126    ('10.0.0.1', 10000),
127    ('10.0.0.2', 20000),
128]
129
130# move interfaces to separate namespaces so they can no longer be
131# bound directly; this prevents rds from switching over from the tcp
132# transport to the loop transport.
133ip(f"link set {VETH0} netns {NET0} up")
134ip(f"link set {VETH1} netns {NET1} up")
135
136
137
138# add addresses
139ip(f"-n {NET0} addr add {addrs[0][0]}/32 dev {VETH0}")
140ip(f"-n {NET1} addr add {addrs[1][0]}/32 dev {VETH1}")
141
142# add routes
143ip(f"-n {NET0} route add {addrs[1][0]}/32 dev {VETH0}")
144ip(f"-n {NET1} route add {addrs[0][0]}/32 dev {VETH1}")
145
146# sanity check that our two interfaces/addresses are correctly set up
147# and communicating by doing a single ping
148ip(f"netns exec {NET0} ping -c 1 {addrs[1][0]}")
149
150tcpdump_procs = []
151# Start a packet capture on each network
152if logdir is not None:
153    for net in [NET0, NET1]:
154        pcap = logdir+'/'+net+'.pcap'
155
156        tcpdump_cmd = ['ip', 'netns', 'exec', net, '/usr/sbin/tcpdump']
157        sudo_user = os.environ.get('SUDO_USER')
158        if sudo_user:
159            tcpdump_cmd.extend(['-Z', sudo_user])
160        tcpdump_cmd.extend(['-i', 'any', '-w', pcap])
161
162        # pylint: disable-next=consider-using-with
163        p = subprocess.Popen(tcpdump_cmd,
164                             stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
165        tcpdump_procs.append(p)
166
167# simulate packet loss, duplication and corruption
168for net, iface in [(NET0, VETH0), (NET1, VETH1)]:
169    ip(f"netns exec {net} /usr/sbin/tc qdisc add dev {iface} root netem  \
170         corrupt {PACKET_CORRUPTION} loss {PACKET_LOSS} duplicate  \
171         {PACKET_DUPLICATE}")
172
173print("TAP version 13")
174print("1..1")
175
176# add a timeout
177if args.timeout > 0:
178    signal.alarm(args.timeout)
179    signal.signal(signal.SIGALRM, signal_handler)
180
181sockets = [
182    netns_socket(NET0, socket.AF_RDS, socket.SOCK_SEQPACKET),
183    netns_socket(NET1, socket.AF_RDS, socket.SOCK_SEQPACKET),
184]
185
186for s, addr in zip(sockets, addrs):
187    s.bind(addr)
188    s.setblocking(0)
189
190fileno_to_socket = {
191    s.fileno(): s for s in sockets
192}
193
194addr_to_socket = dict(zip(addrs, sockets))
195
196socket_to_addr = {
197    s: addr for addr, s in zip(addrs, sockets)
198}
199
200send_hashes = {}
201recv_hashes = {}
202
203ep = select.epoll()
204
205for s in sockets:
206    ep.register(s, select.EPOLLRDNORM)
207
208NUM_PACKETS = 50000
209nr_send = 0
210nr_recv = 0
211
212while nr_send < NUM_PACKETS:
213    # Send as much as we can without blocking
214    ksft_pr("sending...", nr_send, nr_recv)
215    while nr_send < NUM_PACKETS:
216        send_data = hashlib.sha256(
217            f'packet {nr_send}'.encode('utf-8')).hexdigest().encode('utf-8')
218
219        # pseudo-random send/receive pattern
220        sender = sockets[nr_send % 2]
221        receiver = sockets[1 - (nr_send % 3) % 2]
222
223        try:
224            sender.sendto(send_data, socket_to_addr[receiver])
225            send_hashes.setdefault((sender.fileno(), receiver.fileno()),
226                    hashlib.sha256()).update(f'<{send_data}>'.encode('utf-8'))
227            nr_send = nr_send + 1
228        except BlockingIOError:
229            break
230        except OSError as e:
231            if e.errno in [errno.ENOBUFS, errno.ECONNRESET, errno.EPIPE]:
232                break
233            raise
234
235    # Receive as much as we can without blocking
236    ksft_pr("receiving...", nr_send, nr_recv)
237    while nr_recv < nr_send:
238        for fileno, eventmask in ep.poll():
239            receiver = fileno_to_socket[fileno]
240
241            if eventmask & select.EPOLLRDNORM:
242                while True:
243                    try:
244                        recv_data, address = receiver.recvfrom(1024)
245                        sender = addr_to_socket[address]
246                        recv_hashes.setdefault((sender.fileno(),
247                            receiver.fileno()), hashlib.sha256()).update(
248                                    f'<{recv_data}>'.encode('utf-8'))
249                        nr_recv = nr_recv + 1
250                    except BlockingIOError:
251                        break
252
253    # exercise net/rds/tcp.c:rds_tcp_sysctl_reset()
254    for net in [NET0, NET1]:
255        ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_rcvbuf=10000")
256        ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_sndbuf=10000")
257
258ksft_pr("done", nr_send, nr_recv)
259
260# the Python socket module doesn't know these
261RDS_INFO_FIRST = 10000
262RDS_INFO_LAST = 10017
263
264nr_success = 0
265nr_error = 0
266
267for s in sockets:
268    for optname in range(RDS_INFO_FIRST, RDS_INFO_LAST + 1):
269        # Sigh, the Python socket module doesn't allow us to pass
270        # buffer lengths greater than 1024 for some reason. RDS
271        # wants multiple pages.
272        try:
273            s.getsockopt(socket.SOL_RDS, optname, 1024)
274            nr_success = nr_success + 1
275        except OSError as e:
276            nr_error = nr_error + 1
277            if e.errno == errno.ENOSPC:
278                # ignore
279                pass
280
281ksft_pr(f"getsockopt(): {nr_success}/{nr_error}")
282stop_pcaps()
283
284# We're done sending and receiving stuff, now let's check if what
285# we received is what we sent.
286ret = 0
287for (sender, receiver), send_hash in send_hashes.items():
288    recv_hash = recv_hashes.get((sender, receiver))
289
290    if recv_hash is None:
291        ksft_pr("FAIL: No data received")
292        ret = 1
293        break
294
295    if send_hash.hexdigest() != recv_hash.hexdigest():
296        ksft_pr("FAIL: Send/recv mismatch")
297        ksft_pr("hash expected:", send_hash.hexdigest())
298        ksft_pr("hash received:", recv_hash.hexdigest())
299        ret = 1
300        break
301
302    ksft_pr(f"{sender}/{receiver}: ok")
303
304if ret == 0:
305    ksft_pr("Success")
306    print("ok 1 rds selftest")
307else:
308    print("not ok 1 rds selftest")
309
310ksft_pr(f"Totals: pass:{1-ret} fail:{ret} skip:0")
311sys.exit(ret)
312