xref: /linux/tools/testing/selftests/net/rds/test.py (revision d603517771d8e08a2d8fc9e1f7682ce393d3973a)
1#! /usr/bin/env python3
2# SPDX-License-Identifier: GPL-2.0
3"""
4This module provides functional testing for the net/rds component.
5"""
6
7import argparse
8import atexit
9import ctypes
10import errno
11import hashlib
12import os
13import select
14import re
15import signal
16import socket
17import subprocess
18import sys
19import time
20
21# Allow utils module to be imported from different directory
22this_dir = os.path.dirname(os.path.realpath(__file__))
23sys.path.append(os.path.join(this_dir, "../"))
24# pylint: disable-next=wrong-import-position,import-error,no-name-in-module
25from lib.py.utils import ip, cmd # noqa: E402
26# pylint: disable-next=wrong-import-position,import-error,no-name-in-module
27from lib.py.ksft import ksft_pr # noqa: E402
28
29libc = ctypes.cdll.LoadLibrary('libc.so.6')
30setns = libc.setns
31
32NET0 = 'net0'
33NET1 = 'net1'
34
35VETH0 = 'veth0'
36VETH1 = 'veth1'
37
38tcpdump_procs = []
39tcp_addrs = [
40    # we technically don't need different port numbers, but this will
41    # help identify traffic in the network analyzer
42    ('10.0.0.1', 10000),
43    ('10.0.0.2', 20000),
44]
45
46# RDMA network configs
47RXE_DEV0 = 'rxe0'
48RXE_DEV1 = 'rxe1'
49
50VETH_RDMA0 = 'veth_rdma0'
51VETH_RDMA1 = 'veth_rdma1'
52
53rdma_addrs = [
54    ('10.0.0.3', 30000),
55    ('10.0.0.4', 30000),
56]
57
58# send_packets flag space
59OP_FLAG_TCP     = 0x1
60OP_FLAG_RDMA    = 0x2
61
62signal_handler_label = ""
63
64tap_idx = 0
65nr_pass = 0
66nr_fail = 0
67
68# Helper function for creating a socket inside a network namespace.
69# We need this because otherwise RDS will detect that the two TCP
70# sockets are on the same interface and use the loop transport instead
71# of the TCP transport.
72def netns_socket(netns, *sock_args):
73    """
74    Creates sockets inside of network namespace
75
76    :param netns: the name of the network namespace
77    :param sock_args: socket family and type
78    """
79    u0, u1 = socket.socketpair(socket.AF_UNIX, socket.SOCK_SEQPACKET)
80
81    child = os.fork()
82    if child == 0:
83        try:
84            # change network namespace
85            with open(f'/var/run/netns/{netns}', encoding='utf-8') as f:
86                setns(f.fileno(), 0)
87            # create socket in target namespace
88            sock = socket.socket(*sock_args)
89
90            # send resulting socket to parent
91            socket.send_fds(u0, [], [sock.fileno()])
92
93            os._exit(0)
94        except BaseException:
95            os._exit(1)
96
97    # receive socket from child
98    _, fds, _, _ = socket.recv_fds(u1, 0, 1)
99    _, status = os.waitpid(child, 0)
100    u0.close()
101    u1.close()
102    if not os.WIFEXITED(status) or os.WEXITSTATUS(status) != 0:
103        raise RuntimeError(
104            f"netns_socket child failed in netns {netns} (status={status})")
105    return socket.fromfd(fds[0], *sock_args)
106
107def send_burst(socks, ip_addrs, snd_hashes, nr_sent, nr_total):
108    """Send until blocked or nr_total reached. Return updated nr_sent."""
109
110    while nr_sent < nr_total:
111        data = hashlib.sha256(
112            f'packet {nr_sent}'.encode('utf-8')).hexdigest().encode('utf-8')
113        # pseudo-random send/receive pattern
114        snd_idx = nr_sent % 2
115        rcv_idx = 1 - (nr_sent % 3) % 2
116
117        snd = socks[snd_idx]
118        rcv = socks[rcv_idx]
119        try:
120            snd.sendto(data, ip_addrs[rcv_idx])
121        except BlockingIOError:
122            return nr_sent
123        except OSError as e:
124            if e.errno in (errno.ENOBUFS, errno.ECONNRESET, errno.EPIPE):
125                return nr_sent
126            raise
127        snd_hashes.setdefault((snd.fileno(), rcv.fileno()),
128                hashlib.sha256()).update(f'<{data}>'.encode('utf-8'))
129        nr_sent += 1
130    return nr_sent
131
132def recv_burst(epoll, socks, ip_addrs, rcv_hashes, nr_rcv):
133    """Drain whatever's readable from epoll. Return updated nr_recv."""
134    for filen, evntmask in epoll.poll():
135        if not evntmask & select.EPOLLRDNORM:
136            continue
137        rcv = next(s for s in socks if s.fileno() == filen)
138        while True:
139            try:
140                data, adr = rcv.recvfrom(1024)
141            except BlockingIOError:
142                break
143            snd_idx = ip_addrs.index(adr)
144            snd = socks[snd_idx]
145            rcv_hashes.setdefault((snd.fileno(), rcv.fileno()),
146                    hashlib.sha256()).update(f'<{data}>'.encode('utf-8'))
147            nr_rcv += 1
148    return nr_rcv
149
150def check_info(socks):
151    """
152    Check all rds info pages for errors
153
154    :param socks: list of sockets to check
155    """
156
157    # the Python socket module doesn't know these
158    rds_info_first = 10000
159    rds_info_last = 10017
160
161    nr_success = 0
162    nr_error = 0
163
164    for sock in socks:
165        for optname in range(rds_info_first, rds_info_last + 1):
166            # Sigh, the Python socket module doesn't allow us to pass
167            # buffer lengths greater than 1024 for some reason. RDS
168            # wants multiple pages.
169            try:
170                sock.getsockopt(socket.SOL_RDS, optname, 1024)
171                nr_success = nr_success + 1
172            except OSError as e:
173                nr_error = nr_error + 1
174                if e.errno == errno.ENOSPC:
175                    # ignore
176                    pass
177
178    ksft_pr(f"getsockopt(): {nr_success}/{nr_error}")
179
180def verify_hashes(snd_hashes, rcv_hashes):
181    """Compare send/recv hashes per (sender, receiver) pair."""
182    for key, snd_hash in snd_hashes.items():
183        rcv_hash = rcv_hashes.get(key)
184        if rcv_hash is None:
185            ksft_pr("FAIL: No data received")
186            return 1
187        if snd_hash.hexdigest() != rcv_hash.hexdigest():
188            ksft_pr("FAIL: Send/recv mismatch")
189            ksft_pr("hash expected:", snd_hash.hexdigest())
190            ksft_pr("hash received:", rcv_hash.hexdigest())
191            return 1
192        ksft_pr(f"{key[0]}/{key[1]}: ok")
193    return 0
194
195def snd_rcv_packets(env):
196    """
197    Send packets on the given network interfaces
198
199    :param env: transport-environment dict for setup_tcp() / setup_rdma().
200                "addrs": list of (ip, port) tuples matching the sockets
201                "netns": list of netns names for TCP or None for RDMA
202                "flags": OP_FLAG_TCP or OP_FLAG_RDMA, selects sockets
203    """
204
205    addrs = env["addrs"]
206    netns_list = env["netns"]
207    flags = env.get("flags", 0)
208
209    if (flags & OP_FLAG_TCP) and (flags & OP_FLAG_RDMA):
210        raise RuntimeError(f"Invalid transport flag sets multiple transports: {flags}")
211
212    if flags & OP_FLAG_TCP:
213        sockets = [
214            netns_socket(netns_list[0], socket.AF_RDS, socket.SOCK_SEQPACKET),
215            netns_socket(netns_list[1], socket.AF_RDS, socket.SOCK_SEQPACKET),
216        ]
217    elif flags & OP_FLAG_RDMA:
218        sockets = [
219            socket.socket(socket.AF_RDS, socket.SOCK_SEQPACKET),
220            socket.socket(socket.AF_RDS, socket.SOCK_SEQPACKET),
221        ]
222    else:
223        raise RuntimeError(f"Invalid transport flag sets no transports: {flags}")
224
225    for s, addr in zip(sockets, addrs):
226        s.bind(addr)
227        s.setblocking(0)
228
229    send_hashes = {}
230    recv_hashes = {}
231
232    ep = select.epoll()
233
234    for s in sockets:
235        ep.register(s, select.EPOLLRDNORM)
236
237    num_packets = 50000
238    nr_send = 0
239    nr_recv = 0
240
241    while nr_send < num_packets:
242
243        # Send as much as we can without blocking
244        ksft_pr("sending...", nr_send, nr_recv)
245        nr_send = send_burst(sockets, addrs, send_hashes, nr_send, num_packets)
246
247        # Receive as much as we can without blocking
248        ksft_pr("receiving...", nr_send, nr_recv)
249        while nr_recv < nr_send:
250            nr_recv = recv_burst(ep, sockets, addrs, recv_hashes, nr_recv)
251
252        # exercise net/rds/tcp.c:rds_tcp_sysctl_reset()
253        if netns_list:
254            for net in netns_list:
255                ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_rcvbuf=10000")
256                ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_sndbuf=10000")
257
258    ksft_pr("done", nr_send, nr_recv)
259
260    check_info(sockets)
261
262    # We're done sending and receiving stuff, now let's check if what
263    # we received is what we sent.
264    rc = verify_hashes(send_hashes, recv_hashes)
265
266    ep.close()
267    for s in sockets:
268        s.close()
269
270    return rc
271
272def stop_pcaps():
273    """Stop tcpdump processes.
274
275    We use pop() here to drain the list in the event that the test
276    completes after the signal handler is fired.  List will be empty
277    if logdir is not set
278    """
279
280    if not tcpdump_procs:
281        return
282
283    ksft_pr("Stopping network packet captures")
284    while tcpdump_procs:
285        proc = tcpdump_procs.pop()
286        proc.terminate()
287        try:
288            proc.wait(timeout=5)
289        except subprocess.TimeoutExpired:
290            proc.kill()
291            proc.wait()
292
293def signal_handler(_sig, _frame):
294    """
295    Test timed out signal handler
296    """
297    ksft_pr(f"Test timed out: {signal_handler_label}")
298    print(f"not ok {tap_idx} rds selftest {signal_handler_label}")
299    sys.exit(1)
300
301def setup_tcp():
302    """
303    Configure tcp network
304    """
305
306    # clean up any leftovers from a previously interrupted run
307    teardown_tcp()
308
309    ip(f"netns add {NET0}")
310    ip(f"netns add {NET1}")
311    ip("link add type veth")
312
313    # Move TCP interfaces into separate namespaces so they can no longer be
314    # bound directly; this prevents rds from switching over from the tcp
315    # transport to the loop transport.
316    ip(f"link set {VETH0} netns {NET0} up")
317    ip(f"link set {VETH1} netns {NET1} up")
318
319    # add addresses
320    ip(f"-n {NET0} addr add {tcp_addrs[0][0]}/32 dev {VETH0}")
321    ip(f"-n {NET1} addr add {tcp_addrs[1][0]}/32 dev {VETH1}")
322
323    # add routes
324    ip(f"-n {NET0} route add {tcp_addrs[1][0]}/32 dev {VETH0}")
325    ip(f"-n {NET1} route add {tcp_addrs[0][0]}/32 dev {VETH1}")
326
327    # sanity check that our two interfaces/addresses are correctly set up
328    # and communicating by doing a single ping
329    ip(f"netns exec {NET0} ping -c 1 {tcp_addrs[1][0]}")
330
331    # Start a packet capture on each network
332    if logdir is not None:
333        for netn in [NET0, NET1]:
334            pcap = logdir+'/rds-'+netn+'.pcap'
335
336            tcpdump_cmd = ['ip', 'netns', 'exec', netn, '/usr/sbin/tcpdump']
337            sudo_user = os.environ.get('SUDO_USER')
338            if sudo_user:
339                tcpdump_cmd.extend(['-Z', sudo_user])
340            tcpdump_cmd.extend(['-i', 'any', '-w', pcap])
341
342            # pylint: disable-next=consider-using-with
343            p = subprocess.Popen(tcpdump_cmd,
344                                 stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
345            tcpdump_procs.append(p)
346
347    # simulate packet loss, duplication and corruption
348    for netn, iface in [(NET0, VETH0), (NET1, VETH1)]:
349        ip(f"netns exec {netn} /usr/sbin/tc qdisc add dev {iface} root netem  \
350             corrupt {PACKET_CORRUPTION} loss {PACKET_LOSS} duplicate  \
351             {PACKET_DUPLICATE}")
352
353def teardown_tcp():
354    """
355    Tear down the tcp network configured by setup_tcp().
356
357    Removing the namespaces also removes the veth pair, addresses,
358    routes, and netem qdisc that live inside them.  fail=False so
359    this is safe to call in error paths after a partial or complete setup.
360    """
361    cmd(f"ip netns del {NET0}", fail=False)
362    cmd(f"ip netns del {NET1}", fail=False)
363
364def get_iface_mac(iface):
365    """Return the MAC address of a local network interface."""
366    out = subprocess.check_output(['ip', 'link', 'show', iface], text=True)
367    mac = re.search(r'link/ether\s+([0-9a-f:]+)', out)
368    if not mac:
369        raise RuntimeError(f"Cannot determine MAC address of {iface}")
370    return mac.group(1)
371
372def setup_rdma():
373    """
374    Configure rdma network
375    """
376
377    # remove links left over by previously interrupted run.
378    teardown_rdma()
379
380    # use call here since modprobe may fail if the rdma_rxe
381    # module is built-in
382    subprocess.call(['modprobe', 'rdma_rxe'],
383                    stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
384
385    ip(f"link add {VETH_RDMA0} type veth peer name {VETH_RDMA1}")
386
387    ip(f"link set {VETH_RDMA0} up")
388    ip(f"link set {VETH_RDMA1} up")
389
390    # Since both addresses are in the same namespace, the source address
391    # is always local, so enable accept_local
392    cmd(f"/usr/sbin/sysctl -q net.ipv4.conf.{VETH_RDMA0}.accept_local=1")
393    cmd(f"/usr/sbin/sysctl -q net.ipv4.conf.{VETH_RDMA1}.accept_local=1")
394
395    # Reverse path filters must be disabled so that the local routes don't
396    # cause RPF failures.
397    cmd(f"/usr/sbin/sysctl -q net.ipv4.conf.{VETH_RDMA0}.rp_filter=0")
398    cmd(f"/usr/sbin/sysctl -q net.ipv4.conf.{VETH_RDMA1}.rp_filter=0")
399
400    # add addresses
401    ip(f"addr add {rdma_addrs[0][0]}/32 dev {VETH_RDMA0}")
402    ip(f"addr add {rdma_addrs[1][0]}/32 dev {VETH_RDMA1}")
403
404    # add routes
405    ip(f"route add {rdma_addrs[1][0]}/32 dev {VETH_RDMA0}")
406    ip(f"route add {rdma_addrs[0][0]}/32 dev {VETH_RDMA1}")
407
408    # ARP will not resolve neighbor IPs on /32 routes without a subnet.
409    # Avoid this by adding neighbors directly so RDMA CM can populate path
410    # records with correct mac addrs without waiting for the ARP.
411    mac0 = get_iface_mac(VETH_RDMA0)
412    mac1 = get_iface_mac(VETH_RDMA1)
413    ip(f"neigh add {rdma_addrs[1][0]} lladdr {mac1} dev {VETH_RDMA0} nud permanent")
414    ip(f"neigh add {rdma_addrs[0][0]} lladdr {mac0} dev {VETH_RDMA1} nud permanent")
415
416    cmd(f'rdma link add {RXE_DEV0} type rxe netdev {VETH_RDMA0}')
417    cmd(f'rdma link add {RXE_DEV1} type rxe netdev {VETH_RDMA1}')
418
419    time.sleep(1)  # allow RXE devices to initialise
420
421    # Start a packet capture on each network
422    if logdir is not None:
423        for iface in [VETH_RDMA0, VETH_RDMA1]:
424            pcap = logdir+'/rds-roce-'+iface+'.pcap'
425
426            tcpdump_cmd = ['/usr/sbin/tcpdump']
427            sudo_user = os.environ.get('SUDO_USER')
428            if sudo_user:
429                tcpdump_cmd.extend(['-Z', sudo_user])
430            tcpdump_cmd.extend(['-i', iface, '-w', pcap])
431
432            # pylint: disable-next=consider-using-with
433            p = subprocess.Popen(tcpdump_cmd,
434                                 stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
435            tcpdump_procs.append(p)
436
437    # simulate packet loss, duplication and corruption
438    for iface in [VETH_RDMA0, VETH_RDMA1]:
439        cmd(f"/usr/sbin/tc qdisc add dev {iface} root netem  \
440             corrupt {PACKET_CORRUPTION} loss {PACKET_LOSS} duplicate  \
441             {PACKET_DUPLICATE}")
442
443def teardown_rdma():
444    """
445    Tear down the rdma network configured by setup_rdma().
446    """
447
448    # remove links left over by previously interrupted run.
449    cmd(f'rdma link del {RXE_DEV0}', fail=False)
450    cmd(f'rdma link del {RXE_DEV1}', fail=False)
451    cmd(f'ip link del {VETH_RDMA0}', fail=False)
452
453
454#Parse out command line arguments.  We take an optional
455# timeout parameter and an optional log output folder
456parser = argparse.ArgumentParser(description="init script args",
457                  formatter_class=argparse.ArgumentDefaultsHelpFormatter)
458parser.add_argument("-d", "--logdir", action="store",
459                    help="directory to store logs", default=None)
460parser.add_argument("-T", "--transport", default="tcp",
461                    help="Comma-separated list of transports to test: "
462                         "tcp, rdma, or tcp,rdma.  Each matching test "
463                         "is run once per transport.  "
464                         "'rdma' requires CONFIG_RDS_RDMA and rdma_rxe.")
465parser.add_argument('-t', '--timeout', help="timeout to terminate hung test",
466                    type=int, default=0)
467parser.add_argument('-l', '--loss', help="Simulate tcp packet loss",
468                    type=int, default=0)
469parser.add_argument('-c', '--corruption', help="Simulate tcp packet corruption",
470                    type=int, default=0)
471parser.add_argument('-u', '--duplicate', help="Simulate tcp packet duplication",
472                    type=int, default=0)
473args = parser.parse_args()
474logdir=args.logdir
475PACKET_LOSS=str(args.loss)+'%'
476PACKET_CORRUPTION=str(args.corruption)+'%'
477PACKET_DUPLICATE=str(args.duplicate)+'%'
478
479# check transport is either tcp or rdma
480transports = [t.strip() for t in args.transport.split(',')]
481for t in transports:
482    if t not in ('tcp', 'rdma'):
483        raise SystemExit(f"test.py: unknown transport: {t!r}")
484
485# Register stop_pcaps before any network setups so that any partially setup
486# tcpdumps are still cleaned up on error
487atexit.register(stop_pcaps)
488
489# Set up all requested transports upfront so network plumbing is
490# ready before any test runs.
491transport_envs = {}
492FLAGS = 0
493if 'tcp' in transports:
494    # Register cleanups before setups to handle partial setups that error'd out
495    atexit.register(teardown_tcp)
496    setup_tcp()
497    transport_envs['tcp'] = {
498        'addrs': tcp_addrs,
499        'netns': [NET0, NET1],
500        'flags': FLAGS | OP_FLAG_TCP,
501    }
502
503if 'rdma' in transports:
504    atexit.register(teardown_rdma)
505    setup_rdma()
506    transport_envs['rdma'] = {
507        'addrs': rdma_addrs,
508        'netns': None,
509        'flags': FLAGS | OP_FLAG_RDMA,
510    }
511
512print("TAP version 13")
513print(f"1..{len(transport_envs)}")
514
515for transport, tenv in transport_envs.items():
516    tap_idx += 1
517
518    # add a timeout
519    if args.timeout > 0:
520        signal_handler_label = transport
521        signal.alarm(args.timeout)
522        signal.signal(signal.SIGALRM, signal_handler)
523
524    ret = snd_rcv_packets(tenv)
525
526    # cancel timeout
527    signal.alarm(0)
528
529    if ret == 0:
530        ksft_pr("Success")
531        print(f"ok {tap_idx} rds selftest {transport}")
532        nr_pass += 1
533    else:
534        print(f"not ok {tap_idx} rds selftest {transport}")
535        nr_fail += 1
536
537ksft_pr(f"Totals: pass:{nr_pass} fail:{nr_fail} skip:0")
538sys.exit(1 if nr_fail else 0)
539