xref: /linux/tools/testing/selftests/net/lib/py/ksft.py (revision eb3ab13d997a2f12ec9d557b6ae2aea4e28e2bc3)
1# SPDX-License-Identifier: GPL-2.0
2
3import builtins
4import functools
5import inspect
6import sys
7import time
8import traceback
9from .consts import KSFT_MAIN_NAME
10from .utils import global_defer_queue
11
12KSFT_RESULT = None
13KSFT_RESULT_ALL = True
14KSFT_DISRUPTIVE = True
15
16
17class KsftFailEx(Exception):
18    pass
19
20
21class KsftSkipEx(Exception):
22    pass
23
24
25class KsftXfailEx(Exception):
26    pass
27
28
29def ksft_pr(*objs, **kwargs):
30    print("#", *objs, **kwargs)
31
32
33def _fail(*args):
34    global KSFT_RESULT
35    KSFT_RESULT = False
36
37    stack = inspect.stack()
38    started = False
39    for frame in reversed(stack[2:]):
40        # Start printing from the test case function
41        if not started:
42            if frame.function == 'ksft_run':
43                started = True
44            continue
45
46        ksft_pr("Check| At " + frame.filename + ", line " + str(frame.lineno) +
47                ", in " + frame.function + ":")
48        ksft_pr("Check|     " + frame.code_context[0].strip())
49    ksft_pr(*args)
50
51
52def ksft_eq(a, b, comment=""):
53    global KSFT_RESULT
54    if a != b:
55        _fail("Check failed", a, "!=", b, comment)
56
57
58def ksft_true(a, comment=""):
59    if not a:
60        _fail("Check failed", a, "does not eval to True", comment)
61
62
63def ksft_in(a, b, comment=""):
64    if a not in b:
65        _fail("Check failed", a, "not in", b, comment)
66
67
68def ksft_ge(a, b, comment=""):
69    if a < b:
70        _fail("Check failed", a, "<", b, comment)
71
72
73def ksft_lt(a, b, comment=""):
74    if a >= b:
75        _fail("Check failed", a, ">=", b, comment)
76
77
78class ksft_raises:
79    def __init__(self, expected_type):
80        self.exception = None
81        self.expected_type = expected_type
82
83    def __enter__(self):
84        return self
85
86    def __exit__(self, exc_type, exc_val, exc_tb):
87        if exc_type is None:
88            _fail(f"Expected exception {str(self.expected_type.__name__)}, none raised")
89        elif self.expected_type != exc_type:
90            _fail(f"Expected exception {str(self.expected_type.__name__)}, raised {str(exc_type.__name__)}")
91        self.exception = exc_val
92        # Suppress the exception if its the expected one
93        return self.expected_type == exc_type
94
95
96def ksft_busy_wait(cond, sleep=0.005, deadline=1, comment=""):
97    end = time.monotonic() + deadline
98    while True:
99        if cond():
100            return
101        if time.monotonic() > end:
102            _fail("Waiting for condition timed out", comment)
103            return
104        time.sleep(sleep)
105
106
107def ktap_result(ok, cnt=1, case="", comment=""):
108    global KSFT_RESULT_ALL
109    KSFT_RESULT_ALL = KSFT_RESULT_ALL and ok
110
111    res = ""
112    if not ok:
113        res += "not "
114    res += "ok "
115    res += str(cnt) + " "
116    res += KSFT_MAIN_NAME
117    if case:
118        res += "." + str(case.__name__)
119    if comment:
120        res += " # " + comment
121    print(res)
122
123
124def ksft_flush_defer():
125    global KSFT_RESULT
126
127    i = 0
128    qlen_start = len(global_defer_queue)
129    while global_defer_queue:
130        i += 1
131        entry = global_defer_queue.pop()
132        try:
133            entry.exec_only()
134        except:
135            ksft_pr(f"Exception while handling defer / cleanup (callback {i} of {qlen_start})!")
136            tb = traceback.format_exc()
137            for line in tb.strip().split('\n'):
138                ksft_pr("Defer Exception|", line)
139            KSFT_RESULT = False
140
141
142def ksft_disruptive(func):
143    """
144    Decorator that marks the test as disruptive (e.g. the test
145    that can down the interface). Disruptive tests can be skipped
146    by passing DISRUPTIVE=False environment variable.
147    """
148
149    @functools.wraps(func)
150    def wrapper(*args, **kwargs):
151        if not KSFT_DISRUPTIVE:
152            raise KsftSkipEx(f"marked as disruptive")
153        return func(*args, **kwargs)
154    return wrapper
155
156
157def ksft_setup(env):
158    """
159    Setup test framework global state from the environment.
160    """
161
162    def get_bool(env, name):
163        value = env.get(name, "").lower()
164        if value in ["yes", "true"]:
165            return True
166        if value in ["no", "false"]:
167            return False
168        try:
169            return bool(int(value))
170        except:
171            raise Exception(f"failed to parse {name}")
172
173    if "DISRUPTIVE" in env:
174        global KSFT_DISRUPTIVE
175        KSFT_DISRUPTIVE = get_bool(env, "DISRUPTIVE")
176
177    return env
178
179
180def ksft_run(cases=None, globs=None, case_pfx=None, args=()):
181    cases = cases or []
182
183    if globs and case_pfx:
184        for key, value in globs.items():
185            if not callable(value):
186                continue
187            for prefix in case_pfx:
188                if key.startswith(prefix):
189                    cases.append(value)
190                    break
191
192    totals = {"pass": 0, "fail": 0, "skip": 0, "xfail": 0}
193
194    print("KTAP version 1")
195    print("1.." + str(len(cases)))
196
197    global KSFT_RESULT
198    cnt = 0
199    stop = False
200    for case in cases:
201        KSFT_RESULT = True
202        cnt += 1
203        comment = ""
204        cnt_key = ""
205
206        try:
207            case(*args)
208        except KsftSkipEx as e:
209            comment = "SKIP " + str(e)
210            cnt_key = 'skip'
211        except KsftXfailEx as e:
212            comment = "XFAIL " + str(e)
213            cnt_key = 'xfail'
214        except BaseException as e:
215            stop |= isinstance(e, KeyboardInterrupt)
216            tb = traceback.format_exc()
217            for line in tb.strip().split('\n'):
218                ksft_pr("Exception|", line)
219            if stop:
220                ksft_pr("Stopping tests due to KeyboardInterrupt.")
221            KSFT_RESULT = False
222            cnt_key = 'fail'
223
224        ksft_flush_defer()
225
226        if not cnt_key:
227            cnt_key = 'pass' if KSFT_RESULT else 'fail'
228
229        ktap_result(KSFT_RESULT, cnt, case, comment=comment)
230        totals[cnt_key] += 1
231
232        if stop:
233            break
234
235    print(
236        f"# Totals: pass:{totals['pass']} fail:{totals['fail']} xfail:{totals['xfail']} xpass:0 skip:{totals['skip']} error:0"
237    )
238
239
240def ksft_exit():
241    global KSFT_RESULT_ALL
242    sys.exit(0 if KSFT_RESULT_ALL else 1)
243