1# SPDX-License-Identifier: GPL-2.0 2 3import functools 4import inspect 5import signal 6import sys 7import time 8import traceback 9from .consts import KSFT_MAIN_NAME 10from .utils import global_defer_queue 11 12KSFT_RESULT = None 13KSFT_RESULT_ALL = True 14KSFT_DISRUPTIVE = True 15 16 17class KsftFailEx(Exception): 18 pass 19 20 21class KsftSkipEx(Exception): 22 pass 23 24 25class KsftXfailEx(Exception): 26 pass 27 28 29class KsftTerminate(KeyboardInterrupt): 30 pass 31 32 33def ksft_pr(*objs, **kwargs): 34 kwargs["flush"] = True 35 print("#", *objs, **kwargs) 36 37 38def _fail(*args): 39 global KSFT_RESULT 40 KSFT_RESULT = False 41 42 stack = inspect.stack() 43 started = False 44 for frame in reversed(stack[2:]): 45 # Start printing from the test case function 46 if not started: 47 if frame.function == 'ksft_run': 48 started = True 49 continue 50 51 ksft_pr("Check| At " + frame.filename + ", line " + str(frame.lineno) + 52 ", in " + frame.function + ":") 53 ksft_pr("Check| " + frame.code_context[0].strip()) 54 ksft_pr(*args) 55 56 57def ksft_eq(a, b, comment=""): 58 global KSFT_RESULT 59 if a != b: 60 _fail("Check failed", a, "!=", b, comment) 61 62 63def ksft_ne(a, b, comment=""): 64 global KSFT_RESULT 65 if a == b: 66 _fail("Check failed", a, "==", b, comment) 67 68 69def ksft_true(a, comment=""): 70 if not a: 71 _fail("Check failed", a, "does not eval to True", comment) 72 73 74def ksft_not_none(a, comment=""): 75 if a is None: 76 _fail("Check failed", a, "is None", comment) 77 78 79def ksft_in(a, b, comment=""): 80 if a not in b: 81 _fail("Check failed", a, "not in", b, comment) 82 83 84def ksft_not_in(a, b, comment=""): 85 if a in b: 86 _fail("Check failed", a, "in", b, comment) 87 88 89def ksft_is(a, b, comment=""): 90 if a is not b: 91 _fail("Check failed", a, "is not", b, comment) 92 93 94def ksft_ge(a, b, comment=""): 95 if a < b: 96 _fail("Check failed", a, "<", b, comment) 97 98 99def ksft_gt(a, b, comment=""): 100 if a <= b: 101 _fail("Check failed", a, "<=", b, comment) 102 103 104def ksft_lt(a, b, comment=""): 105 if a >= b: 106 _fail("Check failed", a, ">=", b, comment) 107 108 109class ksft_raises: 110 def __init__(self, expected_type): 111 self.exception = None 112 self.expected_type = expected_type 113 114 def __enter__(self): 115 return self 116 117 def __exit__(self, exc_type, exc_val, exc_tb): 118 if exc_type is None: 119 _fail(f"Expected exception {str(self.expected_type.__name__)}, none raised") 120 elif self.expected_type != exc_type: 121 _fail(f"Expected exception {str(self.expected_type.__name__)}, raised {str(exc_type.__name__)}") 122 self.exception = exc_val 123 # Suppress the exception if its the expected one 124 return self.expected_type == exc_type 125 126 127def ksft_busy_wait(cond, sleep=0.005, deadline=1, comment=""): 128 end = time.monotonic() + deadline 129 while True: 130 if cond(): 131 return 132 if time.monotonic() > end: 133 _fail("Waiting for condition timed out", comment) 134 return 135 time.sleep(sleep) 136 137 138def ktap_result(ok, cnt=1, case_name="", comment=""): 139 global KSFT_RESULT_ALL 140 KSFT_RESULT_ALL = KSFT_RESULT_ALL and ok 141 142 res = "" 143 if not ok: 144 res += "not " 145 res += "ok " 146 res += str(cnt) + " " 147 res += KSFT_MAIN_NAME 148 if case_name: 149 res += "." + case_name 150 if comment: 151 res += " # " + comment 152 print(res, flush=True) 153 154 155def ksft_flush_defer(): 156 global KSFT_RESULT 157 158 i = 0 159 qlen_start = len(global_defer_queue) 160 while global_defer_queue: 161 i += 1 162 entry = global_defer_queue.pop() 163 try: 164 entry.exec_only() 165 except BaseException: 166 ksft_pr(f"Exception while handling defer / cleanup (callback {i} of {qlen_start})!") 167 tb = traceback.format_exc() 168 for line in tb.strip().split('\n'): 169 ksft_pr("Defer Exception|", line) 170 KSFT_RESULT = False 171 172 173def ksft_disruptive(func): 174 """ 175 Decorator that marks the test as disruptive (e.g. the test 176 that can down the interface). Disruptive tests can be skipped 177 by passing DISRUPTIVE=False environment variable. 178 """ 179 180 @functools.wraps(func) 181 def wrapper(*args, **kwargs): 182 if not KSFT_DISRUPTIVE: 183 raise KsftSkipEx("marked as disruptive") 184 return func(*args, **kwargs) 185 return wrapper 186 187 188def ksft_setup(env): 189 """ 190 Setup test framework global state from the environment. 191 """ 192 193 def get_bool(env, name): 194 value = env.get(name, "").lower() 195 if value in ["yes", "true"]: 196 return True 197 if value in ["no", "false"]: 198 return False 199 try: 200 return bool(int(value)) 201 except Exception: 202 raise Exception(f"failed to parse {name}") 203 204 if "DISRUPTIVE" in env: 205 global KSFT_DISRUPTIVE 206 KSFT_DISRUPTIVE = get_bool(env, "DISRUPTIVE") 207 208 return env 209 210 211def _ksft_intr(signum, frame): 212 # ksft runner.sh sends 2 SIGTERMs in a row on a timeout 213 # if we don't ignore the second one it will stop us from handling cleanup 214 global term_cnt 215 term_cnt += 1 216 if term_cnt == 1: 217 raise KsftTerminate() 218 else: 219 ksft_pr(f"Ignoring SIGTERM (cnt: {term_cnt}), already exiting...") 220 221 222def _ksft_generate_test_cases(cases, globs, case_pfx, args): 223 """Generate a flat list of (func, args, name) tuples""" 224 225 cases = cases or [] 226 test_cases = [] 227 228 # If using the globs method find all relevant functions 229 if globs and case_pfx: 230 for key, value in globs.items(): 231 if not callable(value): 232 continue 233 for prefix in case_pfx: 234 if key.startswith(prefix): 235 cases.append(value) 236 break 237 238 for func in cases: 239 test_cases.append((func, args, func.__name__)) 240 241 return test_cases 242 243 244def ksft_run(cases=None, globs=None, case_pfx=None, args=()): 245 test_cases = _ksft_generate_test_cases(cases, globs, case_pfx, args) 246 247 global term_cnt 248 term_cnt = 0 249 prev_sigterm = signal.signal(signal.SIGTERM, _ksft_intr) 250 251 totals = {"pass": 0, "fail": 0, "skip": 0, "xfail": 0} 252 253 print("TAP version 13", flush=True) 254 print("1.." + str(len(test_cases)), flush=True) 255 256 global KSFT_RESULT 257 cnt = 0 258 stop = False 259 for func, args, name in test_cases: 260 KSFT_RESULT = True 261 cnt += 1 262 comment = "" 263 cnt_key = "" 264 265 try: 266 func(*args) 267 except KsftSkipEx as e: 268 comment = "SKIP " + str(e) 269 cnt_key = 'skip' 270 except KsftXfailEx as e: 271 comment = "XFAIL " + str(e) 272 cnt_key = 'xfail' 273 except BaseException as e: 274 stop |= isinstance(e, KeyboardInterrupt) 275 tb = traceback.format_exc() 276 for line in tb.strip().split('\n'): 277 ksft_pr("Exception|", line) 278 if stop: 279 ksft_pr(f"Stopping tests due to {type(e).__name__}.") 280 KSFT_RESULT = False 281 cnt_key = 'fail' 282 283 ksft_flush_defer() 284 285 if not cnt_key: 286 cnt_key = 'pass' if KSFT_RESULT else 'fail' 287 288 ktap_result(KSFT_RESULT, cnt, name, comment=comment) 289 totals[cnt_key] += 1 290 291 if stop: 292 break 293 294 signal.signal(signal.SIGTERM, prev_sigterm) 295 296 print( 297 f"# Totals: pass:{totals['pass']} fail:{totals['fail']} xfail:{totals['xfail']} xpass:0 skip:{totals['skip']} error:0" 298 ) 299 300 301def ksft_exit(): 302 global KSFT_RESULT_ALL 303 sys.exit(0 if KSFT_RESULT_ALL else 1) 304