1# SPDX-License-Identifier: GPL-2.0 2 3import json as _json 4import os 5import re 6import select 7import socket 8import subprocess 9import time 10 11 12class CmdInitFailure(Exception): 13 """ Command failed to start. Only raised by bkg(). """ 14 def __init__(self, msg, cmd_obj): 15 super().__init__(msg + "\n" + repr(cmd_obj)) 16 self.cmd = cmd_obj 17 18 19class CmdExitFailure(Exception): 20 """ Command failed (returned non-zero exit code). """ 21 def __init__(self, msg, cmd_obj): 22 super().__init__(msg + "\n" + repr(cmd_obj)) 23 self.cmd = cmd_obj 24 25 26def fd_read_timeout(fd, timeout): 27 rlist, _, _ = select.select([fd], [], [], timeout) 28 if rlist: 29 return os.read(fd, 1024) 30 raise TimeoutError("Timeout waiting for fd read") 31 32 33class cmd: 34 """ 35 Execute a command on local or remote host. 36 37 @shell defaults to false, and class will try to split @comm into a list 38 if it's a string with spaces. 39 40 Use bkg() instead to run a command in the background. 41 """ 42 def __init__(self, comm, shell=None, fail=True, ns=None, background=False, 43 host=None, timeout=5, ksft_ready=None, ksft_wait=None): 44 if ns: 45 comm = f'ip netns exec {ns} ' + comm 46 47 self.stdout = None 48 self.stderr = None 49 self.ret = None 50 self.ksft_term_fd = None 51 52 self.host = host 53 self.comm = comm 54 55 if host: 56 self.proc = host.cmd(comm) 57 else: 58 # If user doesn't explicitly request shell try to avoid it. 59 if shell is None and isinstance(comm, str) and ' ' in comm: 60 comm = comm.split() 61 62 # ksft_wait lets us wait for the background process to fully start, 63 # we pass an FD to the child process, and wait for it to write back. 64 # Similarly term_fd tells child it's time to exit. 65 pass_fds = [] 66 env = os.environ.copy() 67 if ksft_wait is not None: 68 wait_fd, self.ksft_term_fd = os.pipe() 69 pass_fds.append(wait_fd) 70 env["KSFT_WAIT_FD"] = str(wait_fd) 71 ksft_ready = True # ksft_wait implies ready 72 if ksft_ready is not None: 73 rfd, ready_fd = os.pipe() 74 pass_fds.append(ready_fd) 75 env["KSFT_READY_FD"] = str(ready_fd) 76 77 self.proc = subprocess.Popen(comm, shell=shell, stdout=subprocess.PIPE, 78 stderr=subprocess.PIPE, pass_fds=pass_fds, 79 env=env) 80 if ksft_wait is not None: 81 os.close(wait_fd) 82 if ksft_ready is not None: 83 os.close(ready_fd) 84 msg = fd_read_timeout(rfd, ksft_wait) 85 os.close(rfd) 86 if not msg: 87 terminate = self.proc.poll() is None 88 self._process_terminate(terminate=terminate, timeout=1) 89 raise CmdInitFailure("Did not receive ready message", self) 90 if not background: 91 self.process(terminate=False, fail=fail, timeout=timeout) 92 93 def _process_terminate(self, terminate, timeout): 94 if terminate: 95 self.proc.terminate() 96 stdout, stderr = self.proc.communicate(timeout=timeout) 97 self.stdout = stdout.decode("utf-8") 98 self.stderr = stderr.decode("utf-8") 99 self.proc.stdout.close() 100 self.proc.stderr.close() 101 self.ret = self.proc.returncode 102 103 return stdout, stderr 104 105 def process(self, terminate=True, fail=None, timeout=5): 106 if fail is None: 107 fail = not terminate 108 109 if self.ksft_term_fd: 110 os.write(self.ksft_term_fd, b"1") 111 112 stdout, stderr = self._process_terminate(terminate=terminate, 113 timeout=timeout) 114 if self.proc.returncode != 0 and fail: 115 if len(stderr) > 0 and stderr[-1] == "\n": 116 stderr = stderr[:-1] 117 raise CmdExitFailure("Command failed", self) 118 119 def __repr__(self): 120 def str_fmt(name, s): 121 name += ': ' 122 return (name + s.strip().replace('\n', '\n' + ' ' * len(name))) 123 124 ret = "CMD" 125 if self.host: 126 ret += "[remote]" 127 if self.ret is None: 128 ret += f" (unterminated): {self.comm}\n" 129 elif self.ret == 0: 130 ret += f" (success): {self.comm}\n" 131 else: 132 ret += f": {self.comm}\n" 133 ret += f" EXIT: {self.ret}\n" 134 if self.stdout: 135 ret += str_fmt(" STDOUT", self.stdout) + "\n" 136 if self.stderr: 137 ret += str_fmt(" STDERR", self.stderr) + "\n" 138 return ret.strip() 139 140 141class bkg(cmd): 142 """ 143 Run a command in the background. 144 145 Examples usage: 146 147 Run a command on remote host, and wait for it to finish. 148 This is usually paired with wait_port_listen() to make sure 149 the command has initialized: 150 151 with bkg("socat ...", exit_wait=True, host=cfg.remote) as nc: 152 ... 153 154 Run a command and expect it to let us know that it's ready 155 by writing to a special file descriptor passed via KSFT_READY_FD. 156 Command will be terminated when we exit the context manager: 157 158 with bkg("my_binary", ksft_wait=5): 159 """ 160 def __init__(self, comm, shell=None, fail=None, ns=None, host=None, 161 exit_wait=False, ksft_ready=None, ksft_wait=None): 162 super().__init__(comm, background=True, 163 shell=shell, fail=fail, ns=ns, host=host, 164 ksft_ready=ksft_ready, ksft_wait=ksft_wait) 165 self.terminate = not exit_wait and not ksft_wait 166 self._exit_wait = exit_wait 167 self.check_fail = fail 168 169 if shell and self.terminate: 170 print("# Warning: combining shell and terminate is risky!") 171 print("# SIGTERM may not reach the child on zsh/ksh!") 172 173 def __enter__(self): 174 return self 175 176 def __exit__(self, ex_type, ex_value, ex_tb): 177 terminate = self.terminate 178 # Force termination on exception, but only if bkg() didn't already exit 179 # since forcing termination silences failures with fail=None 180 if self.proc.poll() is None: 181 terminate = terminate or (self._exit_wait and ex_type is not None) 182 return self.process(terminate=terminate, fail=self.check_fail) 183 184 185GLOBAL_DEFER_QUEUE = [] 186GLOBAL_DEFER_ARMED = False 187 188 189class defer: 190 def __init__(self, func, *args, **kwargs): 191 if not callable(func): 192 raise Exception("defer created with un-callable object, did you call the function instead of passing its name?") 193 194 self.func = func 195 self.args = args 196 self.kwargs = kwargs 197 198 if not GLOBAL_DEFER_ARMED: 199 raise Exception("defer queue not armed, did you use defer() outside of a test case?") 200 self._queue = GLOBAL_DEFER_QUEUE 201 self._queue.append(self) 202 203 def __enter__(self): 204 return self 205 206 def __exit__(self, ex_type, ex_value, ex_tb): 207 return self.exec() 208 209 def exec_only(self): 210 self.func(*self.args, **self.kwargs) 211 212 def cancel(self): 213 self._queue.remove(self) 214 215 def exec(self): 216 self.cancel() 217 self.exec_only() 218 219 220def tool(name, args, json=None, ns=None, host=None): 221 cmd_str = name + ' ' 222 if json: 223 cmd_str += '--json ' 224 cmd_str += args 225 cmd_obj = cmd(cmd_str, ns=ns, host=host) 226 if json: 227 return _json.loads(cmd_obj.stdout) 228 return cmd_obj 229 230 231def bpftool(args, json=None, ns=None, host=None): 232 return tool('bpftool', args, json=json, ns=ns, host=host) 233 234 235def ip(args, json=None, ns=None, host=None): 236 if ns: 237 args = f'-netns {ns} ' + args 238 return tool('ip', args, json=json, host=host) 239 240 241def ethtool(args, json=None, ns=None, host=None): 242 return tool('ethtool', args, json=json, ns=ns, host=host) 243 244 245def bpftrace(expr, json=None, ns=None, host=None, timeout=None): 246 """ 247 Run bpftrace and return map data (if json=True). 248 The output of bpftrace is inconvenient, so the helper converts 249 to a dict indexed by map name, e.g.: 250 { 251 "@": { ... }, 252 "@map2": { ... }, 253 } 254 """ 255 cmd_arr = ['bpftrace'] 256 # Throw in --quiet if json, otherwise the output has two objects 257 if json: 258 cmd_arr += ['-f', 'json', '-q'] 259 if timeout: 260 expr += ' interval:s:' + str(timeout) + ' { exit(); }' 261 timeout += 20 262 cmd_arr += ['-e', expr] 263 cmd_obj = cmd(cmd_arr, ns=ns, host=host, shell=False, timeout=timeout) 264 if json: 265 # bpftrace prints objects as lines 266 ret = {} 267 for l in cmd_obj.stdout.split('\n'): 268 if not l.strip(): 269 continue 270 one = _json.loads(l) 271 if one.get('type') != 'map': 272 continue 273 for k, v in one["data"].items(): 274 if k.startswith('@'): 275 k = k.lstrip('@') 276 ret[k] = v 277 return ret 278 return cmd_obj 279 280 281def rand_port(stype=socket.SOCK_STREAM): 282 """ 283 Get a random unprivileged port. 284 """ 285 return rand_ports(1, stype)[0] 286 287 288def rand_ports(count, stype=socket.SOCK_STREAM): 289 """ 290 Get a unique set of random unprivileged ports. 291 """ 292 sockets = [] 293 ports = [] 294 295 try: 296 for _ in range(count): 297 s = socket.socket(socket.AF_INET6, stype) 298 sockets.append(s) 299 s.bind(("", 0)) 300 ports.append(s.getsockname()[1]) 301 finally: 302 for s in sockets: 303 s.close() 304 305 return ports 306 307 308def wait_port_listen(port, proto="tcp", ns=None, host=None, sleep=0.005, deadline=5): 309 end = time.monotonic() + deadline 310 311 pattern = f":{port:04X} .* " 312 if proto == "tcp": # for tcp protocol additionally check the socket state 313 pattern += "0A" 314 pattern = re.compile(pattern) 315 316 while True: 317 data = cmd(f'cat /proc/net/{proto}*', ns=ns, host=host, shell=True).stdout 318 for row in data.split("\n"): 319 if pattern.search(row): 320 return 321 if time.monotonic() > end: 322 raise Exception("Waiting for port listen timed out") 323 time.sleep(sleep) 324 325 326def wait_file(fname, test_fn, sleep=0.005, deadline=5, encoding='utf-8'): 327 """ 328 Wait for file contents on the local system to satisfy a condition. 329 test_fn() should take one argument (file contents) and return whether 330 condition is met. 331 """ 332 end = time.monotonic() + deadline 333 334 with open(fname, "r", encoding=encoding) as fp: 335 while True: 336 if test_fn(fp.read()): 337 break 338 fp.seek(0) 339 if time.monotonic() > end: 340 raise TimeoutError("Wait for file contents failed", fname) 341 time.sleep(sleep) 342