xref: /linux/tools/testing/selftests/net/lib/py/ksft.py (revision 81b89085319b2b20426d8c9468d8c508dcefdaaf)
1# SPDX-License-Identifier: GPL-2.0
2
3import builtins
4import functools
5import inspect
6import signal
7import sys
8import time
9import traceback
10from .consts import KSFT_MAIN_NAME
11from .utils import global_defer_queue
12
13KSFT_RESULT = None
14KSFT_RESULT_ALL = True
15KSFT_DISRUPTIVE = True
16
17
18class KsftFailEx(Exception):
19    pass
20
21
22class KsftSkipEx(Exception):
23    pass
24
25
26class KsftXfailEx(Exception):
27    pass
28
29
30class KsftTerminate(KeyboardInterrupt):
31    pass
32
33
34def ksft_pr(*objs, **kwargs):
35    kwargs["flush"] = True
36    print("#", *objs, **kwargs)
37
38
39def _fail(*args):
40    global KSFT_RESULT
41    KSFT_RESULT = False
42
43    stack = inspect.stack()
44    started = False
45    for frame in reversed(stack[2:]):
46        # Start printing from the test case function
47        if not started:
48            if frame.function == 'ksft_run':
49                started = True
50            continue
51
52        ksft_pr("Check| At " + frame.filename + ", line " + str(frame.lineno) +
53                ", in " + frame.function + ":")
54        ksft_pr("Check|     " + frame.code_context[0].strip())
55    ksft_pr(*args)
56
57
58def ksft_eq(a, b, comment=""):
59    global KSFT_RESULT
60    if a != b:
61        _fail("Check failed", a, "!=", b, comment)
62
63
64def ksft_ne(a, b, comment=""):
65    global KSFT_RESULT
66    if a == b:
67        _fail("Check failed", a, "==", b, comment)
68
69
70def ksft_true(a, comment=""):
71    if not a:
72        _fail("Check failed", a, "does not eval to True", comment)
73
74
75def ksft_in(a, b, comment=""):
76    if a not in b:
77        _fail("Check failed", a, "not in", b, comment)
78
79
80def ksft_not_in(a, b, comment=""):
81    if a in b:
82        _fail("Check failed", a, "in", b, comment)
83
84
85def ksft_is(a, b, comment=""):
86    if a is not b:
87        _fail("Check failed", a, "is not", b, comment)
88
89
90def ksft_ge(a, b, comment=""):
91    if a < b:
92        _fail("Check failed", a, "<", b, comment)
93
94
95def ksft_gt(a, b, comment=""):
96    if a <= b:
97        _fail("Check failed", a, "<=", b, comment)
98
99
100def ksft_lt(a, b, comment=""):
101    if a >= b:
102        _fail("Check failed", a, ">=", b, comment)
103
104
105class ksft_raises:
106    def __init__(self, expected_type):
107        self.exception = None
108        self.expected_type = expected_type
109
110    def __enter__(self):
111        return self
112
113    def __exit__(self, exc_type, exc_val, exc_tb):
114        if exc_type is None:
115            _fail(f"Expected exception {str(self.expected_type.__name__)}, none raised")
116        elif self.expected_type != exc_type:
117            _fail(f"Expected exception {str(self.expected_type.__name__)}, raised {str(exc_type.__name__)}")
118        self.exception = exc_val
119        # Suppress the exception if its the expected one
120        return self.expected_type == exc_type
121
122
123def ksft_busy_wait(cond, sleep=0.005, deadline=1, comment=""):
124    end = time.monotonic() + deadline
125    while True:
126        if cond():
127            return
128        if time.monotonic() > end:
129            _fail("Waiting for condition timed out", comment)
130            return
131        time.sleep(sleep)
132
133
134def ktap_result(ok, cnt=1, case="", comment=""):
135    global KSFT_RESULT_ALL
136    KSFT_RESULT_ALL = KSFT_RESULT_ALL and ok
137
138    res = ""
139    if not ok:
140        res += "not "
141    res += "ok "
142    res += str(cnt) + " "
143    res += KSFT_MAIN_NAME
144    if case:
145        res += "." + str(case.__name__)
146    if comment:
147        res += " # " + comment
148    print(res, flush=True)
149
150
151def ksft_flush_defer():
152    global KSFT_RESULT
153
154    i = 0
155    qlen_start = len(global_defer_queue)
156    while global_defer_queue:
157        i += 1
158        entry = global_defer_queue.pop()
159        try:
160            entry.exec_only()
161        except:
162            ksft_pr(f"Exception while handling defer / cleanup (callback {i} of {qlen_start})!")
163            tb = traceback.format_exc()
164            for line in tb.strip().split('\n'):
165                ksft_pr("Defer Exception|", line)
166            KSFT_RESULT = False
167
168
169def ksft_disruptive(func):
170    """
171    Decorator that marks the test as disruptive (e.g. the test
172    that can down the interface). Disruptive tests can be skipped
173    by passing DISRUPTIVE=False environment variable.
174    """
175
176    @functools.wraps(func)
177    def wrapper(*args, **kwargs):
178        if not KSFT_DISRUPTIVE:
179            raise KsftSkipEx(f"marked as disruptive")
180        return func(*args, **kwargs)
181    return wrapper
182
183
184def ksft_setup(env):
185    """
186    Setup test framework global state from the environment.
187    """
188
189    def get_bool(env, name):
190        value = env.get(name, "").lower()
191        if value in ["yes", "true"]:
192            return True
193        if value in ["no", "false"]:
194            return False
195        try:
196            return bool(int(value))
197        except:
198            raise Exception(f"failed to parse {name}")
199
200    if "DISRUPTIVE" in env:
201        global KSFT_DISRUPTIVE
202        KSFT_DISRUPTIVE = get_bool(env, "DISRUPTIVE")
203
204    return env
205
206
207def _ksft_intr(signum, frame):
208    # ksft runner.sh sends 2 SIGTERMs in a row on a timeout
209    # if we don't ignore the second one it will stop us from handling cleanup
210    global term_cnt
211    term_cnt += 1
212    if term_cnt == 1:
213        raise KsftTerminate()
214    else:
215        ksft_pr(f"Ignoring SIGTERM (cnt: {term_cnt}), already exiting...")
216
217
218def ksft_run(cases=None, globs=None, case_pfx=None, args=()):
219    cases = cases or []
220
221    if globs and case_pfx:
222        for key, value in globs.items():
223            if not callable(value):
224                continue
225            for prefix in case_pfx:
226                if key.startswith(prefix):
227                    cases.append(value)
228                    break
229
230    global term_cnt
231    term_cnt = 0
232    prev_sigterm = signal.signal(signal.SIGTERM, _ksft_intr)
233
234    totals = {"pass": 0, "fail": 0, "skip": 0, "xfail": 0}
235
236    print("TAP version 13", flush=True)
237    print("1.." + str(len(cases)), flush=True)
238
239    global KSFT_RESULT
240    cnt = 0
241    stop = False
242    for case in cases:
243        KSFT_RESULT = True
244        cnt += 1
245        comment = ""
246        cnt_key = ""
247
248        try:
249            case(*args)
250        except KsftSkipEx as e:
251            comment = "SKIP " + str(e)
252            cnt_key = 'skip'
253        except KsftXfailEx as e:
254            comment = "XFAIL " + str(e)
255            cnt_key = 'xfail'
256        except BaseException as e:
257            stop |= isinstance(e, KeyboardInterrupt)
258            tb = traceback.format_exc()
259            for line in tb.strip().split('\n'):
260                ksft_pr("Exception|", line)
261            if stop:
262                ksft_pr(f"Stopping tests due to {type(e).__name__}.")
263            KSFT_RESULT = False
264            cnt_key = 'fail'
265
266        ksft_flush_defer()
267
268        if not cnt_key:
269            cnt_key = 'pass' if KSFT_RESULT else 'fail'
270
271        ktap_result(KSFT_RESULT, cnt, case, comment=comment)
272        totals[cnt_key] += 1
273
274        if stop:
275            break
276
277    signal.signal(signal.SIGTERM, prev_sigterm)
278
279    print(
280        f"# Totals: pass:{totals['pass']} fail:{totals['fail']} xfail:{totals['xfail']} xpass:0 skip:{totals['skip']} error:0"
281    )
282
283
284def ksft_exit():
285    global KSFT_RESULT_ALL
286    sys.exit(0 if KSFT_RESULT_ALL else 1)
287