1# SPDX-License-Identifier: GPL-2.0 2 3import builtins 4import functools 5import inspect 6import sys 7import time 8import traceback 9from .consts import KSFT_MAIN_NAME 10from .utils import global_defer_queue 11 12KSFT_RESULT = None 13KSFT_RESULT_ALL = True 14KSFT_DISRUPTIVE = True 15 16 17class KsftFailEx(Exception): 18 pass 19 20 21class KsftSkipEx(Exception): 22 pass 23 24 25class KsftXfailEx(Exception): 26 pass 27 28 29def ksft_pr(*objs, **kwargs): 30 print("#", *objs, **kwargs) 31 32 33def _fail(*args): 34 global KSFT_RESULT 35 KSFT_RESULT = False 36 37 stack = inspect.stack() 38 started = False 39 for frame in reversed(stack[2:]): 40 # Start printing from the test case function 41 if not started: 42 if frame.function == 'ksft_run': 43 started = True 44 continue 45 46 ksft_pr("Check| At " + frame.filename + ", line " + str(frame.lineno) + 47 ", in " + frame.function + ":") 48 ksft_pr("Check| " + frame.code_context[0].strip()) 49 ksft_pr(*args) 50 51 52def ksft_eq(a, b, comment=""): 53 global KSFT_RESULT 54 if a != b: 55 _fail("Check failed", a, "!=", b, comment) 56 57 58def ksft_ne(a, b, comment=""): 59 global KSFT_RESULT 60 if a == b: 61 _fail("Check failed", a, "==", b, comment) 62 63 64def ksft_true(a, comment=""): 65 if not a: 66 _fail("Check failed", a, "does not eval to True", comment) 67 68 69def ksft_in(a, b, comment=""): 70 if a not in b: 71 _fail("Check failed", a, "not in", b, comment) 72 73 74def ksft_not_in(a, b, comment=""): 75 if a in b: 76 _fail("Check failed", a, "in", b, comment) 77 78 79def ksft_is(a, b, comment=""): 80 if a is not b: 81 _fail("Check failed", a, "is not", b, comment) 82 83 84def ksft_ge(a, b, comment=""): 85 if a < b: 86 _fail("Check failed", a, "<", b, comment) 87 88 89def ksft_lt(a, b, comment=""): 90 if a >= b: 91 _fail("Check failed", a, ">=", b, comment) 92 93 94class ksft_raises: 95 def __init__(self, expected_type): 96 self.exception = None 97 self.expected_type = expected_type 98 99 def __enter__(self): 100 return self 101 102 def __exit__(self, exc_type, exc_val, exc_tb): 103 if exc_type is None: 104 _fail(f"Expected exception {str(self.expected_type.__name__)}, none raised") 105 elif self.expected_type != exc_type: 106 _fail(f"Expected exception {str(self.expected_type.__name__)}, raised {str(exc_type.__name__)}") 107 self.exception = exc_val 108 # Suppress the exception if its the expected one 109 return self.expected_type == exc_type 110 111 112def ksft_busy_wait(cond, sleep=0.005, deadline=1, comment=""): 113 end = time.monotonic() + deadline 114 while True: 115 if cond(): 116 return 117 if time.monotonic() > end: 118 _fail("Waiting for condition timed out", comment) 119 return 120 time.sleep(sleep) 121 122 123def ktap_result(ok, cnt=1, case="", comment=""): 124 global KSFT_RESULT_ALL 125 KSFT_RESULT_ALL = KSFT_RESULT_ALL and ok 126 127 res = "" 128 if not ok: 129 res += "not " 130 res += "ok " 131 res += str(cnt) + " " 132 res += KSFT_MAIN_NAME 133 if case: 134 res += "." + str(case.__name__) 135 if comment: 136 res += " # " + comment 137 print(res) 138 139 140def ksft_flush_defer(): 141 global KSFT_RESULT 142 143 i = 0 144 qlen_start = len(global_defer_queue) 145 while global_defer_queue: 146 i += 1 147 entry = global_defer_queue.pop() 148 try: 149 entry.exec_only() 150 except: 151 ksft_pr(f"Exception while handling defer / cleanup (callback {i} of {qlen_start})!") 152 tb = traceback.format_exc() 153 for line in tb.strip().split('\n'): 154 ksft_pr("Defer Exception|", line) 155 KSFT_RESULT = False 156 157 158def ksft_disruptive(func): 159 """ 160 Decorator that marks the test as disruptive (e.g. the test 161 that can down the interface). Disruptive tests can be skipped 162 by passing DISRUPTIVE=False environment variable. 163 """ 164 165 @functools.wraps(func) 166 def wrapper(*args, **kwargs): 167 if not KSFT_DISRUPTIVE: 168 raise KsftSkipEx(f"marked as disruptive") 169 return func(*args, **kwargs) 170 return wrapper 171 172 173def ksft_setup(env): 174 """ 175 Setup test framework global state from the environment. 176 """ 177 178 def get_bool(env, name): 179 value = env.get(name, "").lower() 180 if value in ["yes", "true"]: 181 return True 182 if value in ["no", "false"]: 183 return False 184 try: 185 return bool(int(value)) 186 except: 187 raise Exception(f"failed to parse {name}") 188 189 if "DISRUPTIVE" in env: 190 global KSFT_DISRUPTIVE 191 KSFT_DISRUPTIVE = get_bool(env, "DISRUPTIVE") 192 193 return env 194 195 196def ksft_run(cases=None, globs=None, case_pfx=None, args=()): 197 cases = cases or [] 198 199 if globs and case_pfx: 200 for key, value in globs.items(): 201 if not callable(value): 202 continue 203 for prefix in case_pfx: 204 if key.startswith(prefix): 205 cases.append(value) 206 break 207 208 totals = {"pass": 0, "fail": 0, "skip": 0, "xfail": 0} 209 210 print("TAP version 13") 211 print("1.." + str(len(cases))) 212 213 global KSFT_RESULT 214 cnt = 0 215 stop = False 216 for case in cases: 217 KSFT_RESULT = True 218 cnt += 1 219 comment = "" 220 cnt_key = "" 221 222 try: 223 case(*args) 224 except KsftSkipEx as e: 225 comment = "SKIP " + str(e) 226 cnt_key = 'skip' 227 except KsftXfailEx as e: 228 comment = "XFAIL " + str(e) 229 cnt_key = 'xfail' 230 except BaseException as e: 231 stop |= isinstance(e, KeyboardInterrupt) 232 tb = traceback.format_exc() 233 for line in tb.strip().split('\n'): 234 ksft_pr("Exception|", line) 235 if stop: 236 ksft_pr("Stopping tests due to KeyboardInterrupt.") 237 KSFT_RESULT = False 238 cnt_key = 'fail' 239 240 ksft_flush_defer() 241 242 if not cnt_key: 243 cnt_key = 'pass' if KSFT_RESULT else 'fail' 244 245 ktap_result(KSFT_RESULT, cnt, case, comment=comment) 246 totals[cnt_key] += 1 247 248 if stop: 249 break 250 251 print( 252 f"# Totals: pass:{totals['pass']} fail:{totals['fail']} xfail:{totals['xfail']} xpass:0 skip:{totals['skip']} error:0" 253 ) 254 255 256def ksft_exit(): 257 global KSFT_RESULT_ALL 258 sys.exit(0 if KSFT_RESULT_ALL else 1) 259