1# SPDX-License-Identifier: GPL-2.0 2 3import builtins 4import functools 5import inspect 6import sys 7import time 8import traceback 9from .consts import KSFT_MAIN_NAME 10from .utils import global_defer_queue 11 12KSFT_RESULT = None 13KSFT_RESULT_ALL = True 14KSFT_DISRUPTIVE = True 15 16 17class KsftFailEx(Exception): 18 pass 19 20 21class KsftSkipEx(Exception): 22 pass 23 24 25class KsftXfailEx(Exception): 26 pass 27 28 29def ksft_pr(*objs, **kwargs): 30 print("#", *objs, **kwargs) 31 32 33def _fail(*args): 34 global KSFT_RESULT 35 KSFT_RESULT = False 36 37 stack = inspect.stack() 38 started = False 39 for frame in reversed(stack[2:]): 40 # Start printing from the test case function 41 if not started: 42 if frame.function == 'ksft_run': 43 started = True 44 continue 45 46 ksft_pr("Check| At " + frame.filename + ", line " + str(frame.lineno) + 47 ", in " + frame.function + ":") 48 ksft_pr("Check| " + frame.code_context[0].strip()) 49 ksft_pr(*args) 50 51 52def ksft_eq(a, b, comment=""): 53 global KSFT_RESULT 54 if a != b: 55 _fail("Check failed", a, "!=", b, comment) 56 57 58def ksft_ne(a, b, comment=""): 59 global KSFT_RESULT 60 if a == b: 61 _fail("Check failed", a, "==", b, comment) 62 63 64def ksft_true(a, comment=""): 65 if not a: 66 _fail("Check failed", a, "does not eval to True", comment) 67 68 69def ksft_in(a, b, comment=""): 70 if a not in b: 71 _fail("Check failed", a, "not in", b, comment) 72 73 74def ksft_is(a, b, comment=""): 75 if a is not b: 76 _fail("Check failed", a, "is not", b, comment) 77 78 79def ksft_ge(a, b, comment=""): 80 if a < b: 81 _fail("Check failed", a, "<", b, comment) 82 83 84def ksft_lt(a, b, comment=""): 85 if a >= b: 86 _fail("Check failed", a, ">=", b, comment) 87 88 89class ksft_raises: 90 def __init__(self, expected_type): 91 self.exception = None 92 self.expected_type = expected_type 93 94 def __enter__(self): 95 return self 96 97 def __exit__(self, exc_type, exc_val, exc_tb): 98 if exc_type is None: 99 _fail(f"Expected exception {str(self.expected_type.__name__)}, none raised") 100 elif self.expected_type != exc_type: 101 _fail(f"Expected exception {str(self.expected_type.__name__)}, raised {str(exc_type.__name__)}") 102 self.exception = exc_val 103 # Suppress the exception if its the expected one 104 return self.expected_type == exc_type 105 106 107def ksft_busy_wait(cond, sleep=0.005, deadline=1, comment=""): 108 end = time.monotonic() + deadline 109 while True: 110 if cond(): 111 return 112 if time.monotonic() > end: 113 _fail("Waiting for condition timed out", comment) 114 return 115 time.sleep(sleep) 116 117 118def ktap_result(ok, cnt=1, case="", comment=""): 119 global KSFT_RESULT_ALL 120 KSFT_RESULT_ALL = KSFT_RESULT_ALL and ok 121 122 res = "" 123 if not ok: 124 res += "not " 125 res += "ok " 126 res += str(cnt) + " " 127 res += KSFT_MAIN_NAME 128 if case: 129 res += "." + str(case.__name__) 130 if comment: 131 res += " # " + comment 132 print(res) 133 134 135def ksft_flush_defer(): 136 global KSFT_RESULT 137 138 i = 0 139 qlen_start = len(global_defer_queue) 140 while global_defer_queue: 141 i += 1 142 entry = global_defer_queue.pop() 143 try: 144 entry.exec_only() 145 except: 146 ksft_pr(f"Exception while handling defer / cleanup (callback {i} of {qlen_start})!") 147 tb = traceback.format_exc() 148 for line in tb.strip().split('\n'): 149 ksft_pr("Defer Exception|", line) 150 KSFT_RESULT = False 151 152 153def ksft_disruptive(func): 154 """ 155 Decorator that marks the test as disruptive (e.g. the test 156 that can down the interface). Disruptive tests can be skipped 157 by passing DISRUPTIVE=False environment variable. 158 """ 159 160 @functools.wraps(func) 161 def wrapper(*args, **kwargs): 162 if not KSFT_DISRUPTIVE: 163 raise KsftSkipEx(f"marked as disruptive") 164 return func(*args, **kwargs) 165 return wrapper 166 167 168def ksft_setup(env): 169 """ 170 Setup test framework global state from the environment. 171 """ 172 173 def get_bool(env, name): 174 value = env.get(name, "").lower() 175 if value in ["yes", "true"]: 176 return True 177 if value in ["no", "false"]: 178 return False 179 try: 180 return bool(int(value)) 181 except: 182 raise Exception(f"failed to parse {name}") 183 184 if "DISRUPTIVE" in env: 185 global KSFT_DISRUPTIVE 186 KSFT_DISRUPTIVE = get_bool(env, "DISRUPTIVE") 187 188 return env 189 190 191def ksft_run(cases=None, globs=None, case_pfx=None, args=()): 192 cases = cases or [] 193 194 if globs and case_pfx: 195 for key, value in globs.items(): 196 if not callable(value): 197 continue 198 for prefix in case_pfx: 199 if key.startswith(prefix): 200 cases.append(value) 201 break 202 203 totals = {"pass": 0, "fail": 0, "skip": 0, "xfail": 0} 204 205 print("KTAP version 1") 206 print("1.." + str(len(cases))) 207 208 global KSFT_RESULT 209 cnt = 0 210 stop = False 211 for case in cases: 212 KSFT_RESULT = True 213 cnt += 1 214 comment = "" 215 cnt_key = "" 216 217 try: 218 case(*args) 219 except KsftSkipEx as e: 220 comment = "SKIP " + str(e) 221 cnt_key = 'skip' 222 except KsftXfailEx as e: 223 comment = "XFAIL " + str(e) 224 cnt_key = 'xfail' 225 except BaseException as e: 226 stop |= isinstance(e, KeyboardInterrupt) 227 tb = traceback.format_exc() 228 for line in tb.strip().split('\n'): 229 ksft_pr("Exception|", line) 230 if stop: 231 ksft_pr("Stopping tests due to KeyboardInterrupt.") 232 KSFT_RESULT = False 233 cnt_key = 'fail' 234 235 ksft_flush_defer() 236 237 if not cnt_key: 238 cnt_key = 'pass' if KSFT_RESULT else 'fail' 239 240 ktap_result(KSFT_RESULT, cnt, case, comment=comment) 241 totals[cnt_key] += 1 242 243 if stop: 244 break 245 246 print( 247 f"# Totals: pass:{totals['pass']} fail:{totals['fail']} xfail:{totals['xfail']} xpass:0 skip:{totals['skip']} error:0" 248 ) 249 250 251def ksft_exit(): 252 global KSFT_RESULT_ALL 253 sys.exit(0 if KSFT_RESULT_ALL else 1) 254