xref: /linux/tools/testing/selftests/net/lib/py/ksft.py (revision df9c299371054cb725eef730fd0f1d0fe2ed6bb0)
1# SPDX-License-Identifier: GPL-2.0
2
3import builtins
4import functools
5import inspect
6import signal
7import sys
8import time
9import traceback
10from .consts import KSFT_MAIN_NAME
11from .utils import global_defer_queue
12
13KSFT_RESULT = None
14KSFT_RESULT_ALL = True
15KSFT_DISRUPTIVE = True
16
17
18class KsftFailEx(Exception):
19    pass
20
21
22class KsftSkipEx(Exception):
23    pass
24
25
26class KsftXfailEx(Exception):
27    pass
28
29
30class KsftTerminate(KeyboardInterrupt):
31    pass
32
33
34def ksft_pr(*objs, **kwargs):
35    print("#", *objs, **kwargs)
36
37
38def _fail(*args):
39    global KSFT_RESULT
40    KSFT_RESULT = False
41
42    stack = inspect.stack()
43    started = False
44    for frame in reversed(stack[2:]):
45        # Start printing from the test case function
46        if not started:
47            if frame.function == 'ksft_run':
48                started = True
49            continue
50
51        ksft_pr("Check| At " + frame.filename + ", line " + str(frame.lineno) +
52                ", in " + frame.function + ":")
53        ksft_pr("Check|     " + frame.code_context[0].strip())
54    ksft_pr(*args)
55
56
57def ksft_eq(a, b, comment=""):
58    global KSFT_RESULT
59    if a != b:
60        _fail("Check failed", a, "!=", b, comment)
61
62
63def ksft_ne(a, b, comment=""):
64    global KSFT_RESULT
65    if a == b:
66        _fail("Check failed", a, "==", b, comment)
67
68
69def ksft_true(a, comment=""):
70    if not a:
71        _fail("Check failed", a, "does not eval to True", comment)
72
73
74def ksft_in(a, b, comment=""):
75    if a not in b:
76        _fail("Check failed", a, "not in", b, comment)
77
78
79def ksft_not_in(a, b, comment=""):
80    if a in b:
81        _fail("Check failed", a, "in", b, comment)
82
83
84def ksft_is(a, b, comment=""):
85    if a is not b:
86        _fail("Check failed", a, "is not", b, comment)
87
88
89def ksft_ge(a, b, comment=""):
90    if a < b:
91        _fail("Check failed", a, "<", b, comment)
92
93
94def ksft_lt(a, b, comment=""):
95    if a >= b:
96        _fail("Check failed", a, ">=", b, comment)
97
98
99class ksft_raises:
100    def __init__(self, expected_type):
101        self.exception = None
102        self.expected_type = expected_type
103
104    def __enter__(self):
105        return self
106
107    def __exit__(self, exc_type, exc_val, exc_tb):
108        if exc_type is None:
109            _fail(f"Expected exception {str(self.expected_type.__name__)}, none raised")
110        elif self.expected_type != exc_type:
111            _fail(f"Expected exception {str(self.expected_type.__name__)}, raised {str(exc_type.__name__)}")
112        self.exception = exc_val
113        # Suppress the exception if its the expected one
114        return self.expected_type == exc_type
115
116
117def ksft_busy_wait(cond, sleep=0.005, deadline=1, comment=""):
118    end = time.monotonic() + deadline
119    while True:
120        if cond():
121            return
122        if time.monotonic() > end:
123            _fail("Waiting for condition timed out", comment)
124            return
125        time.sleep(sleep)
126
127
128def ktap_result(ok, cnt=1, case="", comment=""):
129    global KSFT_RESULT_ALL
130    KSFT_RESULT_ALL = KSFT_RESULT_ALL and ok
131
132    res = ""
133    if not ok:
134        res += "not "
135    res += "ok "
136    res += str(cnt) + " "
137    res += KSFT_MAIN_NAME
138    if case:
139        res += "." + str(case.__name__)
140    if comment:
141        res += " # " + comment
142    print(res)
143
144
145def ksft_flush_defer():
146    global KSFT_RESULT
147
148    i = 0
149    qlen_start = len(global_defer_queue)
150    while global_defer_queue:
151        i += 1
152        entry = global_defer_queue.pop()
153        try:
154            entry.exec_only()
155        except:
156            ksft_pr(f"Exception while handling defer / cleanup (callback {i} of {qlen_start})!")
157            tb = traceback.format_exc()
158            for line in tb.strip().split('\n'):
159                ksft_pr("Defer Exception|", line)
160            KSFT_RESULT = False
161
162
163def ksft_disruptive(func):
164    """
165    Decorator that marks the test as disruptive (e.g. the test
166    that can down the interface). Disruptive tests can be skipped
167    by passing DISRUPTIVE=False environment variable.
168    """
169
170    @functools.wraps(func)
171    def wrapper(*args, **kwargs):
172        if not KSFT_DISRUPTIVE:
173            raise KsftSkipEx(f"marked as disruptive")
174        return func(*args, **kwargs)
175    return wrapper
176
177
178def ksft_setup(env):
179    """
180    Setup test framework global state from the environment.
181    """
182
183    def get_bool(env, name):
184        value = env.get(name, "").lower()
185        if value in ["yes", "true"]:
186            return True
187        if value in ["no", "false"]:
188            return False
189        try:
190            return bool(int(value))
191        except:
192            raise Exception(f"failed to parse {name}")
193
194    if "DISRUPTIVE" in env:
195        global KSFT_DISRUPTIVE
196        KSFT_DISRUPTIVE = get_bool(env, "DISRUPTIVE")
197
198    return env
199
200
201def _ksft_intr(signum, frame):
202    # ksft runner.sh sends 2 SIGTERMs in a row on a timeout
203    # if we don't ignore the second one it will stop us from handling cleanup
204    global term_cnt
205    term_cnt += 1
206    if term_cnt == 1:
207        raise KsftTerminate()
208    else:
209        ksft_pr(f"Ignoring SIGTERM (cnt: {term_cnt}), already exiting...")
210
211
212def ksft_run(cases=None, globs=None, case_pfx=None, args=()):
213    cases = cases or []
214
215    if globs and case_pfx:
216        for key, value in globs.items():
217            if not callable(value):
218                continue
219            for prefix in case_pfx:
220                if key.startswith(prefix):
221                    cases.append(value)
222                    break
223
224    global term_cnt
225    term_cnt = 0
226    prev_sigterm = signal.signal(signal.SIGTERM, _ksft_intr)
227
228    totals = {"pass": 0, "fail": 0, "skip": 0, "xfail": 0}
229
230    print("TAP version 13")
231    print("1.." + str(len(cases)))
232
233    global KSFT_RESULT
234    cnt = 0
235    stop = False
236    for case in cases:
237        KSFT_RESULT = True
238        cnt += 1
239        comment = ""
240        cnt_key = ""
241
242        try:
243            case(*args)
244        except KsftSkipEx as e:
245            comment = "SKIP " + str(e)
246            cnt_key = 'skip'
247        except KsftXfailEx as e:
248            comment = "XFAIL " + str(e)
249            cnt_key = 'xfail'
250        except BaseException as e:
251            stop |= isinstance(e, KeyboardInterrupt)
252            tb = traceback.format_exc()
253            for line in tb.strip().split('\n'):
254                ksft_pr("Exception|", line)
255            if stop:
256                ksft_pr(f"Stopping tests due to {type(e).__name__}.")
257            KSFT_RESULT = False
258            cnt_key = 'fail'
259
260        ksft_flush_defer()
261
262        if not cnt_key:
263            cnt_key = 'pass' if KSFT_RESULT else 'fail'
264
265        ktap_result(KSFT_RESULT, cnt, case, comment=comment)
266        totals[cnt_key] += 1
267
268        if stop:
269            break
270
271    signal.signal(signal.SIGTERM, prev_sigterm)
272
273    print(
274        f"# Totals: pass:{totals['pass']} fail:{totals['fail']} xfail:{totals['xfail']} xpass:0 skip:{totals['skip']} error:0"
275    )
276
277
278def ksft_exit():
279    global KSFT_RESULT_ALL
280    sys.exit(0 if KSFT_RESULT_ALL else 1)
281