1# SPDX-License-Identifier: GPL-2.0 2 3import json as _json 4import os 5import re 6import select 7import socket 8import subprocess 9import time 10 11 12class CmdInitFailure(Exception): 13 """ Command failed to start. Only raised by bkg(). """ 14 def __init__(self, msg, cmd_obj): 15 super().__init__(msg + "\n" + repr(cmd_obj)) 16 self.cmd = cmd_obj 17 18 19class CmdExitFailure(Exception): 20 """ Command failed (returned non-zero exit code). """ 21 def __init__(self, msg, cmd_obj): 22 super().__init__(msg + "\n" + repr(cmd_obj)) 23 self.cmd = cmd_obj 24 25 26class CmdExitZeroFailure(CmdExitFailure): 27 """ Command succeeded (returned zero exit code), but expected failure. """ 28 29 30def fd_read_timeout(fd, timeout): 31 rlist, _, _ = select.select([fd], [], [], timeout) 32 if rlist: 33 return os.read(fd, 1024) 34 raise TimeoutError("Timeout waiting for fd read") 35 36 37class cmd: 38 """ 39 Execute a command on local or remote host. 40 41 @shell defaults to false, and class will try to split @comm into a list 42 if it's a string with spaces. 43 44 Use bkg() instead to run a command in the background. 45 """ 46 def __init__(self, comm, shell=None, fail=True, expect_fail=False, ns=None, 47 background=False, host=None, timeout=5, ksft_ready=None, 48 ksft_wait=None): 49 if ns: 50 if hasattr(ns, 'user_ns_path'): 51 comm = (f'nsenter --user={ns.user_ns_path} ' 52 f'--net={ns.net_ns_path} --setuid=0 --setgid=0 -- ' 53 + comm) 54 else: 55 comm = f'ip netns exec {ns} ' + comm 56 57 self.stdout = None 58 self.stderr = None 59 self.ret = None 60 self.ksft_term_fd = None 61 62 self.host = host 63 self.comm = comm 64 65 if host: 66 self.proc = host.cmd(comm) 67 else: 68 # If user doesn't explicitly request shell try to avoid it. 69 if shell is None and isinstance(comm, str) and ' ' in comm: 70 comm = comm.split() 71 72 # ksft_wait lets us wait for the background process to fully start, 73 # we pass an FD to the child process, and wait for it to write back. 74 # Similarly term_fd tells child it's time to exit. 75 pass_fds = [] 76 env = os.environ.copy() 77 if ksft_wait is not None: 78 wait_fd, self.ksft_term_fd = os.pipe() 79 pass_fds.append(wait_fd) 80 env["KSFT_WAIT_FD"] = str(wait_fd) 81 ksft_ready = True # ksft_wait implies ready 82 if ksft_ready is not None: 83 rfd, ready_fd = os.pipe() 84 pass_fds.append(ready_fd) 85 env["KSFT_READY_FD"] = str(ready_fd) 86 87 self.proc = subprocess.Popen(comm, shell=shell, stdout=subprocess.PIPE, 88 stderr=subprocess.PIPE, pass_fds=pass_fds, 89 env=env) 90 if ksft_wait is not None: 91 os.close(wait_fd) 92 if ksft_ready is not None: 93 os.close(ready_fd) 94 msg = fd_read_timeout(rfd, ksft_wait) 95 os.close(rfd) 96 if not msg: 97 terminate = self.proc.poll() is None 98 self._process_terminate(terminate=terminate, timeout=1) 99 raise CmdInitFailure("Did not receive ready message", self) 100 if not background: 101 self.process(terminate=False, fail=fail, expect_fail=expect_fail, 102 timeout=timeout) 103 104 def _process_terminate(self, terminate, timeout): 105 if terminate: 106 self.proc.terminate() 107 stdout, stderr = self.proc.communicate(timeout=timeout) 108 self.stdout = stdout.decode("utf-8") 109 self.stderr = stderr.decode("utf-8") 110 self.proc.stdout.close() 111 self.proc.stderr.close() 112 self.ret = self.proc.returncode 113 114 return stdout, stderr 115 116 def process(self, terminate=True, fail=None, expect_fail=False, timeout=5): 117 if fail is None: 118 fail = not terminate 119 120 if self.ksft_term_fd: 121 os.write(self.ksft_term_fd, b"1") 122 123 stdout, stderr = self._process_terminate(terminate=terminate, 124 timeout=timeout) 125 126 # Fail on unexpected test failure if fail. 127 # Fail on unexpected test success if expect_fail. 128 # Fail on negative returncode if either: 129 # Set by subprocess on crash or signal, this is never expected failure. 130 if (self.proc.returncode != 0 and fail or 131 (self.proc.returncode < 0 and expect_fail)): 132 if len(stderr) > 0 and stderr[-1] == "\n": 133 stderr = stderr[:-1] 134 raise CmdExitFailure("Command failed", self) 135 elif self.proc.returncode == 0 and expect_fail: 136 raise CmdExitZeroFailure("Command succeeded (expected fail)", self) 137 138 139 def __repr__(self): 140 def str_fmt(name, s): 141 name += ': ' 142 return (name + s.strip().replace('\n', '\n' + ' ' * len(name))) 143 144 ret = "CMD" 145 if self.host: 146 ret += "[remote]" 147 if self.ret is None: 148 ret += f" (unterminated): {self.comm}\n" 149 elif self.ret == 0: 150 ret += f" (success): {self.comm}\n" 151 else: 152 ret += f": {self.comm}\n" 153 ret += f" EXIT: {self.ret}\n" 154 if self.stdout: 155 ret += str_fmt(" STDOUT", self.stdout) + "\n" 156 if self.stderr: 157 ret += str_fmt(" STDERR", self.stderr) + "\n" 158 return ret.strip() 159 160 161class bkg(cmd): 162 """ 163 Run a command in the background. 164 165 Examples usage: 166 167 Run a command on remote host, and wait for it to finish. 168 This is usually paired with wait_port_listen() to make sure 169 the command has initialized: 170 171 with bkg("socat ...", exit_wait=True, host=cfg.remote) as nc: 172 ... 173 174 Run a command and expect it to let us know that it's ready 175 by writing to a special file descriptor passed via KSFT_READY_FD. 176 Command will be terminated when we exit the context manager: 177 178 with bkg("my_binary", ksft_wait=5): 179 """ 180 def __init__(self, comm, shell=None, fail=None, expect_fail=None, 181 ns=None, host=None, exit_wait=False, ksft_ready=None, 182 ksft_wait=None): 183 super().__init__(comm, background=True, 184 shell=shell, fail=fail, expect_fail=expect_fail, 185 ns=ns, host=host, ksft_ready=ksft_ready, 186 ksft_wait=ksft_wait) 187 self.terminate = not exit_wait and not ksft_wait 188 self._exit_wait = exit_wait 189 self.check_fail = fail 190 self.expect_fail = expect_fail 191 192 if shell and self.terminate: 193 print("# Warning: combining shell and terminate is risky!") 194 print("# SIGTERM may not reach the child on zsh/ksh!") 195 196 def __enter__(self): 197 return self 198 199 def __exit__(self, ex_type, ex_value, ex_tb): 200 terminate = self.terminate 201 # Force termination on exception, but only if bkg() didn't already exit 202 # since forcing termination silences failures with fail=None 203 if self.proc.poll() is None: 204 terminate = terminate or (self._exit_wait and ex_type is not None) 205 return self.process(terminate=terminate, fail=self.check_fail, 206 expect_fail=self.expect_fail) 207 208 209GLOBAL_DEFER_QUEUE = [] 210GLOBAL_DEFER_ARMED = False 211 212 213class defer: 214 def __init__(self, func, *args, **kwargs): 215 if not callable(func): 216 raise Exception("defer created with un-callable object, did you call the function instead of passing its name?") 217 218 self.func = func 219 self.args = args 220 self.kwargs = kwargs 221 222 if not GLOBAL_DEFER_ARMED: 223 raise Exception("defer queue not armed, did you use defer() outside of a test case?") 224 self._queue = GLOBAL_DEFER_QUEUE 225 self._queue.append(self) 226 227 def __enter__(self): 228 return self 229 230 def __exit__(self, ex_type, ex_value, ex_tb): 231 return self.exec() 232 233 def exec_only(self): 234 self.func(*self.args, **self.kwargs) 235 236 def cancel(self): 237 self._queue.remove(self) 238 239 def exec(self): 240 self.cancel() 241 self.exec_only() 242 243 244def tool(name, args, json=None, ns=None, host=None): 245 cmd_str = name + ' ' 246 if json: 247 if name == 'tc': 248 cmd_str += '-json ' 249 else: 250 cmd_str += '--json ' 251 cmd_str += args 252 cmd_obj = cmd(cmd_str, ns=ns, host=host) 253 if json: 254 return _json.loads(cmd_obj.stdout) 255 return cmd_obj 256 257 258def bpftool(args, json=None, ns=None, host=None): 259 return tool('bpftool', args, json=json, ns=ns, host=host) 260 261 262def ip(args, json=None, ns=None, host=None): 263 if ns: 264 args = f'-netns {ns} ' + args 265 return tool('ip', args, json=json, host=host) 266 267 268def tc(args, json=None, ns=None, host=None): 269 """ Helper to call tc with standard set of optional args. """ 270 if ns: 271 args = f'-netns {ns} ' + args 272 return tool('tc', args, json=json, host=host) 273 274 275def ethtool(args, json=None, ns=None, host=None): 276 return tool('ethtool', args, json=json, ns=ns, host=host) 277 278 279def bpftrace(expr, json=None, ns=None, host=None, timeout=None): 280 """ 281 Run bpftrace and return map data (if json=True). 282 The output of bpftrace is inconvenient, so the helper converts 283 to a dict indexed by map name, e.g.: 284 { 285 "@": { ... }, 286 "@map2": { ... }, 287 } 288 """ 289 cmd_arr = ['bpftrace'] 290 # Throw in --quiet if json, otherwise the output has two objects 291 if json: 292 cmd_arr += ['-f', 'json', '-q'] 293 if timeout: 294 expr += ' interval:s:' + str(timeout) + ' { exit(); }' 295 timeout += 20 296 cmd_arr += ['-e', expr] 297 cmd_obj = cmd(cmd_arr, ns=ns, host=host, shell=False, timeout=timeout) 298 if json: 299 # bpftrace prints objects as lines 300 ret = {} 301 for l in cmd_obj.stdout.split('\n'): 302 if not l.strip(): 303 continue 304 one = _json.loads(l) 305 if one.get('type') != 'map': 306 continue 307 for k, v in one["data"].items(): 308 if k.startswith('@'): 309 k = k.lstrip('@') 310 ret[k] = v 311 return ret 312 return cmd_obj 313 314 315def rand_port(stype=socket.SOCK_STREAM): 316 """ 317 Get a random unprivileged port. 318 """ 319 return rand_ports(1, stype)[0] 320 321 322def rand_ports(count, stype=socket.SOCK_STREAM): 323 """ 324 Get a unique set of random unprivileged ports. 325 """ 326 sockets = [] 327 ports = [] 328 329 try: 330 for _ in range(count): 331 s = socket.socket(socket.AF_INET6, stype) 332 sockets.append(s) 333 s.bind(("", 0)) 334 ports.append(s.getsockname()[1]) 335 finally: 336 for s in sockets: 337 s.close() 338 339 return ports 340 341 342def wait_port_listen(port, proto="tcp", ns=None, host=None, sleep=0.005, deadline=5): 343 end = time.monotonic() + deadline 344 345 pattern = f":{port:04X} .* " 346 if proto == "tcp": # for tcp protocol additionally check the socket state 347 pattern += "0A" 348 pattern = re.compile(pattern) 349 350 while True: 351 data = cmd(f'cat /proc/net/{proto}*', ns=ns, host=host, shell=True).stdout 352 for row in data.split("\n"): 353 if pattern.search(row): 354 return 355 if time.monotonic() > end: 356 raise Exception("Waiting for port listen timed out") 357 time.sleep(sleep) 358 359 360def wait_file(fname, test_fn, sleep=0.005, deadline=5, encoding='utf-8'): 361 """ 362 Wait for file contents on the local system to satisfy a condition. 363 test_fn() should take one argument (file contents) and return whether 364 condition is met. 365 """ 366 end = time.monotonic() + deadline 367 368 with open(fname, "r", encoding=encoding) as fp: 369 while True: 370 if test_fn(fp.read()): 371 break 372 fp.seek(0) 373 if time.monotonic() > end: 374 raise TimeoutError("Wait for file contents failed", fname) 375 time.sleep(sleep) 376