xref: /linux/tools/testing/selftests/net/lib/py/ksft.py (revision 80970e0fc07ecb1fbaefeb2c912aa2b0c04ed557)
1# SPDX-License-Identifier: GPL-2.0
2
3import functools
4import inspect
5import signal
6import sys
7import time
8import traceback
9from .consts import KSFT_MAIN_NAME
10from .utils import global_defer_queue
11
12KSFT_RESULT = None
13KSFT_RESULT_ALL = True
14KSFT_DISRUPTIVE = True
15
16
17class KsftFailEx(Exception):
18    pass
19
20
21class KsftSkipEx(Exception):
22    pass
23
24
25class KsftXfailEx(Exception):
26    pass
27
28
29class KsftTerminate(KeyboardInterrupt):
30    pass
31
32
33def ksft_pr(*objs, **kwargs):
34    kwargs["flush"] = True
35    print("#", *objs, **kwargs)
36
37
38def _fail(*args):
39    global KSFT_RESULT
40    KSFT_RESULT = False
41
42    stack = inspect.stack()
43    started = False
44    for frame in reversed(stack[2:]):
45        # Start printing from the test case function
46        if not started:
47            if frame.function == 'ksft_run':
48                started = True
49            continue
50
51        ksft_pr("Check| At " + frame.filename + ", line " + str(frame.lineno) +
52                ", in " + frame.function + ":")
53        ksft_pr("Check|     " + frame.code_context[0].strip())
54    ksft_pr(*args)
55
56
57def ksft_eq(a, b, comment=""):
58    global KSFT_RESULT
59    if a != b:
60        _fail("Check failed", a, "!=", b, comment)
61
62
63def ksft_ne(a, b, comment=""):
64    global KSFT_RESULT
65    if a == b:
66        _fail("Check failed", a, "==", b, comment)
67
68
69def ksft_true(a, comment=""):
70    if not a:
71        _fail("Check failed", a, "does not eval to True", comment)
72
73
74def ksft_not_none(a, comment=""):
75    if a is None:
76        _fail("Check failed", a, "is None", comment)
77
78
79def ksft_in(a, b, comment=""):
80    if a not in b:
81        _fail("Check failed", a, "not in", b, comment)
82
83
84def ksft_not_in(a, b, comment=""):
85    if a in b:
86        _fail("Check failed", a, "in", b, comment)
87
88
89def ksft_is(a, b, comment=""):
90    if a is not b:
91        _fail("Check failed", a, "is not", b, comment)
92
93
94def ksft_ge(a, b, comment=""):
95    if a < b:
96        _fail("Check failed", a, "<", b, comment)
97
98
99def ksft_gt(a, b, comment=""):
100    if a <= b:
101        _fail("Check failed", a, "<=", b, comment)
102
103
104def ksft_lt(a, b, comment=""):
105    if a >= b:
106        _fail("Check failed", a, ">=", b, comment)
107
108
109class ksft_raises:
110    def __init__(self, expected_type):
111        self.exception = None
112        self.expected_type = expected_type
113
114    def __enter__(self):
115        return self
116
117    def __exit__(self, exc_type, exc_val, exc_tb):
118        if exc_type is None:
119            _fail(f"Expected exception {str(self.expected_type.__name__)}, none raised")
120        elif self.expected_type != exc_type:
121            _fail(f"Expected exception {str(self.expected_type.__name__)}, raised {str(exc_type.__name__)}")
122        self.exception = exc_val
123        # Suppress the exception if its the expected one
124        return self.expected_type == exc_type
125
126
127def ksft_busy_wait(cond, sleep=0.005, deadline=1, comment=""):
128    end = time.monotonic() + deadline
129    while True:
130        if cond():
131            return
132        if time.monotonic() > end:
133            _fail("Waiting for condition timed out", comment)
134            return
135        time.sleep(sleep)
136
137
138def ktap_result(ok, cnt=1, case_name="", comment=""):
139    global KSFT_RESULT_ALL
140    KSFT_RESULT_ALL = KSFT_RESULT_ALL and ok
141
142    res = ""
143    if not ok:
144        res += "not "
145    res += "ok "
146    res += str(cnt) + " "
147    res += KSFT_MAIN_NAME
148    if case_name:
149        res += "." + case_name
150    if comment:
151        res += " # " + comment
152    print(res, flush=True)
153
154
155def ksft_flush_defer():
156    global KSFT_RESULT
157
158    i = 0
159    qlen_start = len(global_defer_queue)
160    while global_defer_queue:
161        i += 1
162        entry = global_defer_queue.pop()
163        try:
164            entry.exec_only()
165        except BaseException:
166            ksft_pr(f"Exception while handling defer / cleanup (callback {i} of {qlen_start})!")
167            tb = traceback.format_exc()
168            for line in tb.strip().split('\n'):
169                ksft_pr("Defer Exception|", line)
170            KSFT_RESULT = False
171
172
173def ksft_disruptive(func):
174    """
175    Decorator that marks the test as disruptive (e.g. the test
176    that can down the interface). Disruptive tests can be skipped
177    by passing DISRUPTIVE=False environment variable.
178    """
179
180    @functools.wraps(func)
181    def wrapper(*args, **kwargs):
182        if not KSFT_DISRUPTIVE:
183            raise KsftSkipEx("marked as disruptive")
184        return func(*args, **kwargs)
185    return wrapper
186
187
188def ksft_setup(env):
189    """
190    Setup test framework global state from the environment.
191    """
192
193    def get_bool(env, name):
194        value = env.get(name, "").lower()
195        if value in ["yes", "true"]:
196            return True
197        if value in ["no", "false"]:
198            return False
199        try:
200            return bool(int(value))
201        except Exception:
202            raise Exception(f"failed to parse {name}")
203
204    if "DISRUPTIVE" in env:
205        global KSFT_DISRUPTIVE
206        KSFT_DISRUPTIVE = get_bool(env, "DISRUPTIVE")
207
208    return env
209
210
211def _ksft_intr(signum, frame):
212    # ksft runner.sh sends 2 SIGTERMs in a row on a timeout
213    # if we don't ignore the second one it will stop us from handling cleanup
214    global term_cnt
215    term_cnt += 1
216    if term_cnt == 1:
217        raise KsftTerminate()
218    else:
219        ksft_pr(f"Ignoring SIGTERM (cnt: {term_cnt}), already exiting...")
220
221
222def _ksft_generate_test_cases(cases, globs, case_pfx, args):
223    """Generate a flat list of (func, args, name) tuples"""
224
225    cases = cases or []
226    test_cases = []
227
228    # If using the globs method find all relevant functions
229    if globs and case_pfx:
230        for key, value in globs.items():
231            if not callable(value):
232                continue
233            for prefix in case_pfx:
234                if key.startswith(prefix):
235                    cases.append(value)
236                    break
237
238    for func in cases:
239        test_cases.append((func, args, func.__name__))
240
241    return test_cases
242
243
244def ksft_run(cases=None, globs=None, case_pfx=None, args=()):
245    test_cases = _ksft_generate_test_cases(cases, globs, case_pfx, args)
246
247    global term_cnt
248    term_cnt = 0
249    prev_sigterm = signal.signal(signal.SIGTERM, _ksft_intr)
250
251    totals = {"pass": 0, "fail": 0, "skip": 0, "xfail": 0}
252
253    print("TAP version 13", flush=True)
254    print("1.." + str(len(test_cases)), flush=True)
255
256    global KSFT_RESULT
257    cnt = 0
258    stop = False
259    for func, args, name in test_cases:
260        KSFT_RESULT = True
261        cnt += 1
262        comment = ""
263        cnt_key = ""
264
265        try:
266            func(*args)
267        except KsftSkipEx as e:
268            comment = "SKIP " + str(e)
269            cnt_key = 'skip'
270        except KsftXfailEx as e:
271            comment = "XFAIL " + str(e)
272            cnt_key = 'xfail'
273        except BaseException as e:
274            stop |= isinstance(e, KeyboardInterrupt)
275            tb = traceback.format_exc()
276            for line in tb.strip().split('\n'):
277                ksft_pr("Exception|", line)
278            if stop:
279                ksft_pr(f"Stopping tests due to {type(e).__name__}.")
280            KSFT_RESULT = False
281            cnt_key = 'fail'
282
283        ksft_flush_defer()
284
285        if not cnt_key:
286            cnt_key = 'pass' if KSFT_RESULT else 'fail'
287
288        ktap_result(KSFT_RESULT, cnt, name, comment=comment)
289        totals[cnt_key] += 1
290
291        if stop:
292            break
293
294    signal.signal(signal.SIGTERM, prev_sigterm)
295
296    print(
297        f"# Totals: pass:{totals['pass']} fail:{totals['fail']} xfail:{totals['xfail']} xpass:0 skip:{totals['skip']} error:0"
298    )
299
300
301def ksft_exit():
302    global KSFT_RESULT_ALL
303    sys.exit(0 if KSFT_RESULT_ALL else 1)
304