xref: /linux/tools/testing/selftests/net/lib/py/utils.py (revision 78c1930198fc63f2d4761848cbe148c5b2958b01)
1# SPDX-License-Identifier: GPL-2.0
2
3import json as _json
4import os
5import re
6import select
7import socket
8import subprocess
9import time
10
11
12class CmdInitFailure(Exception):
13    """ Command failed to start. Only raised by bkg(). """
14    def __init__(self, msg, cmd_obj):
15        super().__init__(msg + "\n" + repr(cmd_obj))
16        self.cmd = cmd_obj
17
18
19class CmdExitFailure(Exception):
20    """ Command failed (returned non-zero exit code). """
21    def __init__(self, msg, cmd_obj):
22        super().__init__(msg + "\n" + repr(cmd_obj))
23        self.cmd = cmd_obj
24
25
26class CmdExitZeroFailure(CmdExitFailure):
27    """ Command succeeded (returned zero exit code), but expected failure. """
28
29
30def fd_read_timeout(fd, timeout):
31    rlist, _, _ = select.select([fd], [], [], timeout)
32    if rlist:
33        return os.read(fd, 1024)
34    raise TimeoutError("Timeout waiting for fd read")
35
36
37class cmd:
38    """
39    Execute a command on local or remote host.
40
41    @shell defaults to false, and class will try to split @comm into a list
42    if it's a string with spaces.
43
44    Use bkg() instead to run a command in the background.
45    """
46    def __init__(self, comm, shell=None, fail=True, expect_fail=False, ns=None,
47                 background=False, host=None, timeout=5, ksft_ready=None,
48                 ksft_wait=None):
49        if ns:
50            if hasattr(ns, 'user_ns_path'):
51                comm = (f'nsenter --user={ns.user_ns_path} '
52                        f'--net={ns.net_ns_path} --setuid=0 --setgid=0 -- '
53                        + comm)
54            else:
55                comm = f'ip netns exec {ns} ' + comm
56
57        self.stdout = None
58        self.stderr = None
59        self.ret = None
60        self.ksft_term_fd = None
61
62        self.host = host
63        self.comm = comm
64
65        if host:
66            self.proc = host.cmd(comm)
67        else:
68            # If user doesn't explicitly request shell try to avoid it.
69            if shell is None and isinstance(comm, str) and ' ' in comm:
70                comm = comm.split()
71
72            # ksft_wait lets us wait for the background process to fully start,
73            # we pass an FD to the child process, and wait for it to write back.
74            # Similarly term_fd tells child it's time to exit.
75            pass_fds = []
76            env = os.environ.copy()
77            if ksft_wait is not None:
78                wait_fd, self.ksft_term_fd = os.pipe()
79                pass_fds.append(wait_fd)
80                env["KSFT_WAIT_FD"]  = str(wait_fd)
81                ksft_ready = True  # ksft_wait implies ready
82            if ksft_ready is not None:
83                rfd, ready_fd = os.pipe()
84                pass_fds.append(ready_fd)
85                env["KSFT_READY_FD"] = str(ready_fd)
86
87            self.proc = subprocess.Popen(comm, shell=shell, stdout=subprocess.PIPE,
88                                         stderr=subprocess.PIPE, pass_fds=pass_fds,
89                                         env=env)
90            if ksft_wait is not None:
91                os.close(wait_fd)
92            if ksft_ready is not None:
93                os.close(ready_fd)
94                msg = fd_read_timeout(rfd, ksft_wait)
95                os.close(rfd)
96                if not msg:
97                    terminate = self.proc.poll() is None
98                    self._process_terminate(terminate=terminate, timeout=1)
99                    raise CmdInitFailure("Did not receive ready message", self)
100        if not background:
101            self.process(terminate=False, fail=fail, expect_fail=expect_fail,
102                         timeout=timeout)
103
104    def _process_terminate(self, terminate, timeout):
105        if terminate:
106            self.proc.terminate()
107        stdout, stderr = self.proc.communicate(timeout=timeout)
108        self.stdout = stdout.decode("utf-8")
109        self.stderr = stderr.decode("utf-8")
110        self.proc.stdout.close()
111        self.proc.stderr.close()
112        self.ret = self.proc.returncode
113
114        return stdout, stderr
115
116    def process(self, terminate=True, fail=None, expect_fail=False, timeout=5):
117        if fail is None:
118            fail = not terminate
119
120        if self.ksft_term_fd:
121            os.write(self.ksft_term_fd, b"1")
122
123        stdout, stderr = self._process_terminate(terminate=terminate,
124                                                 timeout=timeout)
125
126        # Fail on unexpected test failure if fail.
127        # Fail on unexpected test success if expect_fail.
128        # Fail on negative returncode if either:
129        # Set by subprocess on crash or signal, this is never expected failure.
130        if (self.proc.returncode != 0 and fail or
131            (self.proc.returncode < 0 and expect_fail)):
132            if len(stderr) > 0 and stderr[-1] == "\n":
133                stderr = stderr[:-1]
134            raise CmdExitFailure("Command failed", self)
135        elif self.proc.returncode == 0 and expect_fail:
136            raise CmdExitZeroFailure("Command succeeded (expected fail)", self)
137
138
139    def __repr__(self):
140        def str_fmt(name, s):
141            name += ': '
142            return (name + s.strip().replace('\n', '\n' + ' ' * len(name)))
143
144        ret = "CMD"
145        if self.host:
146            ret += "[remote]"
147        if self.ret is None:
148            ret += f" (unterminated): {self.comm}\n"
149        elif self.ret == 0:
150            ret += f" (success): {self.comm}\n"
151        else:
152            ret += f": {self.comm}\n"
153            ret += f"  EXIT: {self.ret}\n"
154        if self.stdout:
155            ret += str_fmt("  STDOUT", self.stdout) + "\n"
156        if self.stderr:
157            ret += str_fmt("  STDERR", self.stderr) + "\n"
158        return ret.strip()
159
160
161class bkg(cmd):
162    """
163    Run a command in the background.
164
165    Examples usage:
166
167    Run a command on remote host, and wait for it to finish.
168    This is usually paired with wait_port_listen() to make sure
169    the command has initialized:
170
171        with bkg("socat ...", exit_wait=True, host=cfg.remote) as nc:
172            ...
173
174    Run a command and expect it to let us know that it's ready
175    by writing to a special file descriptor passed via KSFT_READY_FD.
176    Command will be terminated when we exit the context manager:
177
178        with bkg("my_binary", ksft_wait=5):
179    """
180    def __init__(self, comm, shell=None, fail=None, expect_fail=None,
181                 ns=None, host=None, exit_wait=False, ksft_ready=None,
182                 ksft_wait=None):
183        super().__init__(comm, background=True,
184                         shell=shell, fail=fail, expect_fail=expect_fail,
185                         ns=ns, host=host, ksft_ready=ksft_ready,
186                         ksft_wait=ksft_wait)
187        self.terminate = not exit_wait and not ksft_wait
188        self._exit_wait = exit_wait
189        self.check_fail = fail
190        self.expect_fail = expect_fail
191
192        if shell and self.terminate:
193            print("# Warning: combining shell and terminate is risky!")
194            print("#          SIGTERM may not reach the child on zsh/ksh!")
195
196    def __enter__(self):
197        return self
198
199    def __exit__(self, ex_type, ex_value, ex_tb):
200        terminate = self.terminate
201        # Force termination on exception, but only if bkg() didn't already exit
202        # since forcing termination silences failures with fail=None
203        if self.proc.poll() is None:
204            terminate = terminate or (self._exit_wait and ex_type is not None)
205        return self.process(terminate=terminate, fail=self.check_fail,
206                            expect_fail=self.expect_fail)
207
208
209GLOBAL_DEFER_QUEUE = []
210GLOBAL_DEFER_ARMED = False
211
212
213class defer:
214    def __init__(self, func, *args, **kwargs):
215        if not callable(func):
216            raise Exception("defer created with un-callable object, did you call the function instead of passing its name?")
217
218        self.func = func
219        self.args = args
220        self.kwargs = kwargs
221
222        if not GLOBAL_DEFER_ARMED:
223            raise Exception("defer queue not armed, did you use defer() outside of a test case?")
224        self._queue = GLOBAL_DEFER_QUEUE
225        self._queue.append(self)
226
227    def __enter__(self):
228        return self
229
230    def __exit__(self, ex_type, ex_value, ex_tb):
231        return self.exec()
232
233    def exec_only(self):
234        self.func(*self.args, **self.kwargs)
235
236    def cancel(self):
237        self._queue.remove(self)
238
239    def exec(self):
240        self.cancel()
241        self.exec_only()
242
243
244def tool(name, args, json=None, ns=None, host=None):
245    cmd_str = name + ' '
246    if json:
247        if name == 'tc':
248            cmd_str += '-json '
249        else:
250            cmd_str += '--json '
251    cmd_str += args
252    cmd_obj = cmd(cmd_str, ns=ns, host=host)
253    if json:
254        return _json.loads(cmd_obj.stdout)
255    return cmd_obj
256
257
258def bpftool(args, json=None, ns=None, host=None):
259    return tool('bpftool', args, json=json, ns=ns, host=host)
260
261
262def ip(args, json=None, ns=None, host=None):
263    if ns:
264        args = f'-netns {ns} ' + args
265    return tool('ip', args, json=json, host=host)
266
267
268def tc(args, json=None, ns=None, host=None):
269    """ Helper to call tc with standard set of optional args. """
270    if ns:
271        args = f'-netns {ns} ' + args
272    return tool('tc', args, json=json, host=host)
273
274
275def ethtool(args, json=None, ns=None, host=None):
276    return tool('ethtool', args, json=json, ns=ns, host=host)
277
278
279def bpftrace(expr, json=None, ns=None, host=None, timeout=None):
280    """
281    Run bpftrace and return map data (if json=True).
282    The output of bpftrace is inconvenient, so the helper converts
283    to a dict indexed by map name, e.g.:
284     {
285       "@":     { ... },
286       "@map2": { ... },
287     }
288    """
289    cmd_arr = ['bpftrace']
290    # Throw in --quiet if json, otherwise the output has two objects
291    if json:
292        cmd_arr += ['-f', 'json', '-q']
293    if timeout:
294        expr += ' interval:s:' + str(timeout) + ' { exit(); }'
295        timeout += 20
296    cmd_arr += ['-e', expr]
297    cmd_obj = cmd(cmd_arr, ns=ns, host=host, shell=False, timeout=timeout)
298    if json:
299        # bpftrace prints objects as lines
300        ret = {}
301        for l in cmd_obj.stdout.split('\n'):
302            if not l.strip():
303                continue
304            one = _json.loads(l)
305            if one.get('type') != 'map':
306                continue
307            for k, v in one["data"].items():
308                if k.startswith('@'):
309                    k = k.lstrip('@')
310                ret[k] = v
311        return ret
312    return cmd_obj
313
314
315def rand_port(stype=socket.SOCK_STREAM):
316    """
317    Get a random unprivileged port.
318    """
319    return rand_ports(1, stype)[0]
320
321
322def rand_ports(count, stype=socket.SOCK_STREAM):
323    """
324    Get a unique set of random unprivileged ports.
325    """
326    sockets = []
327    ports = []
328
329    try:
330        for _ in range(count):
331            s = socket.socket(socket.AF_INET6, stype)
332            sockets.append(s)
333            s.bind(("", 0))
334            ports.append(s.getsockname()[1])
335    finally:
336        for s in sockets:
337            s.close()
338
339    return ports
340
341
342def wait_port_listen(port, proto="tcp", ns=None, host=None, sleep=0.005, deadline=5):
343    end = time.monotonic() + deadline
344
345    pattern = f":{port:04X} .* "
346    if proto == "tcp": # for tcp protocol additionally check the socket state
347        pattern += "0A"
348    pattern = re.compile(pattern)
349
350    while True:
351        data = cmd(f'cat /proc/net/{proto}*', ns=ns, host=host, shell=True).stdout
352        for row in data.split("\n"):
353            if pattern.search(row):
354                return
355        if time.monotonic() > end:
356            raise Exception("Waiting for port listen timed out")
357        time.sleep(sleep)
358
359
360def wait_file(fname, test_fn, sleep=0.005, deadline=5, encoding='utf-8'):
361    """
362    Wait for file contents on the local system to satisfy a condition.
363    test_fn() should take one argument (file contents) and return whether
364    condition is met.
365    """
366    end = time.monotonic() + deadline
367
368    with open(fname, "r", encoding=encoding) as fp:
369        while True:
370            if test_fn(fp.read()):
371                break
372            fp.seek(0)
373            if time.monotonic() > end:
374                raise TimeoutError("Wait for file contents failed", fname)
375            time.sleep(sleep)
376