1# SPDX-License-Identifier: GPL-2.0 2 3import json as _json 4import os 5import re 6import select 7import socket 8import subprocess 9import time 10 11 12class CmdExitFailure(Exception): 13 def __init__(self, msg, cmd_obj): 14 super().__init__(msg) 15 self.cmd = cmd_obj 16 17 18def fd_read_timeout(fd, timeout): 19 rlist, _, _ = select.select([fd], [], [], timeout) 20 if rlist: 21 return os.read(fd, 1024) 22 raise TimeoutError("Timeout waiting for fd read") 23 24 25class cmd: 26 """ 27 Execute a command on local or remote host. 28 29 @shell defaults to false, and class will try to split @comm into a list 30 if it's a string with spaces. 31 32 Use bkg() instead to run a command in the background. 33 """ 34 def __init__(self, comm, shell=None, fail=True, ns=None, background=False, 35 host=None, timeout=5, ksft_ready=None, ksft_wait=None): 36 if ns: 37 comm = f'ip netns exec {ns} ' + comm 38 39 self.stdout = None 40 self.stderr = None 41 self.ret = None 42 self.ksft_term_fd = None 43 44 self.host = host 45 self.comm = comm 46 47 if host: 48 self.proc = host.cmd(comm) 49 else: 50 # If user doesn't explicitly request shell try to avoid it. 51 if shell is None and isinstance(comm, str) and ' ' in comm: 52 comm = comm.split() 53 54 # ksft_wait lets us wait for the background process to fully start, 55 # we pass an FD to the child process, and wait for it to write back. 56 # Similarly term_fd tells child it's time to exit. 57 pass_fds = [] 58 env = os.environ.copy() 59 if ksft_wait is not None: 60 wait_fd, self.ksft_term_fd = os.pipe() 61 pass_fds.append(wait_fd) 62 env["KSFT_WAIT_FD"] = str(wait_fd) 63 ksft_ready = True # ksft_wait implies ready 64 if ksft_ready is not None: 65 rfd, ready_fd = os.pipe() 66 pass_fds.append(ready_fd) 67 env["KSFT_READY_FD"] = str(ready_fd) 68 69 self.proc = subprocess.Popen(comm, shell=shell, stdout=subprocess.PIPE, 70 stderr=subprocess.PIPE, pass_fds=pass_fds, 71 env=env) 72 if ksft_wait is not None: 73 os.close(wait_fd) 74 if ksft_ready is not None: 75 os.close(ready_fd) 76 msg = fd_read_timeout(rfd, ksft_wait) 77 os.close(rfd) 78 if not msg: 79 raise Exception("Did not receive ready message") 80 if not background: 81 self.process(terminate=False, fail=fail, timeout=timeout) 82 83 def process(self, terminate=True, fail=None, timeout=5): 84 if fail is None: 85 fail = not terminate 86 87 if self.ksft_term_fd: 88 os.write(self.ksft_term_fd, b"1") 89 if terminate: 90 self.proc.terminate() 91 stdout, stderr = self.proc.communicate(timeout) 92 self.stdout = stdout.decode("utf-8") 93 self.stderr = stderr.decode("utf-8") 94 self.proc.stdout.close() 95 self.proc.stderr.close() 96 self.ret = self.proc.returncode 97 98 if self.proc.returncode != 0 and fail: 99 if len(stderr) > 0 and stderr[-1] == "\n": 100 stderr = stderr[:-1] 101 raise CmdExitFailure("Command failed: %s\nSTDOUT: %s\nSTDERR: %s" % 102 (self.proc.args, stdout, stderr), self) 103 104 def __repr__(self): 105 def str_fmt(name, s): 106 name += ': ' 107 return (name + s.strip().replace('\n', '\n' + ' ' * len(name))) 108 109 ret = "CMD" 110 if self.host: 111 ret += "[remote]" 112 if self.ret is None: 113 ret += f" (unterminated): {self.comm}\n" 114 elif self.ret == 0: 115 ret += f" (success): {self.comm}\n" 116 else: 117 ret += f": {self.comm}\n" 118 ret += f" EXIT: {self.ret}\n" 119 if self.stdout: 120 ret += str_fmt(" STDOUT", self.stdout) + "\n" 121 if self.stderr: 122 ret += str_fmt(" STDERR", self.stderr) + "\n" 123 return ret.strip() 124 125 126class bkg(cmd): 127 """ 128 Run a command in the background. 129 130 Examples usage: 131 132 Run a command on remote host, and wait for it to finish. 133 This is usually paired with wait_port_listen() to make sure 134 the command has initialized: 135 136 with bkg("socat ...", exit_wait=True, host=cfg.remote) as nc: 137 ... 138 139 Run a command and expect it to let us know that it's ready 140 by writing to a special file descriptor passed via KSFT_READY_FD. 141 Command will be terminated when we exit the context manager: 142 143 with bkg("my_binary", ksft_wait=5): 144 """ 145 def __init__(self, comm, shell=None, fail=None, ns=None, host=None, 146 exit_wait=False, ksft_ready=None, ksft_wait=None): 147 super().__init__(comm, background=True, 148 shell=shell, fail=fail, ns=ns, host=host, 149 ksft_ready=ksft_ready, ksft_wait=ksft_wait) 150 self.terminate = not exit_wait and not ksft_wait 151 self._exit_wait = exit_wait 152 self.check_fail = fail 153 154 if shell and self.terminate: 155 print("# Warning: combining shell and terminate is risky!") 156 print("# SIGTERM may not reach the child on zsh/ksh!") 157 158 def __enter__(self): 159 return self 160 161 def __exit__(self, ex_type, ex_value, ex_tb): 162 # Force termination on exception 163 terminate = self.terminate or (self._exit_wait and ex_type is not None) 164 return self.process(terminate=terminate, fail=self.check_fail) 165 166 167GLOBAL_DEFER_QUEUE = [] 168GLOBAL_DEFER_ARMED = False 169 170 171class defer: 172 def __init__(self, func, *args, **kwargs): 173 if not callable(func): 174 raise Exception("defer created with un-callable object, did you call the function instead of passing its name?") 175 176 self.func = func 177 self.args = args 178 self.kwargs = kwargs 179 180 if not GLOBAL_DEFER_ARMED: 181 raise Exception("defer queue not armed, did you use defer() outside of a test case?") 182 self._queue = GLOBAL_DEFER_QUEUE 183 self._queue.append(self) 184 185 def __enter__(self): 186 return self 187 188 def __exit__(self, ex_type, ex_value, ex_tb): 189 return self.exec() 190 191 def exec_only(self): 192 self.func(*self.args, **self.kwargs) 193 194 def cancel(self): 195 self._queue.remove(self) 196 197 def exec(self): 198 self.cancel() 199 self.exec_only() 200 201 202def tool(name, args, json=None, ns=None, host=None): 203 cmd_str = name + ' ' 204 if json: 205 cmd_str += '--json ' 206 cmd_str += args 207 cmd_obj = cmd(cmd_str, ns=ns, host=host) 208 if json: 209 return _json.loads(cmd_obj.stdout) 210 return cmd_obj 211 212 213def bpftool(args, json=None, ns=None, host=None): 214 return tool('bpftool', args, json=json, ns=ns, host=host) 215 216 217def ip(args, json=None, ns=None, host=None): 218 if ns: 219 args = f'-netns {ns} ' + args 220 return tool('ip', args, json=json, host=host) 221 222 223def ethtool(args, json=None, ns=None, host=None): 224 return tool('ethtool', args, json=json, ns=ns, host=host) 225 226 227def bpftrace(expr, json=None, ns=None, host=None, timeout=None): 228 """ 229 Run bpftrace and return map data (if json=True). 230 The output of bpftrace is inconvenient, so the helper converts 231 to a dict indexed by map name, e.g.: 232 { 233 "@": { ... }, 234 "@map2": { ... }, 235 } 236 """ 237 cmd_arr = ['bpftrace'] 238 # Throw in --quiet if json, otherwise the output has two objects 239 if json: 240 cmd_arr += ['-f', 'json', '-q'] 241 if timeout: 242 expr += ' interval:s:' + str(timeout) + ' { exit(); }' 243 cmd_arr += ['-e', expr] 244 cmd_obj = cmd(cmd_arr, ns=ns, host=host, shell=False) 245 if json: 246 # bpftrace prints objects as lines 247 ret = {} 248 for l in cmd_obj.stdout.split('\n'): 249 if not l.strip(): 250 continue 251 one = _json.loads(l) 252 if one.get('type') != 'map': 253 continue 254 for k, v in one["data"].items(): 255 if k.startswith('@'): 256 k = k.lstrip('@') 257 ret[k] = v 258 return ret 259 return cmd_obj 260 261 262def rand_port(stype=socket.SOCK_STREAM): 263 """ 264 Get a random unprivileged port. 265 """ 266 with socket.socket(socket.AF_INET6, stype) as s: 267 s.bind(("", 0)) 268 return s.getsockname()[1] 269 270 271def wait_port_listen(port, proto="tcp", ns=None, host=None, sleep=0.005, deadline=5): 272 end = time.monotonic() + deadline 273 274 pattern = f":{port:04X} .* " 275 if proto == "tcp": # for tcp protocol additionally check the socket state 276 pattern += "0A" 277 pattern = re.compile(pattern) 278 279 while True: 280 data = cmd(f'cat /proc/net/{proto}*', ns=ns, host=host, shell=True).stdout 281 for row in data.split("\n"): 282 if pattern.search(row): 283 return 284 if time.monotonic() > end: 285 raise Exception("Waiting for port listen timed out") 286 time.sleep(sleep) 287 288 289def wait_file(fname, test_fn, sleep=0.005, deadline=5, encoding='utf-8'): 290 """ 291 Wait for file contents on the local system to satisfy a condition. 292 test_fn() should take one argument (file contents) and return whether 293 condition is met. 294 """ 295 end = time.monotonic() + deadline 296 297 with open(fname, "r", encoding=encoding) as fp: 298 while True: 299 if test_fn(fp.read()): 300 break 301 fp.seek(0) 302 if time.monotonic() > end: 303 raise TimeoutError("Wait for file contents failed", fname) 304 time.sleep(sleep) 305