xref: /linux/tools/testing/selftests/net/lib/py/ksft.py (revision eec8359f0797ef87c6ef6cbed6de08b02073b833)
1# SPDX-License-Identifier: GPL-2.0
2
3import builtins
4import functools
5import inspect
6import sys
7import time
8import traceback
9from .consts import KSFT_MAIN_NAME
10from .utils import global_defer_queue
11
12KSFT_RESULT = None
13KSFT_RESULT_ALL = True
14KSFT_DISRUPTIVE = True
15
16
17class KsftFailEx(Exception):
18    pass
19
20
21class KsftSkipEx(Exception):
22    pass
23
24
25class KsftXfailEx(Exception):
26    pass
27
28
29def ksft_pr(*objs, **kwargs):
30    print("#", *objs, **kwargs)
31
32
33def _fail(*args):
34    global KSFT_RESULT
35    KSFT_RESULT = False
36
37    stack = inspect.stack()
38    started = False
39    for frame in reversed(stack[2:]):
40        # Start printing from the test case function
41        if not started:
42            if frame.function == 'ksft_run':
43                started = True
44            continue
45
46        ksft_pr("Check| At " + frame.filename + ", line " + str(frame.lineno) +
47                ", in " + frame.function + ":")
48        ksft_pr("Check|     " + frame.code_context[0].strip())
49    ksft_pr(*args)
50
51
52def ksft_eq(a, b, comment=""):
53    global KSFT_RESULT
54    if a != b:
55        _fail("Check failed", a, "!=", b, comment)
56
57
58def ksft_ne(a, b, comment=""):
59    global KSFT_RESULT
60    if a == b:
61        _fail("Check failed", a, "==", b, comment)
62
63
64def ksft_true(a, comment=""):
65    if not a:
66        _fail("Check failed", a, "does not eval to True", comment)
67
68
69def ksft_in(a, b, comment=""):
70    if a not in b:
71        _fail("Check failed", a, "not in", b, comment)
72
73
74def ksft_is(a, b, comment=""):
75    if a is not b:
76        _fail("Check failed", a, "is not", b, comment)
77
78
79def ksft_ge(a, b, comment=""):
80    if a < b:
81        _fail("Check failed", a, "<", b, comment)
82
83
84def ksft_lt(a, b, comment=""):
85    if a >= b:
86        _fail("Check failed", a, ">=", b, comment)
87
88
89class ksft_raises:
90    def __init__(self, expected_type):
91        self.exception = None
92        self.expected_type = expected_type
93
94    def __enter__(self):
95        return self
96
97    def __exit__(self, exc_type, exc_val, exc_tb):
98        if exc_type is None:
99            _fail(f"Expected exception {str(self.expected_type.__name__)}, none raised")
100        elif self.expected_type != exc_type:
101            _fail(f"Expected exception {str(self.expected_type.__name__)}, raised {str(exc_type.__name__)}")
102        self.exception = exc_val
103        # Suppress the exception if its the expected one
104        return self.expected_type == exc_type
105
106
107def ksft_busy_wait(cond, sleep=0.005, deadline=1, comment=""):
108    end = time.monotonic() + deadline
109    while True:
110        if cond():
111            return
112        if time.monotonic() > end:
113            _fail("Waiting for condition timed out", comment)
114            return
115        time.sleep(sleep)
116
117
118def ktap_result(ok, cnt=1, case="", comment=""):
119    global KSFT_RESULT_ALL
120    KSFT_RESULT_ALL = KSFT_RESULT_ALL and ok
121
122    res = ""
123    if not ok:
124        res += "not "
125    res += "ok "
126    res += str(cnt) + " "
127    res += KSFT_MAIN_NAME
128    if case:
129        res += "." + str(case.__name__)
130    if comment:
131        res += " # " + comment
132    print(res)
133
134
135def ksft_flush_defer():
136    global KSFT_RESULT
137
138    i = 0
139    qlen_start = len(global_defer_queue)
140    while global_defer_queue:
141        i += 1
142        entry = global_defer_queue.pop()
143        try:
144            entry.exec_only()
145        except:
146            ksft_pr(f"Exception while handling defer / cleanup (callback {i} of {qlen_start})!")
147            tb = traceback.format_exc()
148            for line in tb.strip().split('\n'):
149                ksft_pr("Defer Exception|", line)
150            KSFT_RESULT = False
151
152
153def ksft_disruptive(func):
154    """
155    Decorator that marks the test as disruptive (e.g. the test
156    that can down the interface). Disruptive tests can be skipped
157    by passing DISRUPTIVE=False environment variable.
158    """
159
160    @functools.wraps(func)
161    def wrapper(*args, **kwargs):
162        if not KSFT_DISRUPTIVE:
163            raise KsftSkipEx(f"marked as disruptive")
164        return func(*args, **kwargs)
165    return wrapper
166
167
168def ksft_setup(env):
169    """
170    Setup test framework global state from the environment.
171    """
172
173    def get_bool(env, name):
174        value = env.get(name, "").lower()
175        if value in ["yes", "true"]:
176            return True
177        if value in ["no", "false"]:
178            return False
179        try:
180            return bool(int(value))
181        except:
182            raise Exception(f"failed to parse {name}")
183
184    if "DISRUPTIVE" in env:
185        global KSFT_DISRUPTIVE
186        KSFT_DISRUPTIVE = get_bool(env, "DISRUPTIVE")
187
188    return env
189
190
191def ksft_run(cases=None, globs=None, case_pfx=None, args=()):
192    cases = cases or []
193
194    if globs and case_pfx:
195        for key, value in globs.items():
196            if not callable(value):
197                continue
198            for prefix in case_pfx:
199                if key.startswith(prefix):
200                    cases.append(value)
201                    break
202
203    totals = {"pass": 0, "fail": 0, "skip": 0, "xfail": 0}
204
205    print("KTAP version 1")
206    print("1.." + str(len(cases)))
207
208    global KSFT_RESULT
209    cnt = 0
210    stop = False
211    for case in cases:
212        KSFT_RESULT = True
213        cnt += 1
214        comment = ""
215        cnt_key = ""
216
217        try:
218            case(*args)
219        except KsftSkipEx as e:
220            comment = "SKIP " + str(e)
221            cnt_key = 'skip'
222        except KsftXfailEx as e:
223            comment = "XFAIL " + str(e)
224            cnt_key = 'xfail'
225        except BaseException as e:
226            stop |= isinstance(e, KeyboardInterrupt)
227            tb = traceback.format_exc()
228            for line in tb.strip().split('\n'):
229                ksft_pr("Exception|", line)
230            if stop:
231                ksft_pr("Stopping tests due to KeyboardInterrupt.")
232            KSFT_RESULT = False
233            cnt_key = 'fail'
234
235        ksft_flush_defer()
236
237        if not cnt_key:
238            cnt_key = 'pass' if KSFT_RESULT else 'fail'
239
240        ktap_result(KSFT_RESULT, cnt, case, comment=comment)
241        totals[cnt_key] += 1
242
243        if stop:
244            break
245
246    print(
247        f"# Totals: pass:{totals['pass']} fail:{totals['fail']} xfail:{totals['xfail']} xpass:0 skip:{totals['skip']} error:0"
248    )
249
250
251def ksft_exit():
252    global KSFT_RESULT_ALL
253    sys.exit(0 if KSFT_RESULT_ALL else 1)
254