1# SPDX-License-Identifier: GPL-2.0 2 3import json as _json 4import os 5import re 6import select 7import socket 8import subprocess 9import time 10 11 12class CmdInitFailure(Exception): 13 """ Command failed to start. Only raised by bkg(). """ 14 def __init__(self, msg, cmd_obj): 15 super().__init__(msg + "\n" + repr(cmd_obj)) 16 self.cmd = cmd_obj 17 18 19class CmdExitFailure(Exception): 20 """ Command failed (returned non-zero exit code). """ 21 def __init__(self, msg, cmd_obj): 22 super().__init__(msg + "\n" + repr(cmd_obj)) 23 self.cmd = cmd_obj 24 25 26class CmdExitZeroFailure(CmdExitFailure): 27 """ Command succeeded (returned zero exit code), but expected failure. """ 28 29 30def fd_read_timeout(fd, timeout): 31 rlist, _, _ = select.select([fd], [], [], timeout) 32 if rlist: 33 return os.read(fd, 1024) 34 raise TimeoutError("Timeout waiting for fd read") 35 36 37class cmd: 38 """ 39 Execute a command on local or remote host. 40 41 @shell defaults to false, and class will try to split @comm into a list 42 if it's a string with spaces. 43 44 Use bkg() instead to run a command in the background. 45 """ 46 def __init__(self, comm, shell=None, fail=True, expect_fail=False, ns=None, 47 background=False, host=None, timeout=5, ksft_ready=None, 48 ksft_wait=None): 49 if ns: 50 comm = f'ip netns exec {ns} ' + comm 51 52 self.stdout = None 53 self.stderr = None 54 self.ret = None 55 self.ksft_term_fd = None 56 57 self.host = host 58 self.comm = comm 59 60 if host: 61 self.proc = host.cmd(comm) 62 else: 63 # If user doesn't explicitly request shell try to avoid it. 64 if shell is None and isinstance(comm, str) and ' ' in comm: 65 comm = comm.split() 66 67 # ksft_wait lets us wait for the background process to fully start, 68 # we pass an FD to the child process, and wait for it to write back. 69 # Similarly term_fd tells child it's time to exit. 70 pass_fds = [] 71 env = os.environ.copy() 72 if ksft_wait is not None: 73 wait_fd, self.ksft_term_fd = os.pipe() 74 pass_fds.append(wait_fd) 75 env["KSFT_WAIT_FD"] = str(wait_fd) 76 ksft_ready = True # ksft_wait implies ready 77 if ksft_ready is not None: 78 rfd, ready_fd = os.pipe() 79 pass_fds.append(ready_fd) 80 env["KSFT_READY_FD"] = str(ready_fd) 81 82 self.proc = subprocess.Popen(comm, shell=shell, stdout=subprocess.PIPE, 83 stderr=subprocess.PIPE, pass_fds=pass_fds, 84 env=env) 85 if ksft_wait is not None: 86 os.close(wait_fd) 87 if ksft_ready is not None: 88 os.close(ready_fd) 89 msg = fd_read_timeout(rfd, ksft_wait) 90 os.close(rfd) 91 if not msg: 92 terminate = self.proc.poll() is None 93 self._process_terminate(terminate=terminate, timeout=1) 94 raise CmdInitFailure("Did not receive ready message", self) 95 if not background: 96 self.process(terminate=False, fail=fail, expect_fail=expect_fail, 97 timeout=timeout) 98 99 def _process_terminate(self, terminate, timeout): 100 if terminate: 101 self.proc.terminate() 102 stdout, stderr = self.proc.communicate(timeout=timeout) 103 self.stdout = stdout.decode("utf-8") 104 self.stderr = stderr.decode("utf-8") 105 self.proc.stdout.close() 106 self.proc.stderr.close() 107 self.ret = self.proc.returncode 108 109 return stdout, stderr 110 111 def process(self, terminate=True, fail=None, expect_fail=False, timeout=5): 112 if fail is None: 113 fail = not terminate 114 115 if self.ksft_term_fd: 116 os.write(self.ksft_term_fd, b"1") 117 118 stdout, stderr = self._process_terminate(terminate=terminate, 119 timeout=timeout) 120 121 # Fail on unexpected test failure if fail. 122 # Fail on unexpected test success if expect_fail. 123 # Fail on negative returncode if either: 124 # Set by subprocess on crash or signal, this is never expected failure. 125 if (self.proc.returncode != 0 and fail or 126 (self.proc.returncode < 0 and expect_fail)): 127 if len(stderr) > 0 and stderr[-1] == "\n": 128 stderr = stderr[:-1] 129 raise CmdExitFailure("Command failed", self) 130 elif self.proc.returncode == 0 and expect_fail: 131 raise CmdExitZeroFailure("Command succeeded (expected fail)", self) 132 133 134 def __repr__(self): 135 def str_fmt(name, s): 136 name += ': ' 137 return (name + s.strip().replace('\n', '\n' + ' ' * len(name))) 138 139 ret = "CMD" 140 if self.host: 141 ret += "[remote]" 142 if self.ret is None: 143 ret += f" (unterminated): {self.comm}\n" 144 elif self.ret == 0: 145 ret += f" (success): {self.comm}\n" 146 else: 147 ret += f": {self.comm}\n" 148 ret += f" EXIT: {self.ret}\n" 149 if self.stdout: 150 ret += str_fmt(" STDOUT", self.stdout) + "\n" 151 if self.stderr: 152 ret += str_fmt(" STDERR", self.stderr) + "\n" 153 return ret.strip() 154 155 156class bkg(cmd): 157 """ 158 Run a command in the background. 159 160 Examples usage: 161 162 Run a command on remote host, and wait for it to finish. 163 This is usually paired with wait_port_listen() to make sure 164 the command has initialized: 165 166 with bkg("socat ...", exit_wait=True, host=cfg.remote) as nc: 167 ... 168 169 Run a command and expect it to let us know that it's ready 170 by writing to a special file descriptor passed via KSFT_READY_FD. 171 Command will be terminated when we exit the context manager: 172 173 with bkg("my_binary", ksft_wait=5): 174 """ 175 def __init__(self, comm, shell=None, fail=None, expect_fail=None, 176 ns=None, host=None, exit_wait=False, ksft_ready=None, 177 ksft_wait=None): 178 super().__init__(comm, background=True, 179 shell=shell, fail=fail, expect_fail=expect_fail, 180 ns=ns, host=host, ksft_ready=ksft_ready, 181 ksft_wait=ksft_wait) 182 self.terminate = not exit_wait and not ksft_wait 183 self._exit_wait = exit_wait 184 self.check_fail = fail 185 self.expect_fail = expect_fail 186 187 if shell and self.terminate: 188 print("# Warning: combining shell and terminate is risky!") 189 print("# SIGTERM may not reach the child on zsh/ksh!") 190 191 def __enter__(self): 192 return self 193 194 def __exit__(self, ex_type, ex_value, ex_tb): 195 terminate = self.terminate 196 # Force termination on exception, but only if bkg() didn't already exit 197 # since forcing termination silences failures with fail=None 198 if self.proc.poll() is None: 199 terminate = terminate or (self._exit_wait and ex_type is not None) 200 return self.process(terminate=terminate, fail=self.check_fail, 201 expect_fail=self.expect_fail) 202 203 204GLOBAL_DEFER_QUEUE = [] 205GLOBAL_DEFER_ARMED = False 206 207 208class defer: 209 def __init__(self, func, *args, **kwargs): 210 if not callable(func): 211 raise Exception("defer created with un-callable object, did you call the function instead of passing its name?") 212 213 self.func = func 214 self.args = args 215 self.kwargs = kwargs 216 217 if not GLOBAL_DEFER_ARMED: 218 raise Exception("defer queue not armed, did you use defer() outside of a test case?") 219 self._queue = GLOBAL_DEFER_QUEUE 220 self._queue.append(self) 221 222 def __enter__(self): 223 return self 224 225 def __exit__(self, ex_type, ex_value, ex_tb): 226 return self.exec() 227 228 def exec_only(self): 229 self.func(*self.args, **self.kwargs) 230 231 def cancel(self): 232 self._queue.remove(self) 233 234 def exec(self): 235 self.cancel() 236 self.exec_only() 237 238 239def tool(name, args, json=None, ns=None, host=None): 240 cmd_str = name + ' ' 241 if json: 242 if name == 'tc': 243 cmd_str += '-json ' 244 else: 245 cmd_str += '--json ' 246 cmd_str += args 247 cmd_obj = cmd(cmd_str, ns=ns, host=host) 248 if json: 249 return _json.loads(cmd_obj.stdout) 250 return cmd_obj 251 252 253def bpftool(args, json=None, ns=None, host=None): 254 return tool('bpftool', args, json=json, ns=ns, host=host) 255 256 257def ip(args, json=None, ns=None, host=None): 258 if ns: 259 args = f'-netns {ns} ' + args 260 return tool('ip', args, json=json, host=host) 261 262 263def tc(args, json=None, ns=None, host=None): 264 """ Helper to call tc with standard set of optional args. """ 265 if ns: 266 args = f'-netns {ns} ' + args 267 return tool('tc', args, json=json, host=host) 268 269 270def ethtool(args, json=None, ns=None, host=None): 271 return tool('ethtool', args, json=json, ns=ns, host=host) 272 273 274def bpftrace(expr, json=None, ns=None, host=None, timeout=None): 275 """ 276 Run bpftrace and return map data (if json=True). 277 The output of bpftrace is inconvenient, so the helper converts 278 to a dict indexed by map name, e.g.: 279 { 280 "@": { ... }, 281 "@map2": { ... }, 282 } 283 """ 284 cmd_arr = ['bpftrace'] 285 # Throw in --quiet if json, otherwise the output has two objects 286 if json: 287 cmd_arr += ['-f', 'json', '-q'] 288 if timeout: 289 expr += ' interval:s:' + str(timeout) + ' { exit(); }' 290 timeout += 20 291 cmd_arr += ['-e', expr] 292 cmd_obj = cmd(cmd_arr, ns=ns, host=host, shell=False, timeout=timeout) 293 if json: 294 # bpftrace prints objects as lines 295 ret = {} 296 for l in cmd_obj.stdout.split('\n'): 297 if not l.strip(): 298 continue 299 one = _json.loads(l) 300 if one.get('type') != 'map': 301 continue 302 for k, v in one["data"].items(): 303 if k.startswith('@'): 304 k = k.lstrip('@') 305 ret[k] = v 306 return ret 307 return cmd_obj 308 309 310def rand_port(stype=socket.SOCK_STREAM): 311 """ 312 Get a random unprivileged port. 313 """ 314 return rand_ports(1, stype)[0] 315 316 317def rand_ports(count, stype=socket.SOCK_STREAM): 318 """ 319 Get a unique set of random unprivileged ports. 320 """ 321 sockets = [] 322 ports = [] 323 324 try: 325 for _ in range(count): 326 s = socket.socket(socket.AF_INET6, stype) 327 sockets.append(s) 328 s.bind(("", 0)) 329 ports.append(s.getsockname()[1]) 330 finally: 331 for s in sockets: 332 s.close() 333 334 return ports 335 336 337def wait_port_listen(port, proto="tcp", ns=None, host=None, sleep=0.005, deadline=5): 338 end = time.monotonic() + deadline 339 340 pattern = f":{port:04X} .* " 341 if proto == "tcp": # for tcp protocol additionally check the socket state 342 pattern += "0A" 343 pattern = re.compile(pattern) 344 345 while True: 346 data = cmd(f'cat /proc/net/{proto}*', ns=ns, host=host, shell=True).stdout 347 for row in data.split("\n"): 348 if pattern.search(row): 349 return 350 if time.monotonic() > end: 351 raise Exception("Waiting for port listen timed out") 352 time.sleep(sleep) 353 354 355def wait_file(fname, test_fn, sleep=0.005, deadline=5, encoding='utf-8'): 356 """ 357 Wait for file contents on the local system to satisfy a condition. 358 test_fn() should take one argument (file contents) and return whether 359 condition is met. 360 """ 361 end = time.monotonic() + deadline 362 363 with open(fname, "r", encoding=encoding) as fp: 364 while True: 365 if test_fn(fp.read()): 366 break 367 fp.seek(0) 368 if time.monotonic() > end: 369 raise TimeoutError("Wait for file contents failed", fname) 370 time.sleep(sleep) 371