1# SPDX-License-Identifier: GPL-2.0 2 3import builtins 4import functools 5import inspect 6import signal 7import sys 8import time 9import traceback 10from .consts import KSFT_MAIN_NAME 11from .utils import global_defer_queue 12 13KSFT_RESULT = None 14KSFT_RESULT_ALL = True 15KSFT_DISRUPTIVE = True 16 17 18class KsftFailEx(Exception): 19 pass 20 21 22class KsftSkipEx(Exception): 23 pass 24 25 26class KsftXfailEx(Exception): 27 pass 28 29 30class KsftTerminate(KeyboardInterrupt): 31 pass 32 33 34def ksft_pr(*objs, **kwargs): 35 kwargs["flush"] = True 36 print("#", *objs, **kwargs) 37 38 39def _fail(*args): 40 global KSFT_RESULT 41 KSFT_RESULT = False 42 43 stack = inspect.stack() 44 started = False 45 for frame in reversed(stack[2:]): 46 # Start printing from the test case function 47 if not started: 48 if frame.function == 'ksft_run': 49 started = True 50 continue 51 52 ksft_pr("Check| At " + frame.filename + ", line " + str(frame.lineno) + 53 ", in " + frame.function + ":") 54 ksft_pr("Check| " + frame.code_context[0].strip()) 55 ksft_pr(*args) 56 57 58def ksft_eq(a, b, comment=""): 59 global KSFT_RESULT 60 if a != b: 61 _fail("Check failed", a, "!=", b, comment) 62 63 64def ksft_ne(a, b, comment=""): 65 global KSFT_RESULT 66 if a == b: 67 _fail("Check failed", a, "==", b, comment) 68 69 70def ksft_true(a, comment=""): 71 if not a: 72 _fail("Check failed", a, "does not eval to True", comment) 73 74 75def ksft_in(a, b, comment=""): 76 if a not in b: 77 _fail("Check failed", a, "not in", b, comment) 78 79 80def ksft_not_in(a, b, comment=""): 81 if a in b: 82 _fail("Check failed", a, "in", b, comment) 83 84 85def ksft_is(a, b, comment=""): 86 if a is not b: 87 _fail("Check failed", a, "is not", b, comment) 88 89 90def ksft_ge(a, b, comment=""): 91 if a < b: 92 _fail("Check failed", a, "<", b, comment) 93 94 95def ksft_gt(a, b, comment=""): 96 if a <= b: 97 _fail("Check failed", a, "<=", b, comment) 98 99 100def ksft_lt(a, b, comment=""): 101 if a >= b: 102 _fail("Check failed", a, ">=", b, comment) 103 104 105class ksft_raises: 106 def __init__(self, expected_type): 107 self.exception = None 108 self.expected_type = expected_type 109 110 def __enter__(self): 111 return self 112 113 def __exit__(self, exc_type, exc_val, exc_tb): 114 if exc_type is None: 115 _fail(f"Expected exception {str(self.expected_type.__name__)}, none raised") 116 elif self.expected_type != exc_type: 117 _fail(f"Expected exception {str(self.expected_type.__name__)}, raised {str(exc_type.__name__)}") 118 self.exception = exc_val 119 # Suppress the exception if its the expected one 120 return self.expected_type == exc_type 121 122 123def ksft_busy_wait(cond, sleep=0.005, deadline=1, comment=""): 124 end = time.monotonic() + deadline 125 while True: 126 if cond(): 127 return 128 if time.monotonic() > end: 129 _fail("Waiting for condition timed out", comment) 130 return 131 time.sleep(sleep) 132 133 134def ktap_result(ok, cnt=1, case="", comment=""): 135 global KSFT_RESULT_ALL 136 KSFT_RESULT_ALL = KSFT_RESULT_ALL and ok 137 138 res = "" 139 if not ok: 140 res += "not " 141 res += "ok " 142 res += str(cnt) + " " 143 res += KSFT_MAIN_NAME 144 if case: 145 res += "." + str(case.__name__) 146 if comment: 147 res += " # " + comment 148 print(res, flush=True) 149 150 151def ksft_flush_defer(): 152 global KSFT_RESULT 153 154 i = 0 155 qlen_start = len(global_defer_queue) 156 while global_defer_queue: 157 i += 1 158 entry = global_defer_queue.pop() 159 try: 160 entry.exec_only() 161 except: 162 ksft_pr(f"Exception while handling defer / cleanup (callback {i} of {qlen_start})!") 163 tb = traceback.format_exc() 164 for line in tb.strip().split('\n'): 165 ksft_pr("Defer Exception|", line) 166 KSFT_RESULT = False 167 168 169def ksft_disruptive(func): 170 """ 171 Decorator that marks the test as disruptive (e.g. the test 172 that can down the interface). Disruptive tests can be skipped 173 by passing DISRUPTIVE=False environment variable. 174 """ 175 176 @functools.wraps(func) 177 def wrapper(*args, **kwargs): 178 if not KSFT_DISRUPTIVE: 179 raise KsftSkipEx(f"marked as disruptive") 180 return func(*args, **kwargs) 181 return wrapper 182 183 184def ksft_setup(env): 185 """ 186 Setup test framework global state from the environment. 187 """ 188 189 def get_bool(env, name): 190 value = env.get(name, "").lower() 191 if value in ["yes", "true"]: 192 return True 193 if value in ["no", "false"]: 194 return False 195 try: 196 return bool(int(value)) 197 except: 198 raise Exception(f"failed to parse {name}") 199 200 if "DISRUPTIVE" in env: 201 global KSFT_DISRUPTIVE 202 KSFT_DISRUPTIVE = get_bool(env, "DISRUPTIVE") 203 204 return env 205 206 207def _ksft_intr(signum, frame): 208 # ksft runner.sh sends 2 SIGTERMs in a row on a timeout 209 # if we don't ignore the second one it will stop us from handling cleanup 210 global term_cnt 211 term_cnt += 1 212 if term_cnt == 1: 213 raise KsftTerminate() 214 else: 215 ksft_pr(f"Ignoring SIGTERM (cnt: {term_cnt}), already exiting...") 216 217 218def ksft_run(cases=None, globs=None, case_pfx=None, args=()): 219 cases = cases or [] 220 221 if globs and case_pfx: 222 for key, value in globs.items(): 223 if not callable(value): 224 continue 225 for prefix in case_pfx: 226 if key.startswith(prefix): 227 cases.append(value) 228 break 229 230 global term_cnt 231 term_cnt = 0 232 prev_sigterm = signal.signal(signal.SIGTERM, _ksft_intr) 233 234 totals = {"pass": 0, "fail": 0, "skip": 0, "xfail": 0} 235 236 print("TAP version 13", flush=True) 237 print("1.." + str(len(cases)), flush=True) 238 239 global KSFT_RESULT 240 cnt = 0 241 stop = False 242 for case in cases: 243 KSFT_RESULT = True 244 cnt += 1 245 comment = "" 246 cnt_key = "" 247 248 try: 249 case(*args) 250 except KsftSkipEx as e: 251 comment = "SKIP " + str(e) 252 cnt_key = 'skip' 253 except KsftXfailEx as e: 254 comment = "XFAIL " + str(e) 255 cnt_key = 'xfail' 256 except BaseException as e: 257 stop |= isinstance(e, KeyboardInterrupt) 258 tb = traceback.format_exc() 259 for line in tb.strip().split('\n'): 260 ksft_pr("Exception|", line) 261 if stop: 262 ksft_pr(f"Stopping tests due to {type(e).__name__}.") 263 KSFT_RESULT = False 264 cnt_key = 'fail' 265 266 ksft_flush_defer() 267 268 if not cnt_key: 269 cnt_key = 'pass' if KSFT_RESULT else 'fail' 270 271 ktap_result(KSFT_RESULT, cnt, case, comment=comment) 272 totals[cnt_key] += 1 273 274 if stop: 275 break 276 277 signal.signal(signal.SIGTERM, prev_sigterm) 278 279 print( 280 f"# Totals: pass:{totals['pass']} fail:{totals['fail']} xfail:{totals['xfail']} xpass:0 skip:{totals['skip']} error:0" 281 ) 282 283 284def ksft_exit(): 285 global KSFT_RESULT_ALL 286 sys.exit(0 if KSFT_RESULT_ALL else 1) 287