1# SPDX-License-Identifier: GPL-2.0 2 3import builtins 4import functools 5import inspect 6import signal 7import sys 8import time 9import traceback 10from .consts import KSFT_MAIN_NAME 11from .utils import global_defer_queue 12 13KSFT_RESULT = None 14KSFT_RESULT_ALL = True 15KSFT_DISRUPTIVE = True 16 17 18class KsftFailEx(Exception): 19 pass 20 21 22class KsftSkipEx(Exception): 23 pass 24 25 26class KsftXfailEx(Exception): 27 pass 28 29 30class KsftTerminate(KeyboardInterrupt): 31 pass 32 33 34def ksft_pr(*objs, **kwargs): 35 print("#", *objs, **kwargs) 36 37 38def _fail(*args): 39 global KSFT_RESULT 40 KSFT_RESULT = False 41 42 stack = inspect.stack() 43 started = False 44 for frame in reversed(stack[2:]): 45 # Start printing from the test case function 46 if not started: 47 if frame.function == 'ksft_run': 48 started = True 49 continue 50 51 ksft_pr("Check| At " + frame.filename + ", line " + str(frame.lineno) + 52 ", in " + frame.function + ":") 53 ksft_pr("Check| " + frame.code_context[0].strip()) 54 ksft_pr(*args) 55 56 57def ksft_eq(a, b, comment=""): 58 global KSFT_RESULT 59 if a != b: 60 _fail("Check failed", a, "!=", b, comment) 61 62 63def ksft_ne(a, b, comment=""): 64 global KSFT_RESULT 65 if a == b: 66 _fail("Check failed", a, "==", b, comment) 67 68 69def ksft_true(a, comment=""): 70 if not a: 71 _fail("Check failed", a, "does not eval to True", comment) 72 73 74def ksft_in(a, b, comment=""): 75 if a not in b: 76 _fail("Check failed", a, "not in", b, comment) 77 78 79def ksft_not_in(a, b, comment=""): 80 if a in b: 81 _fail("Check failed", a, "in", b, comment) 82 83 84def ksft_is(a, b, comment=""): 85 if a is not b: 86 _fail("Check failed", a, "is not", b, comment) 87 88 89def ksft_ge(a, b, comment=""): 90 if a < b: 91 _fail("Check failed", a, "<", b, comment) 92 93 94def ksft_lt(a, b, comment=""): 95 if a >= b: 96 _fail("Check failed", a, ">=", b, comment) 97 98 99class ksft_raises: 100 def __init__(self, expected_type): 101 self.exception = None 102 self.expected_type = expected_type 103 104 def __enter__(self): 105 return self 106 107 def __exit__(self, exc_type, exc_val, exc_tb): 108 if exc_type is None: 109 _fail(f"Expected exception {str(self.expected_type.__name__)}, none raised") 110 elif self.expected_type != exc_type: 111 _fail(f"Expected exception {str(self.expected_type.__name__)}, raised {str(exc_type.__name__)}") 112 self.exception = exc_val 113 # Suppress the exception if its the expected one 114 return self.expected_type == exc_type 115 116 117def ksft_busy_wait(cond, sleep=0.005, deadline=1, comment=""): 118 end = time.monotonic() + deadline 119 while True: 120 if cond(): 121 return 122 if time.monotonic() > end: 123 _fail("Waiting for condition timed out", comment) 124 return 125 time.sleep(sleep) 126 127 128def ktap_result(ok, cnt=1, case="", comment=""): 129 global KSFT_RESULT_ALL 130 KSFT_RESULT_ALL = KSFT_RESULT_ALL and ok 131 132 res = "" 133 if not ok: 134 res += "not " 135 res += "ok " 136 res += str(cnt) + " " 137 res += KSFT_MAIN_NAME 138 if case: 139 res += "." + str(case.__name__) 140 if comment: 141 res += " # " + comment 142 print(res) 143 144 145def ksft_flush_defer(): 146 global KSFT_RESULT 147 148 i = 0 149 qlen_start = len(global_defer_queue) 150 while global_defer_queue: 151 i += 1 152 entry = global_defer_queue.pop() 153 try: 154 entry.exec_only() 155 except: 156 ksft_pr(f"Exception while handling defer / cleanup (callback {i} of {qlen_start})!") 157 tb = traceback.format_exc() 158 for line in tb.strip().split('\n'): 159 ksft_pr("Defer Exception|", line) 160 KSFT_RESULT = False 161 162 163def ksft_disruptive(func): 164 """ 165 Decorator that marks the test as disruptive (e.g. the test 166 that can down the interface). Disruptive tests can be skipped 167 by passing DISRUPTIVE=False environment variable. 168 """ 169 170 @functools.wraps(func) 171 def wrapper(*args, **kwargs): 172 if not KSFT_DISRUPTIVE: 173 raise KsftSkipEx(f"marked as disruptive") 174 return func(*args, **kwargs) 175 return wrapper 176 177 178def ksft_setup(env): 179 """ 180 Setup test framework global state from the environment. 181 """ 182 183 def get_bool(env, name): 184 value = env.get(name, "").lower() 185 if value in ["yes", "true"]: 186 return True 187 if value in ["no", "false"]: 188 return False 189 try: 190 return bool(int(value)) 191 except: 192 raise Exception(f"failed to parse {name}") 193 194 if "DISRUPTIVE" in env: 195 global KSFT_DISRUPTIVE 196 KSFT_DISRUPTIVE = get_bool(env, "DISRUPTIVE") 197 198 return env 199 200 201def _ksft_intr(signum, frame): 202 # ksft runner.sh sends 2 SIGTERMs in a row on a timeout 203 # if we don't ignore the second one it will stop us from handling cleanup 204 global term_cnt 205 term_cnt += 1 206 if term_cnt == 1: 207 raise KsftTerminate() 208 else: 209 ksft_pr(f"Ignoring SIGTERM (cnt: {term_cnt}), already exiting...") 210 211 212def ksft_run(cases=None, globs=None, case_pfx=None, args=()): 213 cases = cases or [] 214 215 if globs and case_pfx: 216 for key, value in globs.items(): 217 if not callable(value): 218 continue 219 for prefix in case_pfx: 220 if key.startswith(prefix): 221 cases.append(value) 222 break 223 224 global term_cnt 225 term_cnt = 0 226 prev_sigterm = signal.signal(signal.SIGTERM, _ksft_intr) 227 228 totals = {"pass": 0, "fail": 0, "skip": 0, "xfail": 0} 229 230 print("TAP version 13") 231 print("1.." + str(len(cases))) 232 233 global KSFT_RESULT 234 cnt = 0 235 stop = False 236 for case in cases: 237 KSFT_RESULT = True 238 cnt += 1 239 comment = "" 240 cnt_key = "" 241 242 try: 243 case(*args) 244 except KsftSkipEx as e: 245 comment = "SKIP " + str(e) 246 cnt_key = 'skip' 247 except KsftXfailEx as e: 248 comment = "XFAIL " + str(e) 249 cnt_key = 'xfail' 250 except BaseException as e: 251 stop |= isinstance(e, KeyboardInterrupt) 252 tb = traceback.format_exc() 253 for line in tb.strip().split('\n'): 254 ksft_pr("Exception|", line) 255 if stop: 256 ksft_pr(f"Stopping tests due to {type(e).__name__}.") 257 KSFT_RESULT = False 258 cnt_key = 'fail' 259 260 ksft_flush_defer() 261 262 if not cnt_key: 263 cnt_key = 'pass' if KSFT_RESULT else 'fail' 264 265 ktap_result(KSFT_RESULT, cnt, case, comment=comment) 266 totals[cnt_key] += 1 267 268 if stop: 269 break 270 271 signal.signal(signal.SIGTERM, prev_sigterm) 272 273 print( 274 f"# Totals: pass:{totals['pass']} fail:{totals['fail']} xfail:{totals['xfail']} xpass:0 skip:{totals['skip']} error:0" 275 ) 276 277 278def ksft_exit(): 279 global KSFT_RESULT_ALL 280 sys.exit(0 if KSFT_RESULT_ALL else 1) 281