1# SPDX-License-Identifier: GPL-2.0 2 3import fnmatch 4import functools 5import getopt 6import inspect 7import os 8import signal 9import sys 10import time 11import traceback 12from collections import namedtuple 13from .consts import KSFT_MAIN_NAME 14from . import utils 15 16KSFT_RESULT = None 17KSFT_RESULT_ALL = True 18KSFT_DISRUPTIVE = True 19 20 21class KsftFailEx(Exception): 22 pass 23 24 25class KsftSkipEx(Exception): 26 pass 27 28 29class KsftXfailEx(Exception): 30 pass 31 32 33class KsftTerminate(KeyboardInterrupt): 34 pass 35 36 37class _KsftArgs: 38 def __init__(self): 39 self.list_tests = False 40 self.filters = [] 41 42 try: 43 opts, _ = getopt.getopt(sys.argv[1:], 'hlt:T:') 44 except getopt.GetoptError as e: 45 print(e, file=sys.stderr) 46 sys.exit(1) 47 48 for opt, val in opts: 49 if opt == '-h': 50 print(f"Usage: {sys.argv[0]} [-h|-l] [-t|-T name]\n" 51 f"\t-h print help\n" 52 f"\t-l list tests (filtered, if filters were specified)\n" 53 f"\t-t name include test\n" 54 f"\t-T name exclude test", 55 file=sys.stderr) 56 sys.exit(0) 57 elif opt == '-l': 58 self.list_tests = True 59 elif opt == '-t': 60 self.filters.append((True, val)) 61 elif opt == '-T': 62 self.filters.append((False, val)) 63 64 65@functools.lru_cache() 66def _ksft_supports_color(): 67 if os.environ.get("NO_COLOR") is not None: 68 return False 69 if not hasattr(sys.stdout, "isatty") or not sys.stdout.isatty(): 70 return False 71 if os.environ.get("TERM") == "dumb": 72 return False 73 return True 74 75 76def ksft_pr(*objs, **kwargs): 77 """ 78 Print logs to stdout. 79 80 Behaves like print() but log lines will be prefixed 81 with # to prevent breaking the TAP output formatting. 82 83 Extra arguments (on top of what print() supports): 84 line_pfx - add extra string before each line 85 """ 86 sep = kwargs.pop("sep", " ") 87 pfx = kwargs.pop("line_pfx", "") 88 pfx = "#" + (" " + pfx if pfx else "") 89 kwargs["flush"] = True 90 91 text = sep.join(str(obj) for obj in objs) 92 prefixed = f"\n{pfx} ".join(text.split('\n')) 93 print(pfx, prefixed, **kwargs) 94 95 96def _fail(*args): 97 global KSFT_RESULT 98 KSFT_RESULT = False 99 100 stack = inspect.stack() 101 started = False 102 for frame in reversed(stack[2:]): 103 # Start printing from the test case function 104 if not started: 105 if frame.function == 'ksft_run': 106 started = True 107 continue 108 109 ksft_pr("Check| At " + frame.filename + ", line " + str(frame.lineno) + 110 ", in " + frame.function + ":") 111 ksft_pr("Check| " + frame.code_context[0].strip()) 112 ksft_pr(*args) 113 114 115def ksft_eq(a, b, comment=""): 116 global KSFT_RESULT 117 if a != b: 118 _fail("Check failed", a, "!=", b, comment) 119 120 121def ksft_ne(a, b, comment=""): 122 global KSFT_RESULT 123 if a == b: 124 _fail("Check failed", a, "==", b, comment) 125 126 127def ksft_true(a, comment=""): 128 if not a: 129 _fail("Check failed", a, "does not eval to True", comment) 130 131 132def ksft_not_none(a, comment=""): 133 if a is None: 134 _fail("Check failed", a, "is None", comment) 135 136 137def ksft_in(a, b, comment=""): 138 if a not in b: 139 _fail("Check failed", a, "not in", b, comment) 140 141 142def ksft_not_in(a, b, comment=""): 143 if a in b: 144 _fail("Check failed", a, "in", b, comment) 145 146 147def ksft_is(a, b, comment=""): 148 if a is not b: 149 _fail("Check failed", a, "is not", b, comment) 150 151 152def ksft_ge(a, b, comment=""): 153 if a < b: 154 _fail("Check failed", a, "<", b, comment) 155 156 157def ksft_gt(a, b, comment=""): 158 if a <= b: 159 _fail("Check failed", a, "<=", b, comment) 160 161 162def ksft_lt(a, b, comment=""): 163 if a >= b: 164 _fail("Check failed", a, ">=", b, comment) 165 166 167class ksft_raises: 168 def __init__(self, expected_type): 169 self.exception = None 170 self.expected_type = expected_type 171 172 def __enter__(self): 173 return self 174 175 def __exit__(self, exc_type, exc_val, exc_tb): 176 if exc_type is None: 177 _fail(f"Expected exception {str(self.expected_type.__name__)}, none raised") 178 elif self.expected_type != exc_type: 179 _fail(f"Expected exception {str(self.expected_type.__name__)}, raised {str(exc_type.__name__)}") 180 self.exception = exc_val 181 # Suppress the exception if its the expected one 182 return self.expected_type == exc_type 183 184 185def ksft_busy_wait(cond, sleep=0.005, deadline=1, comment=""): 186 end = time.monotonic() + deadline 187 while True: 188 if cond(): 189 return 190 if time.monotonic() > end: 191 _fail("Waiting for condition timed out", comment) 192 return 193 time.sleep(sleep) 194 195 196def ktap_result(ok, cnt=1, case_name="", comment=""): 197 global KSFT_RESULT_ALL 198 KSFT_RESULT_ALL = KSFT_RESULT_ALL and ok 199 200 res = "" 201 if not ok: 202 res += "not " 203 res += "ok " 204 res += str(cnt) + " " 205 res += KSFT_MAIN_NAME 206 if case_name: 207 res += "." + case_name 208 if comment: 209 res += " # " + comment 210 if _ksft_supports_color(): 211 if comment.startswith(("SKIP", "XFAIL")): 212 color = "\033[33m" 213 elif ok: 214 color = "\033[32m" 215 else: 216 color = "\033[31m" 217 res = color + res + "\033[0m" 218 print(res, flush=True) 219 220 221def _ksft_defer_arm(state): 222 """ Allow or disallow the use of defer() """ 223 utils.GLOBAL_DEFER_ARMED = state 224 225 226def ksft_flush_defer(): 227 global KSFT_RESULT 228 229 i = 0 230 qlen_start = len(utils.GLOBAL_DEFER_QUEUE) 231 while utils.GLOBAL_DEFER_QUEUE: 232 i += 1 233 entry = utils.GLOBAL_DEFER_QUEUE.pop() 234 try: 235 entry.exec_only() 236 except Exception: 237 ksft_pr(f"Exception while handling defer / cleanup (callback {i} of {qlen_start})!") 238 ksft_pr(traceback.format_exc(), line_pfx="Defer Exception|") 239 KSFT_RESULT = False 240 241 242KsftCaseFunction = namedtuple("KsftCaseFunction", 243 ['name', 'original_func', 'variants']) 244 245 246def ksft_disruptive(func): 247 """ 248 Decorator that marks the test as disruptive (e.g. the test 249 that can down the interface). Disruptive tests can be skipped 250 by passing DISRUPTIVE=False environment variable. 251 """ 252 253 @functools.wraps(func) 254 def wrapper(*args, **kwargs): 255 if not KSFT_DISRUPTIVE: 256 raise KsftSkipEx("marked as disruptive") 257 return func(*args, **kwargs) 258 return wrapper 259 260 261class KsftNamedVariant: 262 """ Named string name + argument list tuple for @ksft_variants """ 263 264 def __init__(self, name, *params): 265 self.params = params 266 self.name = name or "_".join([str(x) for x in self.params]) 267 268 269def ksft_variants(params): 270 """ 271 Decorator defining the sets of inputs for a test. 272 The parameters will be included in the name of the resulting sub-case. 273 Parameters can be either single object, tuple or a KsftNamedVariant. 274 The argument can be a list or a generator. 275 276 Example: 277 278 @ksft_variants([ 279 (1, "a"), 280 (2, "b"), 281 KsftNamedVariant("three", 3, "c"), 282 ]) 283 def my_case(cfg, a, b): 284 pass # ... 285 286 ksft_run(cases=[my_case], args=(cfg, )) 287 288 Will generate cases: 289 my_case.1_a 290 my_case.2_b 291 my_case.three 292 """ 293 294 return lambda func: KsftCaseFunction(func.__name__, func, params) 295 296 297def ksft_setup(env): 298 """ 299 Setup test framework global state from the environment. 300 """ 301 302 def get_bool(env, name): 303 value = env.get(name, "").lower() 304 if value in ["yes", "true"]: 305 return True 306 if value in ["no", "false"]: 307 return False 308 try: 309 return bool(int(value)) 310 except Exception: 311 raise Exception(f"failed to parse {name}") 312 313 if "DISRUPTIVE" in env: 314 global KSFT_DISRUPTIVE 315 KSFT_DISRUPTIVE = get_bool(env, "DISRUPTIVE") 316 317 return env 318 319 320def _ksft_intr(signum, frame): 321 # ksft runner.sh sends 2 SIGTERMs in a row on a timeout 322 # if we don't ignore the second one it will stop us from handling cleanup 323 global term_cnt 324 term_cnt += 1 325 if term_cnt == 1: 326 raise KsftTerminate() 327 else: 328 ksft_pr(f"Ignoring SIGTERM (cnt: {term_cnt}), already exiting...") 329 330 331def _ksft_name_matches(name, pattern): 332 if '*' in pattern or '?' in pattern or '[' in pattern: 333 return fnmatch.fnmatchcase(name, pattern) 334 return name == pattern 335 336 337def _ksft_test_enabled(name, filters): 338 has_positive = False 339 for include, pattern in filters: 340 has_positive |= include 341 if _ksft_name_matches(name, pattern): 342 return include 343 return not has_positive 344 345 346def _ksft_generate_test_cases(cases, globs, case_pfx, args, cli_args): 347 """Generate a filtered list of (func, args, name) tuples. 348 349 If -l is given, prints matching test names and exits. 350 """ 351 352 cases = cases or [] 353 test_cases = [] 354 355 # If using the globs method find all relevant functions 356 if globs and case_pfx: 357 for key, value in globs.items(): 358 if not callable(value): 359 continue 360 for prefix in case_pfx: 361 if key.startswith(prefix): 362 cases.append(value) 363 break 364 365 for func in cases: 366 if isinstance(func, KsftCaseFunction): 367 # Parametrized test - create case for each param 368 for param in func.variants: 369 if not isinstance(param, KsftNamedVariant): 370 if not isinstance(param, tuple): 371 param = (param, ) 372 param = KsftNamedVariant(None, *param) 373 374 test_cases.append((func.original_func, 375 (*args, *param.params), 376 func.name + "." + param.name)) 377 else: 378 test_cases.append((func, args, func.__name__)) 379 380 if cli_args.filters: 381 test_cases = [tc for tc in test_cases 382 if _ksft_test_enabled(tc[2], cli_args.filters)] 383 384 if cli_args.list_tests: 385 for _, _, name in test_cases: 386 print(name) 387 sys.exit(0) 388 389 return test_cases 390 391 392def ksft_run(cases=None, globs=None, case_pfx=None, args=()): 393 cli_args = _KsftArgs() 394 test_cases = _ksft_generate_test_cases(cases, globs, case_pfx, args, 395 cli_args) 396 397 global term_cnt 398 term_cnt = 0 399 prev_sigterm = signal.signal(signal.SIGTERM, _ksft_intr) 400 401 totals = {"pass": 0, "fail": 0, "skip": 0, "xfail": 0} 402 403 global KSFT_RESULT 404 if KSFT_RESULT is not None: 405 raise RuntimeError("ksft_run() can't be called multiple times.") 406 407 print("TAP version 13", flush=True) 408 print("1.." + str(len(test_cases)), flush=True) 409 410 cnt = 0 411 stop = False 412 for func, args, name in test_cases: 413 KSFT_RESULT = True 414 cnt += 1 415 comment = "" 416 cnt_key = "" 417 418 _ksft_defer_arm(True) 419 try: 420 func(*args) 421 except KsftSkipEx as e: 422 comment = "SKIP " + str(e) 423 cnt_key = 'skip' 424 except KsftXfailEx as e: 425 comment = "XFAIL " + str(e) 426 cnt_key = 'xfail' 427 except BaseException as e: 428 stop |= isinstance(e, KeyboardInterrupt) 429 ksft_pr(traceback.format_exc(), line_pfx="Exception|") 430 if stop: 431 ksft_pr(f"Stopping tests due to {type(e).__name__}.") 432 KSFT_RESULT = False 433 cnt_key = 'fail' 434 _ksft_defer_arm(False) 435 436 try: 437 ksft_flush_defer() 438 except BaseException as e: 439 ksft_pr(traceback.format_exc(), line_pfx="Exception|") 440 if isinstance(e, KeyboardInterrupt): 441 ksft_pr() 442 ksft_pr("WARN: defer() interrupted, cleanup may be incomplete.") 443 ksft_pr(" Attempting to finish cleanup before exiting.") 444 ksft_pr(" Interrupt again to exit immediately.") 445 ksft_pr() 446 stop = True 447 # Flush was interrupted, try to finish the job best we can 448 ksft_flush_defer() 449 450 if not cnt_key: 451 cnt_key = 'pass' if KSFT_RESULT else 'fail' 452 453 ktap_result(KSFT_RESULT, cnt, name, comment=comment) 454 totals[cnt_key] += 1 455 456 if stop: 457 break 458 459 signal.signal(signal.SIGTERM, prev_sigterm) 460 461 print( 462 f"# Totals: pass:{totals['pass']} fail:{totals['fail']} xfail:{totals['xfail']} xpass:0 skip:{totals['skip']} error:0" 463 ) 464 465 466def ksft_exit(): 467 global KSFT_RESULT_ALL 468 sys.exit(0 if KSFT_RESULT_ALL else 1) 469