1# SPDX-License-Identifier: GPL-2.0 2 3import functools 4import inspect 5import signal 6import sys 7import time 8import traceback 9from collections import namedtuple 10from .consts import KSFT_MAIN_NAME 11from .utils import global_defer_queue 12 13KSFT_RESULT = None 14KSFT_RESULT_ALL = True 15KSFT_DISRUPTIVE = True 16 17 18class KsftFailEx(Exception): 19 pass 20 21 22class KsftSkipEx(Exception): 23 pass 24 25 26class KsftXfailEx(Exception): 27 pass 28 29 30class KsftTerminate(KeyboardInterrupt): 31 pass 32 33 34def ksft_pr(*objs, **kwargs): 35 kwargs["flush"] = True 36 print("#", *objs, **kwargs) 37 38 39def _fail(*args): 40 global KSFT_RESULT 41 KSFT_RESULT = False 42 43 stack = inspect.stack() 44 started = False 45 for frame in reversed(stack[2:]): 46 # Start printing from the test case function 47 if not started: 48 if frame.function == 'ksft_run': 49 started = True 50 continue 51 52 ksft_pr("Check| At " + frame.filename + ", line " + str(frame.lineno) + 53 ", in " + frame.function + ":") 54 ksft_pr("Check| " + frame.code_context[0].strip()) 55 ksft_pr(*args) 56 57 58def ksft_eq(a, b, comment=""): 59 global KSFT_RESULT 60 if a != b: 61 _fail("Check failed", a, "!=", b, comment) 62 63 64def ksft_ne(a, b, comment=""): 65 global KSFT_RESULT 66 if a == b: 67 _fail("Check failed", a, "==", b, comment) 68 69 70def ksft_true(a, comment=""): 71 if not a: 72 _fail("Check failed", a, "does not eval to True", comment) 73 74 75def ksft_not_none(a, comment=""): 76 if a is None: 77 _fail("Check failed", a, "is None", comment) 78 79 80def ksft_in(a, b, comment=""): 81 if a not in b: 82 _fail("Check failed", a, "not in", b, comment) 83 84 85def ksft_not_in(a, b, comment=""): 86 if a in b: 87 _fail("Check failed", a, "in", b, comment) 88 89 90def ksft_is(a, b, comment=""): 91 if a is not b: 92 _fail("Check failed", a, "is not", b, comment) 93 94 95def ksft_ge(a, b, comment=""): 96 if a < b: 97 _fail("Check failed", a, "<", b, comment) 98 99 100def ksft_gt(a, b, comment=""): 101 if a <= b: 102 _fail("Check failed", a, "<=", b, comment) 103 104 105def ksft_lt(a, b, comment=""): 106 if a >= b: 107 _fail("Check failed", a, ">=", b, comment) 108 109 110class ksft_raises: 111 def __init__(self, expected_type): 112 self.exception = None 113 self.expected_type = expected_type 114 115 def __enter__(self): 116 return self 117 118 def __exit__(self, exc_type, exc_val, exc_tb): 119 if exc_type is None: 120 _fail(f"Expected exception {str(self.expected_type.__name__)}, none raised") 121 elif self.expected_type != exc_type: 122 _fail(f"Expected exception {str(self.expected_type.__name__)}, raised {str(exc_type.__name__)}") 123 self.exception = exc_val 124 # Suppress the exception if its the expected one 125 return self.expected_type == exc_type 126 127 128def ksft_busy_wait(cond, sleep=0.005, deadline=1, comment=""): 129 end = time.monotonic() + deadline 130 while True: 131 if cond(): 132 return 133 if time.monotonic() > end: 134 _fail("Waiting for condition timed out", comment) 135 return 136 time.sleep(sleep) 137 138 139def ktap_result(ok, cnt=1, case_name="", comment=""): 140 global KSFT_RESULT_ALL 141 KSFT_RESULT_ALL = KSFT_RESULT_ALL and ok 142 143 res = "" 144 if not ok: 145 res += "not " 146 res += "ok " 147 res += str(cnt) + " " 148 res += KSFT_MAIN_NAME 149 if case_name: 150 res += "." + case_name 151 if comment: 152 res += " # " + comment 153 print(res, flush=True) 154 155 156def ksft_flush_defer(): 157 global KSFT_RESULT 158 159 i = 0 160 qlen_start = len(global_defer_queue) 161 while global_defer_queue: 162 i += 1 163 entry = global_defer_queue.pop() 164 try: 165 entry.exec_only() 166 except Exception: 167 ksft_pr(f"Exception while handling defer / cleanup (callback {i} of {qlen_start})!") 168 tb = traceback.format_exc() 169 for line in tb.strip().split('\n'): 170 ksft_pr("Defer Exception|", line) 171 KSFT_RESULT = False 172 173 174KsftCaseFunction = namedtuple("KsftCaseFunction", 175 ['name', 'original_func', 'variants']) 176 177 178def ksft_disruptive(func): 179 """ 180 Decorator that marks the test as disruptive (e.g. the test 181 that can down the interface). Disruptive tests can be skipped 182 by passing DISRUPTIVE=False environment variable. 183 """ 184 185 @functools.wraps(func) 186 def wrapper(*args, **kwargs): 187 if not KSFT_DISRUPTIVE: 188 raise KsftSkipEx("marked as disruptive") 189 return func(*args, **kwargs) 190 return wrapper 191 192 193class KsftNamedVariant: 194 """ Named string name + argument list tuple for @ksft_variants """ 195 196 def __init__(self, name, *params): 197 self.params = params 198 self.name = name or "_".join([str(x) for x in self.params]) 199 200 201def ksft_variants(params): 202 """ 203 Decorator defining the sets of inputs for a test. 204 The parameters will be included in the name of the resulting sub-case. 205 Parameters can be either single object, tuple or a KsftNamedVariant. 206 The argument can be a list or a generator. 207 208 Example: 209 210 @ksft_variants([ 211 (1, "a"), 212 (2, "b"), 213 KsftNamedVariant("three", 3, "c"), 214 ]) 215 def my_case(cfg, a, b): 216 pass # ... 217 218 ksft_run(cases=[my_case], args=(cfg, )) 219 220 Will generate cases: 221 my_case.1_a 222 my_case.2_b 223 my_case.three 224 """ 225 226 return lambda func: KsftCaseFunction(func.__name__, func, params) 227 228 229def ksft_setup(env): 230 """ 231 Setup test framework global state from the environment. 232 """ 233 234 def get_bool(env, name): 235 value = env.get(name, "").lower() 236 if value in ["yes", "true"]: 237 return True 238 if value in ["no", "false"]: 239 return False 240 try: 241 return bool(int(value)) 242 except Exception: 243 raise Exception(f"failed to parse {name}") 244 245 if "DISRUPTIVE" in env: 246 global KSFT_DISRUPTIVE 247 KSFT_DISRUPTIVE = get_bool(env, "DISRUPTIVE") 248 249 return env 250 251 252def _ksft_intr(signum, frame): 253 # ksft runner.sh sends 2 SIGTERMs in a row on a timeout 254 # if we don't ignore the second one it will stop us from handling cleanup 255 global term_cnt 256 term_cnt += 1 257 if term_cnt == 1: 258 raise KsftTerminate() 259 else: 260 ksft_pr(f"Ignoring SIGTERM (cnt: {term_cnt}), already exiting...") 261 262 263def _ksft_generate_test_cases(cases, globs, case_pfx, args): 264 """Generate a flat list of (func, args, name) tuples""" 265 266 cases = cases or [] 267 test_cases = [] 268 269 # If using the globs method find all relevant functions 270 if globs and case_pfx: 271 for key, value in globs.items(): 272 if not callable(value): 273 continue 274 for prefix in case_pfx: 275 if key.startswith(prefix): 276 cases.append(value) 277 break 278 279 for func in cases: 280 if isinstance(func, KsftCaseFunction): 281 # Parametrized test - create case for each param 282 for param in func.variants: 283 if not isinstance(param, KsftNamedVariant): 284 if not isinstance(param, tuple): 285 param = (param, ) 286 param = KsftNamedVariant(None, *param) 287 288 test_cases.append((func.original_func, 289 (*args, *param.params), 290 func.name + "." + param.name)) 291 else: 292 test_cases.append((func, args, func.__name__)) 293 294 return test_cases 295 296 297def ksft_run(cases=None, globs=None, case_pfx=None, args=()): 298 test_cases = _ksft_generate_test_cases(cases, globs, case_pfx, args) 299 300 global term_cnt 301 term_cnt = 0 302 prev_sigterm = signal.signal(signal.SIGTERM, _ksft_intr) 303 304 totals = {"pass": 0, "fail": 0, "skip": 0, "xfail": 0} 305 306 print("TAP version 13", flush=True) 307 print("1.." + str(len(test_cases)), flush=True) 308 309 global KSFT_RESULT 310 cnt = 0 311 stop = False 312 for func, args, name in test_cases: 313 KSFT_RESULT = True 314 cnt += 1 315 comment = "" 316 cnt_key = "" 317 318 try: 319 func(*args) 320 except KsftSkipEx as e: 321 comment = "SKIP " + str(e) 322 cnt_key = 'skip' 323 except KsftXfailEx as e: 324 comment = "XFAIL " + str(e) 325 cnt_key = 'xfail' 326 except BaseException as e: 327 stop |= isinstance(e, KeyboardInterrupt) 328 tb = traceback.format_exc() 329 for line in tb.strip().split('\n'): 330 ksft_pr("Exception|", line) 331 if stop: 332 ksft_pr(f"Stopping tests due to {type(e).__name__}.") 333 KSFT_RESULT = False 334 cnt_key = 'fail' 335 336 try: 337 ksft_flush_defer() 338 except BaseException as e: 339 tb = traceback.format_exc() 340 for line in tb.strip().split('\n'): 341 ksft_pr("Exception|", line) 342 if isinstance(e, KeyboardInterrupt): 343 ksft_pr() 344 ksft_pr("WARN: defer() interrupted, cleanup may be incomplete.") 345 ksft_pr(" Attempting to finish cleanup before exiting.") 346 ksft_pr(" Interrupt again to exit immediately.") 347 ksft_pr() 348 stop = True 349 # Flush was interrupted, try to finish the job best we can 350 ksft_flush_defer() 351 352 if not cnt_key: 353 cnt_key = 'pass' if KSFT_RESULT else 'fail' 354 355 ktap_result(KSFT_RESULT, cnt, name, comment=comment) 356 totals[cnt_key] += 1 357 358 if stop: 359 break 360 361 signal.signal(signal.SIGTERM, prev_sigterm) 362 363 print( 364 f"# Totals: pass:{totals['pass']} fail:{totals['fail']} xfail:{totals['xfail']} xpass:0 skip:{totals['skip']} error:0" 365 ) 366 367 368def ksft_exit(): 369 global KSFT_RESULT_ALL 370 sys.exit(0 if KSFT_RESULT_ALL else 1) 371