1# SPDX-License-Identifier: GPL-2.0 2 3import functools 4import inspect 5import os 6import signal 7import sys 8import time 9import traceback 10from collections import namedtuple 11from .consts import KSFT_MAIN_NAME 12from . import utils 13 14KSFT_RESULT = None 15KSFT_RESULT_ALL = True 16KSFT_DISRUPTIVE = True 17 18 19class KsftFailEx(Exception): 20 pass 21 22 23class KsftSkipEx(Exception): 24 pass 25 26 27class KsftXfailEx(Exception): 28 pass 29 30 31class KsftTerminate(KeyboardInterrupt): 32 pass 33 34 35@functools.lru_cache() 36def _ksft_supports_color(): 37 if os.environ.get("NO_COLOR") is not None: 38 return False 39 if not hasattr(sys.stdout, "isatty") or not sys.stdout.isatty(): 40 return False 41 if os.environ.get("TERM") == "dumb": 42 return False 43 return True 44 45 46def ksft_pr(*objs, **kwargs): 47 """ 48 Print logs to stdout. 49 50 Behaves like print() but log lines will be prefixed 51 with # to prevent breaking the TAP output formatting. 52 53 Extra arguments (on top of what print() supports): 54 line_pfx - add extra string before each line 55 """ 56 sep = kwargs.pop("sep", " ") 57 pfx = kwargs.pop("line_pfx", "") 58 pfx = "#" + (" " + pfx if pfx else "") 59 kwargs["flush"] = True 60 61 text = sep.join(str(obj) for obj in objs) 62 prefixed = f"\n{pfx} ".join(text.split('\n')) 63 print(pfx, prefixed, **kwargs) 64 65 66def _fail(*args): 67 global KSFT_RESULT 68 KSFT_RESULT = False 69 70 stack = inspect.stack() 71 started = False 72 for frame in reversed(stack[2:]): 73 # Start printing from the test case function 74 if not started: 75 if frame.function == 'ksft_run': 76 started = True 77 continue 78 79 ksft_pr("Check| At " + frame.filename + ", line " + str(frame.lineno) + 80 ", in " + frame.function + ":") 81 ksft_pr("Check| " + frame.code_context[0].strip()) 82 ksft_pr(*args) 83 84 85def ksft_eq(a, b, comment=""): 86 global KSFT_RESULT 87 if a != b: 88 _fail("Check failed", a, "!=", b, comment) 89 90 91def ksft_ne(a, b, comment=""): 92 global KSFT_RESULT 93 if a == b: 94 _fail("Check failed", a, "==", b, comment) 95 96 97def ksft_true(a, comment=""): 98 if not a: 99 _fail("Check failed", a, "does not eval to True", comment) 100 101 102def ksft_not_none(a, comment=""): 103 if a is None: 104 _fail("Check failed", a, "is None", comment) 105 106 107def ksft_in(a, b, comment=""): 108 if a not in b: 109 _fail("Check failed", a, "not in", b, comment) 110 111 112def ksft_not_in(a, b, comment=""): 113 if a in b: 114 _fail("Check failed", a, "in", b, comment) 115 116 117def ksft_is(a, b, comment=""): 118 if a is not b: 119 _fail("Check failed", a, "is not", b, comment) 120 121 122def ksft_ge(a, b, comment=""): 123 if a < b: 124 _fail("Check failed", a, "<", b, comment) 125 126 127def ksft_gt(a, b, comment=""): 128 if a <= b: 129 _fail("Check failed", a, "<=", b, comment) 130 131 132def ksft_lt(a, b, comment=""): 133 if a >= b: 134 _fail("Check failed", a, ">=", b, comment) 135 136 137class ksft_raises: 138 def __init__(self, expected_type): 139 self.exception = None 140 self.expected_type = expected_type 141 142 def __enter__(self): 143 return self 144 145 def __exit__(self, exc_type, exc_val, exc_tb): 146 if exc_type is None: 147 _fail(f"Expected exception {str(self.expected_type.__name__)}, none raised") 148 elif self.expected_type != exc_type: 149 _fail(f"Expected exception {str(self.expected_type.__name__)}, raised {str(exc_type.__name__)}") 150 self.exception = exc_val 151 # Suppress the exception if its the expected one 152 return self.expected_type == exc_type 153 154 155def ksft_busy_wait(cond, sleep=0.005, deadline=1, comment=""): 156 end = time.monotonic() + deadline 157 while True: 158 if cond(): 159 return 160 if time.monotonic() > end: 161 _fail("Waiting for condition timed out", comment) 162 return 163 time.sleep(sleep) 164 165 166def ktap_result(ok, cnt=1, case_name="", comment=""): 167 global KSFT_RESULT_ALL 168 KSFT_RESULT_ALL = KSFT_RESULT_ALL and ok 169 170 res = "" 171 if not ok: 172 res += "not " 173 res += "ok " 174 res += str(cnt) + " " 175 res += KSFT_MAIN_NAME 176 if case_name: 177 res += "." + case_name 178 if comment: 179 res += " # " + comment 180 if _ksft_supports_color(): 181 if comment.startswith(("SKIP", "XFAIL")): 182 color = "\033[33m" 183 elif ok: 184 color = "\033[32m" 185 else: 186 color = "\033[31m" 187 res = color + res + "\033[0m" 188 print(res, flush=True) 189 190 191def _ksft_defer_arm(state): 192 """ Allow or disallow the use of defer() """ 193 utils.GLOBAL_DEFER_ARMED = state 194 195 196def ksft_flush_defer(): 197 global KSFT_RESULT 198 199 i = 0 200 qlen_start = len(utils.GLOBAL_DEFER_QUEUE) 201 while utils.GLOBAL_DEFER_QUEUE: 202 i += 1 203 entry = utils.GLOBAL_DEFER_QUEUE.pop() 204 try: 205 entry.exec_only() 206 except Exception: 207 ksft_pr(f"Exception while handling defer / cleanup (callback {i} of {qlen_start})!") 208 ksft_pr(traceback.format_exc(), line_pfx="Defer Exception|") 209 KSFT_RESULT = False 210 211 212KsftCaseFunction = namedtuple("KsftCaseFunction", 213 ['name', 'original_func', 'variants']) 214 215 216def ksft_disruptive(func): 217 """ 218 Decorator that marks the test as disruptive (e.g. the test 219 that can down the interface). Disruptive tests can be skipped 220 by passing DISRUPTIVE=False environment variable. 221 """ 222 223 @functools.wraps(func) 224 def wrapper(*args, **kwargs): 225 if not KSFT_DISRUPTIVE: 226 raise KsftSkipEx("marked as disruptive") 227 return func(*args, **kwargs) 228 return wrapper 229 230 231class KsftNamedVariant: 232 """ Named string name + argument list tuple for @ksft_variants """ 233 234 def __init__(self, name, *params): 235 self.params = params 236 self.name = name or "_".join([str(x) for x in self.params]) 237 238 239def ksft_variants(params): 240 """ 241 Decorator defining the sets of inputs for a test. 242 The parameters will be included in the name of the resulting sub-case. 243 Parameters can be either single object, tuple or a KsftNamedVariant. 244 The argument can be a list or a generator. 245 246 Example: 247 248 @ksft_variants([ 249 (1, "a"), 250 (2, "b"), 251 KsftNamedVariant("three", 3, "c"), 252 ]) 253 def my_case(cfg, a, b): 254 pass # ... 255 256 ksft_run(cases=[my_case], args=(cfg, )) 257 258 Will generate cases: 259 my_case.1_a 260 my_case.2_b 261 my_case.three 262 """ 263 264 return lambda func: KsftCaseFunction(func.__name__, func, params) 265 266 267def ksft_setup(env): 268 """ 269 Setup test framework global state from the environment. 270 """ 271 272 def get_bool(env, name): 273 value = env.get(name, "").lower() 274 if value in ["yes", "true"]: 275 return True 276 if value in ["no", "false"]: 277 return False 278 try: 279 return bool(int(value)) 280 except Exception: 281 raise Exception(f"failed to parse {name}") 282 283 if "DISRUPTIVE" in env: 284 global KSFT_DISRUPTIVE 285 KSFT_DISRUPTIVE = get_bool(env, "DISRUPTIVE") 286 287 return env 288 289 290def _ksft_intr(signum, frame): 291 # ksft runner.sh sends 2 SIGTERMs in a row on a timeout 292 # if we don't ignore the second one it will stop us from handling cleanup 293 global term_cnt 294 term_cnt += 1 295 if term_cnt == 1: 296 raise KsftTerminate() 297 else: 298 ksft_pr(f"Ignoring SIGTERM (cnt: {term_cnt}), already exiting...") 299 300 301def _ksft_generate_test_cases(cases, globs, case_pfx, args): 302 """Generate a flat list of (func, args, name) tuples""" 303 304 cases = cases or [] 305 test_cases = [] 306 307 # If using the globs method find all relevant functions 308 if globs and case_pfx: 309 for key, value in globs.items(): 310 if not callable(value): 311 continue 312 for prefix in case_pfx: 313 if key.startswith(prefix): 314 cases.append(value) 315 break 316 317 for func in cases: 318 if isinstance(func, KsftCaseFunction): 319 # Parametrized test - create case for each param 320 for param in func.variants: 321 if not isinstance(param, KsftNamedVariant): 322 if not isinstance(param, tuple): 323 param = (param, ) 324 param = KsftNamedVariant(None, *param) 325 326 test_cases.append((func.original_func, 327 (*args, *param.params), 328 func.name + "." + param.name)) 329 else: 330 test_cases.append((func, args, func.__name__)) 331 332 return test_cases 333 334 335def ksft_run(cases=None, globs=None, case_pfx=None, args=()): 336 test_cases = _ksft_generate_test_cases(cases, globs, case_pfx, args) 337 338 global term_cnt 339 term_cnt = 0 340 prev_sigterm = signal.signal(signal.SIGTERM, _ksft_intr) 341 342 totals = {"pass": 0, "fail": 0, "skip": 0, "xfail": 0} 343 344 print("TAP version 13", flush=True) 345 print("1.." + str(len(test_cases)), flush=True) 346 347 global KSFT_RESULT 348 cnt = 0 349 stop = False 350 for func, args, name in test_cases: 351 KSFT_RESULT = True 352 cnt += 1 353 comment = "" 354 cnt_key = "" 355 356 _ksft_defer_arm(True) 357 try: 358 func(*args) 359 except KsftSkipEx as e: 360 comment = "SKIP " + str(e) 361 cnt_key = 'skip' 362 except KsftXfailEx as e: 363 comment = "XFAIL " + str(e) 364 cnt_key = 'xfail' 365 except BaseException as e: 366 stop |= isinstance(e, KeyboardInterrupt) 367 ksft_pr(traceback.format_exc(), line_pfx="Exception|") 368 if stop: 369 ksft_pr(f"Stopping tests due to {type(e).__name__}.") 370 KSFT_RESULT = False 371 cnt_key = 'fail' 372 _ksft_defer_arm(False) 373 374 try: 375 ksft_flush_defer() 376 except BaseException as e: 377 ksft_pr(traceback.format_exc(), line_pfx="Exception|") 378 if isinstance(e, KeyboardInterrupt): 379 ksft_pr() 380 ksft_pr("WARN: defer() interrupted, cleanup may be incomplete.") 381 ksft_pr(" Attempting to finish cleanup before exiting.") 382 ksft_pr(" Interrupt again to exit immediately.") 383 ksft_pr() 384 stop = True 385 # Flush was interrupted, try to finish the job best we can 386 ksft_flush_defer() 387 388 if not cnt_key: 389 cnt_key = 'pass' if KSFT_RESULT else 'fail' 390 391 ktap_result(KSFT_RESULT, cnt, name, comment=comment) 392 totals[cnt_key] += 1 393 394 if stop: 395 break 396 397 signal.signal(signal.SIGTERM, prev_sigterm) 398 399 print( 400 f"# Totals: pass:{totals['pass']} fail:{totals['fail']} xfail:{totals['xfail']} xpass:0 skip:{totals['skip']} error:0" 401 ) 402 403 404def ksft_exit(): 405 global KSFT_RESULT_ALL 406 sys.exit(0 if KSFT_RESULT_ALL else 1) 407