1# SPDX-License-Identifier: GPL-2.0 2 3import functools 4import inspect 5import signal 6import sys 7import time 8import traceback 9from .consts import KSFT_MAIN_NAME 10from .utils import global_defer_queue 11 12KSFT_RESULT = None 13KSFT_RESULT_ALL = True 14KSFT_DISRUPTIVE = True 15 16 17class KsftFailEx(Exception): 18 pass 19 20 21class KsftSkipEx(Exception): 22 pass 23 24 25class KsftXfailEx(Exception): 26 pass 27 28 29class KsftTerminate(KeyboardInterrupt): 30 pass 31 32 33def ksft_pr(*objs, **kwargs): 34 kwargs["flush"] = True 35 print("#", *objs, **kwargs) 36 37 38def _fail(*args): 39 global KSFT_RESULT 40 KSFT_RESULT = False 41 42 stack = inspect.stack() 43 started = False 44 for frame in reversed(stack[2:]): 45 # Start printing from the test case function 46 if not started: 47 if frame.function == 'ksft_run': 48 started = True 49 continue 50 51 ksft_pr("Check| At " + frame.filename + ", line " + str(frame.lineno) + 52 ", in " + frame.function + ":") 53 ksft_pr("Check| " + frame.code_context[0].strip()) 54 ksft_pr(*args) 55 56 57def ksft_eq(a, b, comment=""): 58 global KSFT_RESULT 59 if a != b: 60 _fail("Check failed", a, "!=", b, comment) 61 62 63def ksft_ne(a, b, comment=""): 64 global KSFT_RESULT 65 if a == b: 66 _fail("Check failed", a, "==", b, comment) 67 68 69def ksft_true(a, comment=""): 70 if not a: 71 _fail("Check failed", a, "does not eval to True", comment) 72 73 74def ksft_not_none(a, comment=""): 75 if a is None: 76 _fail("Check failed", a, "is None", comment) 77 78 79def ksft_in(a, b, comment=""): 80 if a not in b: 81 _fail("Check failed", a, "not in", b, comment) 82 83 84def ksft_not_in(a, b, comment=""): 85 if a in b: 86 _fail("Check failed", a, "in", b, comment) 87 88 89def ksft_is(a, b, comment=""): 90 if a is not b: 91 _fail("Check failed", a, "is not", b, comment) 92 93 94def ksft_ge(a, b, comment=""): 95 if a < b: 96 _fail("Check failed", a, "<", b, comment) 97 98 99def ksft_gt(a, b, comment=""): 100 if a <= b: 101 _fail("Check failed", a, "<=", b, comment) 102 103 104def ksft_lt(a, b, comment=""): 105 if a >= b: 106 _fail("Check failed", a, ">=", b, comment) 107 108 109class ksft_raises: 110 def __init__(self, expected_type): 111 self.exception = None 112 self.expected_type = expected_type 113 114 def __enter__(self): 115 return self 116 117 def __exit__(self, exc_type, exc_val, exc_tb): 118 if exc_type is None: 119 _fail(f"Expected exception {str(self.expected_type.__name__)}, none raised") 120 elif self.expected_type != exc_type: 121 _fail(f"Expected exception {str(self.expected_type.__name__)}, raised {str(exc_type.__name__)}") 122 self.exception = exc_val 123 # Suppress the exception if its the expected one 124 return self.expected_type == exc_type 125 126 127def ksft_busy_wait(cond, sleep=0.005, deadline=1, comment=""): 128 end = time.monotonic() + deadline 129 while True: 130 if cond(): 131 return 132 if time.monotonic() > end: 133 _fail("Waiting for condition timed out", comment) 134 return 135 time.sleep(sleep) 136 137 138def ktap_result(ok, cnt=1, case="", comment=""): 139 global KSFT_RESULT_ALL 140 KSFT_RESULT_ALL = KSFT_RESULT_ALL and ok 141 142 res = "" 143 if not ok: 144 res += "not " 145 res += "ok " 146 res += str(cnt) + " " 147 res += KSFT_MAIN_NAME 148 if case: 149 res += "." + str(case.__name__) 150 if comment: 151 res += " # " + comment 152 print(res, flush=True) 153 154 155def ksft_flush_defer(): 156 global KSFT_RESULT 157 158 i = 0 159 qlen_start = len(global_defer_queue) 160 while global_defer_queue: 161 i += 1 162 entry = global_defer_queue.pop() 163 try: 164 entry.exec_only() 165 except BaseException: 166 ksft_pr(f"Exception while handling defer / cleanup (callback {i} of {qlen_start})!") 167 tb = traceback.format_exc() 168 for line in tb.strip().split('\n'): 169 ksft_pr("Defer Exception|", line) 170 KSFT_RESULT = False 171 172 173def ksft_disruptive(func): 174 """ 175 Decorator that marks the test as disruptive (e.g. the test 176 that can down the interface). Disruptive tests can be skipped 177 by passing DISRUPTIVE=False environment variable. 178 """ 179 180 @functools.wraps(func) 181 def wrapper(*args, **kwargs): 182 if not KSFT_DISRUPTIVE: 183 raise KsftSkipEx("marked as disruptive") 184 return func(*args, **kwargs) 185 return wrapper 186 187 188def ksft_setup(env): 189 """ 190 Setup test framework global state from the environment. 191 """ 192 193 def get_bool(env, name): 194 value = env.get(name, "").lower() 195 if value in ["yes", "true"]: 196 return True 197 if value in ["no", "false"]: 198 return False 199 try: 200 return bool(int(value)) 201 except Exception: 202 raise Exception(f"failed to parse {name}") 203 204 if "DISRUPTIVE" in env: 205 global KSFT_DISRUPTIVE 206 KSFT_DISRUPTIVE = get_bool(env, "DISRUPTIVE") 207 208 return env 209 210 211def _ksft_intr(signum, frame): 212 # ksft runner.sh sends 2 SIGTERMs in a row on a timeout 213 # if we don't ignore the second one it will stop us from handling cleanup 214 global term_cnt 215 term_cnt += 1 216 if term_cnt == 1: 217 raise KsftTerminate() 218 else: 219 ksft_pr(f"Ignoring SIGTERM (cnt: {term_cnt}), already exiting...") 220 221 222def ksft_run(cases=None, globs=None, case_pfx=None, args=()): 223 cases = cases or [] 224 225 if globs and case_pfx: 226 for key, value in globs.items(): 227 if not callable(value): 228 continue 229 for prefix in case_pfx: 230 if key.startswith(prefix): 231 cases.append(value) 232 break 233 234 global term_cnt 235 term_cnt = 0 236 prev_sigterm = signal.signal(signal.SIGTERM, _ksft_intr) 237 238 totals = {"pass": 0, "fail": 0, "skip": 0, "xfail": 0} 239 240 print("TAP version 13", flush=True) 241 print("1.." + str(len(cases)), flush=True) 242 243 global KSFT_RESULT 244 cnt = 0 245 stop = False 246 for case in cases: 247 KSFT_RESULT = True 248 cnt += 1 249 comment = "" 250 cnt_key = "" 251 252 try: 253 case(*args) 254 except KsftSkipEx as e: 255 comment = "SKIP " + str(e) 256 cnt_key = 'skip' 257 except KsftXfailEx as e: 258 comment = "XFAIL " + str(e) 259 cnt_key = 'xfail' 260 except BaseException as e: 261 stop |= isinstance(e, KeyboardInterrupt) 262 tb = traceback.format_exc() 263 for line in tb.strip().split('\n'): 264 ksft_pr("Exception|", line) 265 if stop: 266 ksft_pr(f"Stopping tests due to {type(e).__name__}.") 267 KSFT_RESULT = False 268 cnt_key = 'fail' 269 270 ksft_flush_defer() 271 272 if not cnt_key: 273 cnt_key = 'pass' if KSFT_RESULT else 'fail' 274 275 ktap_result(KSFT_RESULT, cnt, case, comment=comment) 276 totals[cnt_key] += 1 277 278 if stop: 279 break 280 281 signal.signal(signal.SIGTERM, prev_sigterm) 282 283 print( 284 f"# Totals: pass:{totals['pass']} fail:{totals['fail']} xfail:{totals['xfail']} xpass:0 skip:{totals['skip']} error:0" 285 ) 286 287 288def ksft_exit(): 289 global KSFT_RESULT_ALL 290 sys.exit(0 if KSFT_RESULT_ALL else 1) 291