1# SPDX-License-Identifier: GPL-2.0 2 3import errno 4import json as _json 5import random 6import re 7import socket 8import subprocess 9import time 10 11 12class CmdExitFailure(Exception): 13 def __init__(self, msg, cmd_obj): 14 super().__init__(msg) 15 self.cmd = cmd_obj 16 17 18class cmd: 19 def __init__(self, comm, shell=True, fail=True, ns=None, background=False, host=None, timeout=5): 20 if ns: 21 comm = f'ip netns exec {ns} ' + comm 22 23 self.stdout = None 24 self.stderr = None 25 self.ret = None 26 27 self.comm = comm 28 if host: 29 self.proc = host.cmd(comm) 30 else: 31 self.proc = subprocess.Popen(comm, shell=shell, stdout=subprocess.PIPE, 32 stderr=subprocess.PIPE) 33 if not background: 34 self.process(terminate=False, fail=fail, timeout=timeout) 35 36 def process(self, terminate=True, fail=None, timeout=5): 37 if fail is None: 38 fail = not terminate 39 40 if terminate: 41 self.proc.terminate() 42 stdout, stderr = self.proc.communicate(timeout) 43 self.stdout = stdout.decode("utf-8") 44 self.stderr = stderr.decode("utf-8") 45 self.proc.stdout.close() 46 self.proc.stderr.close() 47 self.ret = self.proc.returncode 48 49 if self.proc.returncode != 0 and fail: 50 if len(stderr) > 0 and stderr[-1] == "\n": 51 stderr = stderr[:-1] 52 raise CmdExitFailure("Command failed: %s\nSTDOUT: %s\nSTDERR: %s" % 53 (self.proc.args, stdout, stderr), self) 54 55 56class bkg(cmd): 57 def __init__(self, comm, shell=True, fail=None, ns=None, host=None, 58 exit_wait=False): 59 super().__init__(comm, background=True, 60 shell=shell, fail=fail, ns=ns, host=host) 61 self.terminate = not exit_wait 62 self.check_fail = fail 63 64 def __enter__(self): 65 return self 66 67 def __exit__(self, ex_type, ex_value, ex_tb): 68 return self.process(terminate=self.terminate, fail=self.check_fail) 69 70 71global_defer_queue = [] 72 73 74class defer: 75 def __init__(self, func, *args, **kwargs): 76 global global_defer_queue 77 78 if not callable(func): 79 raise Exception("defer created with un-callable object, did you call the function instead of passing its name?") 80 81 self.func = func 82 self.args = args 83 self.kwargs = kwargs 84 85 self._queue = global_defer_queue 86 self._queue.append(self) 87 88 def __enter__(self): 89 return self 90 91 def __exit__(self, ex_type, ex_value, ex_tb): 92 return self.exec() 93 94 def exec_only(self): 95 self.func(*self.args, **self.kwargs) 96 97 def cancel(self): 98 self._queue.remove(self) 99 100 def exec(self): 101 self.cancel() 102 self.exec_only() 103 104 105def tool(name, args, json=None, ns=None, host=None): 106 cmd_str = name + ' ' 107 if json: 108 cmd_str += '--json ' 109 cmd_str += args 110 cmd_obj = cmd(cmd_str, ns=ns, host=host) 111 if json: 112 return _json.loads(cmd_obj.stdout) 113 return cmd_obj 114 115 116def ip(args, json=None, ns=None, host=None): 117 if ns: 118 args = f'-netns {ns} ' + args 119 return tool('ip', args, json=json, host=host) 120 121 122def ethtool(args, json=None, ns=None, host=None): 123 return tool('ethtool', args, json=json, ns=ns, host=host) 124 125 126def rand_port(): 127 """ 128 Get a random unprivileged port, try to make sure it's not already used. 129 """ 130 for _ in range(1000): 131 port = random.randint(10000, 65535) 132 try: 133 with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: 134 s.bind(("", port)) 135 return port 136 except OSError as e: 137 if e.errno != errno.EADDRINUSE: 138 raise 139 raise Exception("Can't find any free unprivileged port") 140 141 142def wait_port_listen(port, proto="tcp", ns=None, host=None, sleep=0.005, deadline=5): 143 end = time.monotonic() + deadline 144 145 pattern = f":{port:04X} .* " 146 if proto == "tcp": # for tcp protocol additionally check the socket state 147 pattern += "0A" 148 pattern = re.compile(pattern) 149 150 while True: 151 data = cmd(f'cat /proc/net/{proto}*', ns=ns, host=host, shell=True).stdout 152 for row in data.split("\n"): 153 if pattern.search(row): 154 return 155 if time.monotonic() > end: 156 raise Exception("Waiting for port listen timed out") 157 time.sleep(sleep) 158