xref: /linux/tools/testing/selftests/net/rds/test.py (revision 6443f4f20bdae726fe01cf5946fba9742a0ffda6)
1#! /usr/bin/env python3
2# SPDX-License-Identifier: GPL-2.0
3"""
4This module provides functional testing for the net/rds component.
5"""
6
7import argparse
8import atexit
9import ctypes
10import errno
11import hashlib
12import os
13import select
14import re
15import signal
16import socket
17import subprocess
18import sys
19import time
20
21# Allow utils module to be imported from different directory
22this_dir = os.path.dirname(os.path.realpath(__file__))
23sys.path.append(os.path.join(this_dir, "../"))
24# pylint: disable-next=wrong-import-position,import-error,no-name-in-module
25from lib.py.utils import ip, cmd # noqa: E402
26# pylint: disable-next=wrong-import-position,import-error,no-name-in-module
27from lib.py.ksft import ksft_pr # noqa: E402
28
29libc = ctypes.cdll.LoadLibrary('libc.so.6')
30setns = libc.setns
31
32NET0 = 'net0'
33NET1 = 'net1'
34
35VETH0 = 'veth0'
36VETH1 = 'veth1'
37
38tcpdump_procs = []
39tcp_addrs = [
40    # we technically don't need different port numbers, but this will
41    # help identify traffic in the network analyzer
42    ('10.0.0.1', 10000),
43    ('10.0.0.2', 20000),
44]
45
46# RDMA network configs
47RXE_DEV0 = 'rxe0'
48RXE_DEV1 = 'rxe1'
49
50VETH_RDMA0 = 'veth_rdma0'
51VETH_RDMA1 = 'veth_rdma1'
52
53rdma_addrs = [
54    ('10.0.0.3', 30000),
55    ('10.0.0.4', 30000),
56]
57
58# send_packets flag space
59OP_FLAG_TCP     = 0x1
60OP_FLAG_RDMA    = 0x2
61
62# from include/uapi/linux/rds.h: SO_RDS_TRANSPORT pins a socket to a
63# specific RDS transport so connection setup cannot silently fall back
64# to another (e.g. loopback) transport.
65SOL_RDS          = 276
66SO_RDS_TRANSPORT = 8
67RDS_TRANS_TCP    = 2
68RDS_TRANS_IB     = 0
69
70signal_handler_label = ""
71
72tap_idx = 0
73nr_pass = 0
74nr_fail = 0
75
76# Helper function for creating a socket inside a network namespace.
77# We need this because otherwise RDS will detect that the two TCP
78# sockets are on the same interface and use the loop transport instead
79# of the TCP transport.
80def netns_socket(netns, *sock_args):
81    """
82    Creates sockets inside of network namespace
83
84    :param netns: the name of the network namespace
85    :param sock_args: socket family and type
86    """
87    u0, u1 = socket.socketpair(socket.AF_UNIX, socket.SOCK_SEQPACKET)
88
89    child = os.fork()
90    if child == 0:
91        try:
92            # change network namespace
93            with open(f'/var/run/netns/{netns}', encoding='utf-8') as f:
94                setns(f.fileno(), 0)
95            # create socket in target namespace
96            sock = socket.socket(*sock_args)
97
98            # send resulting socket to parent
99            socket.send_fds(u0, [], [sock.fileno()])
100
101            os._exit(0)
102        except BaseException:
103            os._exit(1)
104
105    # receive socket from child
106    _, fds, _, _ = socket.recv_fds(u1, 0, 1)
107    _, status = os.waitpid(child, 0)
108    u0.close()
109    u1.close()
110    if not os.WIFEXITED(status) or os.WEXITSTATUS(status) != 0:
111        raise RuntimeError(
112            f"netns_socket child failed in netns {netns} (status={status})")
113    return socket.fromfd(fds[0], *sock_args)
114
115def send_burst(socks, ip_addrs, snd_hashes, nr_sent, nr_total):
116    """Send until blocked or nr_total reached. Return updated nr_sent."""
117
118    while nr_sent < nr_total:
119        data = hashlib.sha256(
120            f'packet {nr_sent}'.encode('utf-8')).hexdigest().encode('utf-8')
121        # pseudo-random send/receive pattern
122        snd_idx = nr_sent % 2
123        rcv_idx = 1 - (nr_sent % 3) % 2
124
125        snd = socks[snd_idx]
126        rcv = socks[rcv_idx]
127        try:
128            snd.sendto(data, ip_addrs[rcv_idx])
129        except BlockingIOError:
130            return nr_sent
131        except OSError as e:
132            if e.errno in (errno.ENOBUFS, errno.ECONNRESET, errno.EPIPE):
133                return nr_sent
134            raise
135        snd_hashes.setdefault((snd.fileno(), rcv.fileno()),
136                hashlib.sha256()).update(f'<{data}>'.encode('utf-8'))
137        nr_sent += 1
138    return nr_sent
139
140def recv_burst(epoll, socks, ip_addrs, rcv_hashes, nr_rcv):
141    """Drain whatever's readable from epoll. Return updated nr_recv."""
142    for filen, evntmask in epoll.poll():
143        if not evntmask & select.EPOLLRDNORM:
144            continue
145        rcv = next(s for s in socks if s.fileno() == filen)
146        while True:
147            try:
148                data, adr = rcv.recvfrom(1024)
149            except BlockingIOError:
150                break
151            snd_idx = ip_addrs.index(adr)
152            snd = socks[snd_idx]
153            rcv_hashes.setdefault((snd.fileno(), rcv.fileno()),
154                    hashlib.sha256()).update(f'<{data}>'.encode('utf-8'))
155            nr_rcv += 1
156    return nr_rcv
157
158def check_info(socks):
159    """
160    Check all rds info pages for errors
161
162    :param socks: list of sockets to check
163    """
164
165    # the Python socket module doesn't know these
166    rds_info_first = 10000
167    rds_info_last = 10017
168
169    nr_success = 0
170    nr_error = 0
171
172    for sock in socks:
173        for optname in range(rds_info_first, rds_info_last + 1):
174            # Sigh, the Python socket module doesn't allow us to pass
175            # buffer lengths greater than 1024 for some reason. RDS
176            # wants multiple pages.
177            try:
178                sock.getsockopt(socket.SOL_RDS, optname, 1024)
179                nr_success = nr_success + 1
180            except OSError as e:
181                nr_error = nr_error + 1
182                if e.errno == errno.ENOSPC:
183                    # ignore
184                    pass
185
186    ksft_pr(f"getsockopt(): {nr_success}/{nr_error}")
187
188def verify_hashes(snd_hashes, rcv_hashes):
189    """Compare send/recv hashes per (sender, receiver) pair."""
190    for key, snd_hash in snd_hashes.items():
191        rcv_hash = rcv_hashes.get(key)
192        if rcv_hash is None:
193            ksft_pr("FAIL: No data received")
194            return 1
195        if snd_hash.hexdigest() != rcv_hash.hexdigest():
196            ksft_pr("FAIL: Send/recv mismatch")
197            ksft_pr("hash expected:", snd_hash.hexdigest())
198            ksft_pr("hash received:", rcv_hash.hexdigest())
199            return 1
200        ksft_pr(f"{key[0]}/{key[1]}: ok")
201    return 0
202
203def snd_rcv_packets(env):
204    """
205    Send packets on the given network interfaces
206
207    :param env: transport-environment dict for setup_tcp() / setup_rdma().
208                "addrs": list of (ip, port) tuples matching the sockets
209                "netns": list of netns names for TCP or None for RDMA
210                "flags": OP_FLAG_TCP or OP_FLAG_RDMA, selects sockets
211    """
212
213    addrs = env["addrs"]
214    netns_list = env["netns"]
215    flags = env.get("flags", 0)
216
217    if (flags & OP_FLAG_TCP) and (flags & OP_FLAG_RDMA):
218        raise RuntimeError(f"Invalid transport flag sets multiple transports: {flags}")
219
220    if flags & OP_FLAG_TCP:
221        sockets = [
222            netns_socket(netns_list[0], socket.AF_RDS, socket.SOCK_SEQPACKET),
223            netns_socket(netns_list[1], socket.AF_RDS, socket.SOCK_SEQPACKET),
224        ]
225
226        # Pin the sockets to the TCP transport so it doesn't fail over to a
227        # different transport during this test
228        for s in sockets:
229            s.setsockopt(SOL_RDS, SO_RDS_TRANSPORT, RDS_TRANS_TCP)
230    elif flags & OP_FLAG_RDMA:
231        sockets = [
232            socket.socket(socket.AF_RDS, socket.SOCK_SEQPACKET),
233            socket.socket(socket.AF_RDS, socket.SOCK_SEQPACKET),
234        ]
235
236        # Pin the sockets to the RDMA transport so it doesn't fail over to a
237        # different transport during this test
238        for s in sockets:
239            s.setsockopt(SOL_RDS, SO_RDS_TRANSPORT, RDS_TRANS_IB)
240    else:
241        raise RuntimeError(f"Invalid transport flag sets no transports: {flags}")
242
243    for s, addr in zip(sockets, addrs):
244        s.bind(addr)
245        s.setblocking(0)
246
247    send_hashes = {}
248    recv_hashes = {}
249
250    ep = select.epoll()
251
252    for s in sockets:
253        ep.register(s, select.EPOLLRDNORM)
254
255    num_packets = 50000
256    nr_send = 0
257    nr_recv = 0
258
259    while nr_send < num_packets:
260
261        # Send as much as we can without blocking
262        ksft_pr("sending...", nr_send, nr_recv)
263        nr_send = send_burst(sockets, addrs, send_hashes, nr_send, num_packets)
264
265        # Receive as much as we can without blocking
266        ksft_pr("receiving...", nr_send, nr_recv)
267        while nr_recv < nr_send:
268            nr_recv = recv_burst(ep, sockets, addrs, recv_hashes, nr_recv)
269
270        # exercise net/rds/tcp.c:rds_tcp_sysctl_reset()
271        if netns_list:
272            for net in netns_list:
273                ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_rcvbuf=10000")
274                ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_sndbuf=10000")
275
276    ksft_pr("done", nr_send, nr_recv)
277
278    check_info(sockets)
279
280    # We're done sending and receiving stuff, now let's check if what
281    # we received is what we sent.
282    rc = verify_hashes(send_hashes, recv_hashes)
283
284    ep.close()
285    for s in sockets:
286        s.close()
287
288    return rc
289
290def stop_pcaps():
291    """Stop tcpdump processes.
292
293    We use pop() here to drain the list in the event that the test
294    completes after the signal handler is fired.  List will be empty
295    if logdir is not set
296    """
297
298    if not tcpdump_procs:
299        return
300
301    ksft_pr("Stopping network packet captures")
302    while tcpdump_procs:
303        proc = tcpdump_procs.pop()
304        proc.terminate()
305        try:
306            proc.wait(timeout=5)
307        except subprocess.TimeoutExpired:
308            proc.kill()
309            proc.wait()
310
311def signal_handler(_sig, _frame):
312    """
313    Test timed out signal handler
314    """
315    ksft_pr(f"Test timed out: {signal_handler_label}")
316    print(f"not ok {tap_idx} rds selftest {signal_handler_label}")
317    sys.exit(1)
318
319def setup_tcp():
320    """
321    Configure tcp network
322    """
323
324    # clean up any leftovers from a previously interrupted run
325    teardown_tcp()
326
327    ip(f"netns add {NET0}")
328    ip(f"netns add {NET1}")
329    ip("link add type veth")
330
331    # Move TCP interfaces into separate namespaces so they can no longer be
332    # bound directly; this prevents rds from switching over from the tcp
333    # transport to the loop transport.
334    ip(f"link set {VETH0} netns {NET0} up")
335    ip(f"link set {VETH1} netns {NET1} up")
336
337    # add addresses
338    ip(f"-n {NET0} addr add {tcp_addrs[0][0]}/32 dev {VETH0}")
339    ip(f"-n {NET1} addr add {tcp_addrs[1][0]}/32 dev {VETH1}")
340
341    # add routes
342    ip(f"-n {NET0} route add {tcp_addrs[1][0]}/32 dev {VETH0}")
343    ip(f"-n {NET1} route add {tcp_addrs[0][0]}/32 dev {VETH1}")
344
345    # sanity check that our two interfaces/addresses are correctly set up
346    # and communicating by doing a single ping
347    ip(f"netns exec {NET0} ping -c 1 {tcp_addrs[1][0]}")
348
349    # Start a packet capture on each network
350    if logdir is not None:
351        for netn in [NET0, NET1]:
352            pcap = logdir+'/rds-'+netn+'.pcap'
353
354            tcpdump_cmd = ['ip', 'netns', 'exec', netn, '/usr/sbin/tcpdump']
355            sudo_user = os.environ.get('SUDO_USER')
356            if sudo_user:
357                tcpdump_cmd.extend(['-Z', sudo_user])
358            tcpdump_cmd.extend(['-i', 'any', '-w', pcap])
359
360            # pylint: disable-next=consider-using-with
361            p = subprocess.Popen(tcpdump_cmd,
362                                 stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
363            tcpdump_procs.append(p)
364
365    # simulate packet loss, duplication and corruption
366    for netn, iface in [(NET0, VETH0), (NET1, VETH1)]:
367        ip(f"netns exec {netn} /usr/sbin/tc qdisc add dev {iface} root netem  \
368             corrupt {PACKET_CORRUPTION} loss {PACKET_LOSS} duplicate  \
369             {PACKET_DUPLICATE}")
370
371def teardown_tcp():
372    """
373    Tear down the tcp network configured by setup_tcp().
374
375    Removing the namespaces also removes the veth pair, addresses,
376    routes, and netem qdisc that live inside them.  fail=False so
377    this is safe to call in error paths after a partial or complete setup.
378    """
379    cmd(f"ip netns del {NET0}", fail=False)
380    cmd(f"ip netns del {NET1}", fail=False)
381
382def get_iface_mac(iface):
383    """Return the MAC address of a local network interface."""
384    out = subprocess.check_output(['ip', 'link', 'show', iface], text=True)
385    mac = re.search(r'link/ether\s+([0-9a-f:]+)', out)
386    if not mac:
387        raise RuntimeError(f"Cannot determine MAC address of {iface}")
388    return mac.group(1)
389
390def setup_rdma():
391    """
392    Configure rdma network
393    """
394
395    # remove links left over by previously interrupted run.
396    teardown_rdma()
397
398    # use call here since modprobe may fail if the rdma_rxe
399    # module is built-in
400    subprocess.call(['modprobe', 'rdma_rxe'],
401                    stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
402
403    ip(f"link add {VETH_RDMA0} type veth peer name {VETH_RDMA1}")
404
405    ip(f"link set {VETH_RDMA0} up")
406    ip(f"link set {VETH_RDMA1} up")
407
408    # Since both addresses are in the same namespace, the source address
409    # is always local, so enable accept_local
410    cmd(f"/usr/sbin/sysctl -q net.ipv4.conf.{VETH_RDMA0}.accept_local=1")
411    cmd(f"/usr/sbin/sysctl -q net.ipv4.conf.{VETH_RDMA1}.accept_local=1")
412
413    # Reverse path filters must be disabled so that the local routes don't
414    # cause RPF failures.
415    cmd(f"/usr/sbin/sysctl -q net.ipv4.conf.{VETH_RDMA0}.rp_filter=0")
416    cmd(f"/usr/sbin/sysctl -q net.ipv4.conf.{VETH_RDMA1}.rp_filter=0")
417
418    # add addresses
419    ip(f"addr add {rdma_addrs[0][0]}/32 dev {VETH_RDMA0}")
420    ip(f"addr add {rdma_addrs[1][0]}/32 dev {VETH_RDMA1}")
421
422    # add routes
423    ip(f"route add {rdma_addrs[1][0]}/32 dev {VETH_RDMA0}")
424    ip(f"route add {rdma_addrs[0][0]}/32 dev {VETH_RDMA1}")
425
426    # ARP will not resolve neighbor IPs on /32 routes without a subnet.
427    # Avoid this by adding neighbors directly so RDMA CM can populate path
428    # records with correct mac addrs without waiting for the ARP.
429    mac0 = get_iface_mac(VETH_RDMA0)
430    mac1 = get_iface_mac(VETH_RDMA1)
431    ip(f"neigh add {rdma_addrs[1][0]} lladdr {mac1} dev {VETH_RDMA0} nud permanent")
432    ip(f"neigh add {rdma_addrs[0][0]} lladdr {mac0} dev {VETH_RDMA1} nud permanent")
433
434    cmd(f'rdma link add {RXE_DEV0} type rxe netdev {VETH_RDMA0}')
435    cmd(f'rdma link add {RXE_DEV1} type rxe netdev {VETH_RDMA1}')
436
437    time.sleep(1)  # allow RXE devices to initialise
438
439    # Start a packet capture on each network
440    if logdir is not None:
441        for iface in [VETH_RDMA0, VETH_RDMA1]:
442            pcap = logdir+'/rds-roce-'+iface+'.pcap'
443
444            tcpdump_cmd = ['/usr/sbin/tcpdump']
445            sudo_user = os.environ.get('SUDO_USER')
446            if sudo_user:
447                tcpdump_cmd.extend(['-Z', sudo_user])
448            tcpdump_cmd.extend(['-i', iface, '-w', pcap])
449
450            # pylint: disable-next=consider-using-with
451            p = subprocess.Popen(tcpdump_cmd,
452                                 stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
453            tcpdump_procs.append(p)
454
455    # simulate packet loss, duplication and corruption
456    for iface in [VETH_RDMA0, VETH_RDMA1]:
457        cmd(f"/usr/sbin/tc qdisc add dev {iface} root netem  \
458             corrupt {PACKET_CORRUPTION} loss {PACKET_LOSS} duplicate  \
459             {PACKET_DUPLICATE}")
460
461def teardown_rdma():
462    """
463    Tear down the rdma network configured by setup_rdma().
464    """
465
466    # remove links left over by previously interrupted run.
467    cmd(f'rdma link del {RXE_DEV0}', fail=False)
468    cmd(f'rdma link del {RXE_DEV1}', fail=False)
469    cmd(f'ip link del {VETH_RDMA0}', fail=False)
470
471
472#Parse out command line arguments.  We take an optional
473# timeout parameter and an optional log output folder
474parser = argparse.ArgumentParser(description="init script args",
475                  formatter_class=argparse.ArgumentDefaultsHelpFormatter)
476parser.add_argument("-d", "--logdir", action="store",
477                    help="directory to store logs", default=None)
478parser.add_argument("-T", "--transport", default="tcp",
479                    help="Comma-separated list of transports to test: "
480                         "tcp, rdma, or tcp,rdma.  Each matching test "
481                         "is run once per transport.  "
482                         "'rdma' requires CONFIG_RDS_RDMA and rdma_rxe.")
483parser.add_argument('-t', '--timeout', help="timeout to terminate hung test",
484                    type=int, default=0)
485parser.add_argument('-l', '--loss', help="Simulate tcp packet loss",
486                    type=int, default=0)
487parser.add_argument('-c', '--corruption', help="Simulate tcp packet corruption",
488                    type=int, default=0)
489parser.add_argument('-u', '--duplicate', help="Simulate tcp packet duplication",
490                    type=int, default=0)
491args = parser.parse_args()
492logdir=args.logdir
493PACKET_LOSS=str(args.loss)+'%'
494PACKET_CORRUPTION=str(args.corruption)+'%'
495PACKET_DUPLICATE=str(args.duplicate)+'%'
496
497# check transport is either tcp or rdma
498transports = [t.strip() for t in args.transport.split(',')]
499for t in transports:
500    if t not in ('tcp', 'rdma'):
501        raise SystemExit(f"test.py: unknown transport: {t!r}")
502
503# Register stop_pcaps before any network setups so that any partially setup
504# tcpdumps are still cleaned up on error
505atexit.register(stop_pcaps)
506
507# Set up all requested transports upfront so network plumbing is
508# ready before any test runs.
509transport_envs = {}
510FLAGS = 0
511if 'tcp' in transports:
512    # Register cleanups before setups to handle partial setups that error'd out
513    atexit.register(teardown_tcp)
514    setup_tcp()
515    transport_envs['tcp'] = {
516        'addrs': tcp_addrs,
517        'netns': [NET0, NET1],
518        'flags': FLAGS | OP_FLAG_TCP,
519    }
520
521if 'rdma' in transports:
522    atexit.register(teardown_rdma)
523    setup_rdma()
524    transport_envs['rdma'] = {
525        'addrs': rdma_addrs,
526        'netns': None,
527        'flags': FLAGS | OP_FLAG_RDMA,
528    }
529
530print("TAP version 13")
531print(f"1..{len(transport_envs)}")
532
533for transport, tenv in transport_envs.items():
534    tap_idx += 1
535
536    # add a timeout
537    if args.timeout > 0:
538        signal_handler_label = transport
539        signal.alarm(args.timeout)
540        signal.signal(signal.SIGALRM, signal_handler)
541
542    ret = snd_rcv_packets(tenv)
543
544    # cancel timeout
545    signal.alarm(0)
546
547    if ret == 0:
548        ksft_pr("Success")
549        print(f"ok {tap_idx} rds selftest {transport}")
550        nr_pass += 1
551    else:
552        print(f"not ok {tap_idx} rds selftest {transport}")
553        nr_fail += 1
554
555ksft_pr(f"Totals: pass:{nr_pass} fail:{nr_fail} skip:0")
556sys.exit(1 if nr_fail else 0)
557