xref: /linux/tools/testing/selftests/net/lib/py/utils.py (revision 6a4c4656b0d2d4056a1f0c35442db4e8a5cf8021)
1# SPDX-License-Identifier: GPL-2.0
2
3import json as _json
4import os
5import re
6import select
7import socket
8import subprocess
9import time
10
11
12class CmdInitFailure(Exception):
13    """ Command failed to start. Only raised by bkg(). """
14    def __init__(self, msg, cmd_obj):
15        super().__init__(msg + "\n" + repr(cmd_obj))
16        self.cmd = cmd_obj
17
18
19class CmdExitFailure(Exception):
20    """ Command failed (returned non-zero exit code). """
21    def __init__(self, msg, cmd_obj):
22        super().__init__(msg + "\n" + repr(cmd_obj))
23        self.cmd = cmd_obj
24
25
26class CmdExitZeroFailure(CmdExitFailure):
27    """ Command succeeded (returned zero exit code), but expected failure. """
28
29
30def fd_read_timeout(fd, timeout):
31    rlist, _, _ = select.select([fd], [], [], timeout)
32    if rlist:
33        return os.read(fd, 1024)
34    raise TimeoutError("Timeout waiting for fd read")
35
36
37class cmd:
38    """
39    Execute a command on local or remote host.
40
41    @shell defaults to false, and class will try to split @comm into a list
42    if it's a string with spaces.
43
44    Use bkg() instead to run a command in the background.
45    """
46    def __init__(self, comm, shell=None, fail=True, expect_fail=False, ns=None,
47                 background=False, host=None, timeout=5, ksft_ready=None,
48                 ksft_wait=None):
49        if ns:
50            comm = f'ip netns exec {ns} ' + comm
51
52        self.stdout = None
53        self.stderr = None
54        self.ret = None
55        self.ksft_term_fd = None
56
57        self.host = host
58        self.comm = comm
59
60        if host:
61            self.proc = host.cmd(comm)
62        else:
63            # If user doesn't explicitly request shell try to avoid it.
64            if shell is None and isinstance(comm, str) and ' ' in comm:
65                comm = comm.split()
66
67            # ksft_wait lets us wait for the background process to fully start,
68            # we pass an FD to the child process, and wait for it to write back.
69            # Similarly term_fd tells child it's time to exit.
70            pass_fds = []
71            env = os.environ.copy()
72            if ksft_wait is not None:
73                wait_fd, self.ksft_term_fd = os.pipe()
74                pass_fds.append(wait_fd)
75                env["KSFT_WAIT_FD"]  = str(wait_fd)
76                ksft_ready = True  # ksft_wait implies ready
77            if ksft_ready is not None:
78                rfd, ready_fd = os.pipe()
79                pass_fds.append(ready_fd)
80                env["KSFT_READY_FD"] = str(ready_fd)
81
82            self.proc = subprocess.Popen(comm, shell=shell, stdout=subprocess.PIPE,
83                                         stderr=subprocess.PIPE, pass_fds=pass_fds,
84                                         env=env)
85            if ksft_wait is not None:
86                os.close(wait_fd)
87            if ksft_ready is not None:
88                os.close(ready_fd)
89                msg = fd_read_timeout(rfd, ksft_wait)
90                os.close(rfd)
91                if not msg:
92                    terminate = self.proc.poll() is None
93                    self._process_terminate(terminate=terminate, timeout=1)
94                    raise CmdInitFailure("Did not receive ready message", self)
95        if not background:
96            self.process(terminate=False, fail=fail, expect_fail=expect_fail,
97                         timeout=timeout)
98
99    def _process_terminate(self, terminate, timeout):
100        if terminate:
101            self.proc.terminate()
102        stdout, stderr = self.proc.communicate(timeout=timeout)
103        self.stdout = stdout.decode("utf-8")
104        self.stderr = stderr.decode("utf-8")
105        self.proc.stdout.close()
106        self.proc.stderr.close()
107        self.ret = self.proc.returncode
108
109        return stdout, stderr
110
111    def process(self, terminate=True, fail=None, expect_fail=False, timeout=5):
112        if fail is None:
113            fail = not terminate
114
115        if self.ksft_term_fd:
116            os.write(self.ksft_term_fd, b"1")
117
118        stdout, stderr = self._process_terminate(terminate=terminate,
119                                                 timeout=timeout)
120
121        # Fail on unexpected test failure if fail.
122        # Fail on unexpected test success if expect_fail.
123        # Fail on negative returncode if either:
124        # Set by subprocess on crash or signal, this is never expected failure.
125        if (self.proc.returncode != 0 and fail or
126            (self.proc.returncode < 0 and expect_fail)):
127            if len(stderr) > 0 and stderr[-1] == "\n":
128                stderr = stderr[:-1]
129            raise CmdExitFailure("Command failed", self)
130        elif self.proc.returncode == 0 and expect_fail:
131            raise CmdExitZeroFailure("Command succeeded (expected fail)", self)
132
133
134    def __repr__(self):
135        def str_fmt(name, s):
136            name += ': '
137            return (name + s.strip().replace('\n', '\n' + ' ' * len(name)))
138
139        ret = "CMD"
140        if self.host:
141            ret += "[remote]"
142        if self.ret is None:
143            ret += f" (unterminated): {self.comm}\n"
144        elif self.ret == 0:
145            ret += f" (success): {self.comm}\n"
146        else:
147            ret += f": {self.comm}\n"
148            ret += f"  EXIT: {self.ret}\n"
149        if self.stdout:
150            ret += str_fmt("  STDOUT", self.stdout) + "\n"
151        if self.stderr:
152            ret += str_fmt("  STDERR", self.stderr) + "\n"
153        return ret.strip()
154
155
156class bkg(cmd):
157    """
158    Run a command in the background.
159
160    Examples usage:
161
162    Run a command on remote host, and wait for it to finish.
163    This is usually paired with wait_port_listen() to make sure
164    the command has initialized:
165
166        with bkg("socat ...", exit_wait=True, host=cfg.remote) as nc:
167            ...
168
169    Run a command and expect it to let us know that it's ready
170    by writing to a special file descriptor passed via KSFT_READY_FD.
171    Command will be terminated when we exit the context manager:
172
173        with bkg("my_binary", ksft_wait=5):
174    """
175    def __init__(self, comm, shell=None, fail=None, expect_fail=None,
176                 ns=None, host=None, exit_wait=False, ksft_ready=None,
177                 ksft_wait=None):
178        super().__init__(comm, background=True,
179                         shell=shell, fail=fail, expect_fail=expect_fail,
180                         ns=ns, host=host, ksft_ready=ksft_ready,
181                         ksft_wait=ksft_wait)
182        self.terminate = not exit_wait and not ksft_wait
183        self._exit_wait = exit_wait
184        self.check_fail = fail
185        self.expect_fail = expect_fail
186
187        if shell and self.terminate:
188            print("# Warning: combining shell and terminate is risky!")
189            print("#          SIGTERM may not reach the child on zsh/ksh!")
190
191    def __enter__(self):
192        return self
193
194    def __exit__(self, ex_type, ex_value, ex_tb):
195        terminate = self.terminate
196        # Force termination on exception, but only if bkg() didn't already exit
197        # since forcing termination silences failures with fail=None
198        if self.proc.poll() is None:
199            terminate = terminate or (self._exit_wait and ex_type is not None)
200        return self.process(terminate=terminate, fail=self.check_fail,
201                            expect_fail=self.expect_fail)
202
203
204GLOBAL_DEFER_QUEUE = []
205GLOBAL_DEFER_ARMED = False
206
207
208class defer:
209    def __init__(self, func, *args, **kwargs):
210        if not callable(func):
211            raise Exception("defer created with un-callable object, did you call the function instead of passing its name?")
212
213        self.func = func
214        self.args = args
215        self.kwargs = kwargs
216
217        if not GLOBAL_DEFER_ARMED:
218            raise Exception("defer queue not armed, did you use defer() outside of a test case?")
219        self._queue = GLOBAL_DEFER_QUEUE
220        self._queue.append(self)
221
222    def __enter__(self):
223        return self
224
225    def __exit__(self, ex_type, ex_value, ex_tb):
226        return self.exec()
227
228    def exec_only(self):
229        self.func(*self.args, **self.kwargs)
230
231    def cancel(self):
232        self._queue.remove(self)
233
234    def exec(self):
235        self.cancel()
236        self.exec_only()
237
238
239def tool(name, args, json=None, ns=None, host=None):
240    cmd_str = name + ' '
241    if json:
242        if name == 'tc':
243            cmd_str += '-json '
244        else:
245            cmd_str += '--json '
246    cmd_str += args
247    cmd_obj = cmd(cmd_str, ns=ns, host=host)
248    if json:
249        return _json.loads(cmd_obj.stdout)
250    return cmd_obj
251
252
253def bpftool(args, json=None, ns=None, host=None):
254    return tool('bpftool', args, json=json, ns=ns, host=host)
255
256
257def ip(args, json=None, ns=None, host=None):
258    if ns:
259        args = f'-netns {ns} ' + args
260    return tool('ip', args, json=json, host=host)
261
262
263def tc(args, json=None, ns=None, host=None):
264    """ Helper to call tc with standard set of optional args. """
265    if ns:
266        args = f'-netns {ns} ' + args
267    return tool('tc', args, json=json, host=host)
268
269
270def ethtool(args, json=None, ns=None, host=None):
271    return tool('ethtool', args, json=json, ns=ns, host=host)
272
273
274def bpftrace(expr, json=None, ns=None, host=None, timeout=None):
275    """
276    Run bpftrace and return map data (if json=True).
277    The output of bpftrace is inconvenient, so the helper converts
278    to a dict indexed by map name, e.g.:
279     {
280       "@":     { ... },
281       "@map2": { ... },
282     }
283    """
284    cmd_arr = ['bpftrace']
285    # Throw in --quiet if json, otherwise the output has two objects
286    if json:
287        cmd_arr += ['-f', 'json', '-q']
288    if timeout:
289        expr += ' interval:s:' + str(timeout) + ' { exit(); }'
290        timeout += 20
291    cmd_arr += ['-e', expr]
292    cmd_obj = cmd(cmd_arr, ns=ns, host=host, shell=False, timeout=timeout)
293    if json:
294        # bpftrace prints objects as lines
295        ret = {}
296        for l in cmd_obj.stdout.split('\n'):
297            if not l.strip():
298                continue
299            one = _json.loads(l)
300            if one.get('type') != 'map':
301                continue
302            for k, v in one["data"].items():
303                if k.startswith('@'):
304                    k = k.lstrip('@')
305                ret[k] = v
306        return ret
307    return cmd_obj
308
309
310def rand_port(stype=socket.SOCK_STREAM):
311    """
312    Get a random unprivileged port.
313    """
314    return rand_ports(1, stype)[0]
315
316
317def rand_ports(count, stype=socket.SOCK_STREAM):
318    """
319    Get a unique set of random unprivileged ports.
320    """
321    sockets = []
322    ports = []
323
324    try:
325        for _ in range(count):
326            s = socket.socket(socket.AF_INET6, stype)
327            sockets.append(s)
328            s.bind(("", 0))
329            ports.append(s.getsockname()[1])
330    finally:
331        for s in sockets:
332            s.close()
333
334    return ports
335
336
337def wait_port_listen(port, proto="tcp", ns=None, host=None, sleep=0.005, deadline=5):
338    end = time.monotonic() + deadline
339
340    pattern = f":{port:04X} .* "
341    if proto == "tcp": # for tcp protocol additionally check the socket state
342        pattern += "0A"
343    pattern = re.compile(pattern)
344
345    while True:
346        data = cmd(f'cat /proc/net/{proto}*', ns=ns, host=host, shell=True).stdout
347        for row in data.split("\n"):
348            if pattern.search(row):
349                return
350        if time.monotonic() > end:
351            raise Exception("Waiting for port listen timed out")
352        time.sleep(sleep)
353
354
355def wait_file(fname, test_fn, sleep=0.005, deadline=5, encoding='utf-8'):
356    """
357    Wait for file contents on the local system to satisfy a condition.
358    test_fn() should take one argument (file contents) and return whether
359    condition is met.
360    """
361    end = time.monotonic() + deadline
362
363    with open(fname, "r", encoding=encoding) as fp:
364        while True:
365            if test_fn(fp.read()):
366                break
367            fp.seek(0)
368            if time.monotonic() > end:
369                raise TimeoutError("Wait for file contents failed", fname)
370            time.sleep(sleep)
371