xref: /linux/tools/testing/selftests/net/lib/py/ksft.py (revision dfecb0c5af3b07ebfa84be63a7a21bfc9e29a872)
1# SPDX-License-Identifier: GPL-2.0
2
3import functools
4import inspect
5import os
6import signal
7import sys
8import time
9import traceback
10from collections import namedtuple
11from .consts import KSFT_MAIN_NAME
12from . import utils
13
14KSFT_RESULT = None
15KSFT_RESULT_ALL = True
16KSFT_DISRUPTIVE = True
17
18
19class KsftFailEx(Exception):
20    pass
21
22
23class KsftSkipEx(Exception):
24    pass
25
26
27class KsftXfailEx(Exception):
28    pass
29
30
31class KsftTerminate(KeyboardInterrupt):
32    pass
33
34
35@functools.lru_cache()
36def _ksft_supports_color():
37    if os.environ.get("NO_COLOR") is not None:
38        return False
39    if not hasattr(sys.stdout, "isatty") or not sys.stdout.isatty():
40        return False
41    if os.environ.get("TERM") == "dumb":
42        return False
43    return True
44
45
46def ksft_pr(*objs, **kwargs):
47    """
48    Print logs to stdout.
49
50    Behaves like print() but log lines will be prefixed
51    with # to prevent breaking the TAP output formatting.
52
53    Extra arguments (on top of what print() supports):
54      line_pfx - add extra string before each line
55    """
56    sep = kwargs.pop("sep", " ")
57    pfx = kwargs.pop("line_pfx", "")
58    pfx = "#" + (" " + pfx if pfx else "")
59    kwargs["flush"] = True
60
61    text = sep.join(str(obj) for obj in objs)
62    prefixed = f"\n{pfx} ".join(text.split('\n'))
63    print(pfx, prefixed, **kwargs)
64
65
66def _fail(*args):
67    global KSFT_RESULT
68    KSFT_RESULT = False
69
70    stack = inspect.stack()
71    started = False
72    for frame in reversed(stack[2:]):
73        # Start printing from the test case function
74        if not started:
75            if frame.function == 'ksft_run':
76                started = True
77            continue
78
79        ksft_pr("Check| At " + frame.filename + ", line " + str(frame.lineno) +
80                ", in " + frame.function + ":")
81        ksft_pr("Check|     " + frame.code_context[0].strip())
82    ksft_pr(*args)
83
84
85def ksft_eq(a, b, comment=""):
86    global KSFT_RESULT
87    if a != b:
88        _fail("Check failed", a, "!=", b, comment)
89
90
91def ksft_ne(a, b, comment=""):
92    global KSFT_RESULT
93    if a == b:
94        _fail("Check failed", a, "==", b, comment)
95
96
97def ksft_true(a, comment=""):
98    if not a:
99        _fail("Check failed", a, "does not eval to True", comment)
100
101
102def ksft_not_none(a, comment=""):
103    if a is None:
104        _fail("Check failed", a, "is None", comment)
105
106
107def ksft_in(a, b, comment=""):
108    if a not in b:
109        _fail("Check failed", a, "not in", b, comment)
110
111
112def ksft_not_in(a, b, comment=""):
113    if a in b:
114        _fail("Check failed", a, "in", b, comment)
115
116
117def ksft_is(a, b, comment=""):
118    if a is not b:
119        _fail("Check failed", a, "is not", b, comment)
120
121
122def ksft_ge(a, b, comment=""):
123    if a < b:
124        _fail("Check failed", a, "<", b, comment)
125
126
127def ksft_gt(a, b, comment=""):
128    if a <= b:
129        _fail("Check failed", a, "<=", b, comment)
130
131
132def ksft_lt(a, b, comment=""):
133    if a >= b:
134        _fail("Check failed", a, ">=", b, comment)
135
136
137class ksft_raises:
138    def __init__(self, expected_type):
139        self.exception = None
140        self.expected_type = expected_type
141
142    def __enter__(self):
143        return self
144
145    def __exit__(self, exc_type, exc_val, exc_tb):
146        if exc_type is None:
147            _fail(f"Expected exception {str(self.expected_type.__name__)}, none raised")
148        elif self.expected_type != exc_type:
149            _fail(f"Expected exception {str(self.expected_type.__name__)}, raised {str(exc_type.__name__)}")
150        self.exception = exc_val
151        # Suppress the exception if its the expected one
152        return self.expected_type == exc_type
153
154
155def ksft_busy_wait(cond, sleep=0.005, deadline=1, comment=""):
156    end = time.monotonic() + deadline
157    while True:
158        if cond():
159            return
160        if time.monotonic() > end:
161            _fail("Waiting for condition timed out", comment)
162            return
163        time.sleep(sleep)
164
165
166def ktap_result(ok, cnt=1, case_name="", comment=""):
167    global KSFT_RESULT_ALL
168    KSFT_RESULT_ALL = KSFT_RESULT_ALL and ok
169
170    res = ""
171    if not ok:
172        res += "not "
173    res += "ok "
174    res += str(cnt) + " "
175    res += KSFT_MAIN_NAME
176    if case_name:
177        res += "." + case_name
178    if comment:
179        res += " # " + comment
180    if _ksft_supports_color():
181        if comment.startswith(("SKIP", "XFAIL")):
182            color = "\033[33m"
183        elif ok:
184            color = "\033[32m"
185        else:
186            color = "\033[31m"
187        res = color + res + "\033[0m"
188    print(res, flush=True)
189
190
191def _ksft_defer_arm(state):
192    """ Allow or disallow the use of defer() """
193    utils.GLOBAL_DEFER_ARMED = state
194
195
196def ksft_flush_defer():
197    global KSFT_RESULT
198
199    i = 0
200    qlen_start = len(utils.GLOBAL_DEFER_QUEUE)
201    while utils.GLOBAL_DEFER_QUEUE:
202        i += 1
203        entry = utils.GLOBAL_DEFER_QUEUE.pop()
204        try:
205            entry.exec_only()
206        except Exception:
207            ksft_pr(f"Exception while handling defer / cleanup (callback {i} of {qlen_start})!")
208            ksft_pr(traceback.format_exc(), line_pfx="Defer Exception|")
209            KSFT_RESULT = False
210
211
212KsftCaseFunction = namedtuple("KsftCaseFunction",
213                              ['name', 'original_func', 'variants'])
214
215
216def ksft_disruptive(func):
217    """
218    Decorator that marks the test as disruptive (e.g. the test
219    that can down the interface). Disruptive tests can be skipped
220    by passing DISRUPTIVE=False environment variable.
221    """
222
223    @functools.wraps(func)
224    def wrapper(*args, **kwargs):
225        if not KSFT_DISRUPTIVE:
226            raise KsftSkipEx("marked as disruptive")
227        return func(*args, **kwargs)
228    return wrapper
229
230
231class KsftNamedVariant:
232    """ Named string name + argument list tuple for @ksft_variants """
233
234    def __init__(self, name, *params):
235        self.params = params
236        self.name = name or "_".join([str(x) for x in self.params])
237
238
239def ksft_variants(params):
240    """
241    Decorator defining the sets of inputs for a test.
242    The parameters will be included in the name of the resulting sub-case.
243    Parameters can be either single object, tuple or a KsftNamedVariant.
244    The argument can be a list or a generator.
245
246    Example:
247
248    @ksft_variants([
249        (1, "a"),
250        (2, "b"),
251        KsftNamedVariant("three", 3, "c"),
252    ])
253    def my_case(cfg, a, b):
254        pass # ...
255
256    ksft_run(cases=[my_case], args=(cfg, ))
257
258    Will generate cases:
259        my_case.1_a
260        my_case.2_b
261        my_case.three
262    """
263
264    return lambda func: KsftCaseFunction(func.__name__, func, params)
265
266
267def ksft_setup(env):
268    """
269    Setup test framework global state from the environment.
270    """
271
272    def get_bool(env, name):
273        value = env.get(name, "").lower()
274        if value in ["yes", "true"]:
275            return True
276        if value in ["no", "false"]:
277            return False
278        try:
279            return bool(int(value))
280        except Exception:
281            raise Exception(f"failed to parse {name}")
282
283    if "DISRUPTIVE" in env:
284        global KSFT_DISRUPTIVE
285        KSFT_DISRUPTIVE = get_bool(env, "DISRUPTIVE")
286
287    return env
288
289
290def _ksft_intr(signum, frame):
291    # ksft runner.sh sends 2 SIGTERMs in a row on a timeout
292    # if we don't ignore the second one it will stop us from handling cleanup
293    global term_cnt
294    term_cnt += 1
295    if term_cnt == 1:
296        raise KsftTerminate()
297    else:
298        ksft_pr(f"Ignoring SIGTERM (cnt: {term_cnt}), already exiting...")
299
300
301def _ksft_generate_test_cases(cases, globs, case_pfx, args):
302    """Generate a flat list of (func, args, name) tuples"""
303
304    cases = cases or []
305    test_cases = []
306
307    # If using the globs method find all relevant functions
308    if globs and case_pfx:
309        for key, value in globs.items():
310            if not callable(value):
311                continue
312            for prefix in case_pfx:
313                if key.startswith(prefix):
314                    cases.append(value)
315                    break
316
317    for func in cases:
318        if isinstance(func, KsftCaseFunction):
319            # Parametrized test - create case for each param
320            for param in func.variants:
321                if not isinstance(param, KsftNamedVariant):
322                    if not isinstance(param, tuple):
323                        param = (param, )
324                    param = KsftNamedVariant(None, *param)
325
326                test_cases.append((func.original_func,
327                                   (*args, *param.params),
328                                   func.name + "." + param.name))
329        else:
330            test_cases.append((func, args, func.__name__))
331
332    return test_cases
333
334
335def ksft_run(cases=None, globs=None, case_pfx=None, args=()):
336    test_cases = _ksft_generate_test_cases(cases, globs, case_pfx, args)
337
338    global term_cnt
339    term_cnt = 0
340    prev_sigterm = signal.signal(signal.SIGTERM, _ksft_intr)
341
342    totals = {"pass": 0, "fail": 0, "skip": 0, "xfail": 0}
343
344    print("TAP version 13", flush=True)
345    print("1.." + str(len(test_cases)), flush=True)
346
347    global KSFT_RESULT
348    cnt = 0
349    stop = False
350    for func, args, name in test_cases:
351        KSFT_RESULT = True
352        cnt += 1
353        comment = ""
354        cnt_key = ""
355
356        _ksft_defer_arm(True)
357        try:
358            func(*args)
359        except KsftSkipEx as e:
360            comment = "SKIP " + str(e)
361            cnt_key = 'skip'
362        except KsftXfailEx as e:
363            comment = "XFAIL " + str(e)
364            cnt_key = 'xfail'
365        except BaseException as e:
366            stop |= isinstance(e, KeyboardInterrupt)
367            ksft_pr(traceback.format_exc(), line_pfx="Exception|")
368            if stop:
369                ksft_pr(f"Stopping tests due to {type(e).__name__}.")
370            KSFT_RESULT = False
371            cnt_key = 'fail'
372        _ksft_defer_arm(False)
373
374        try:
375            ksft_flush_defer()
376        except BaseException as e:
377            ksft_pr(traceback.format_exc(), line_pfx="Exception|")
378            if isinstance(e, KeyboardInterrupt):
379                ksft_pr()
380                ksft_pr("WARN: defer() interrupted, cleanup may be incomplete.")
381                ksft_pr("      Attempting to finish cleanup before exiting.")
382                ksft_pr("      Interrupt again to exit immediately.")
383                ksft_pr()
384                stop = True
385            # Flush was interrupted, try to finish the job best we can
386            ksft_flush_defer()
387
388        if not cnt_key:
389            cnt_key = 'pass' if KSFT_RESULT else 'fail'
390
391        ktap_result(KSFT_RESULT, cnt, name, comment=comment)
392        totals[cnt_key] += 1
393
394        if stop:
395            break
396
397    signal.signal(signal.SIGTERM, prev_sigterm)
398
399    print(
400        f"# Totals: pass:{totals['pass']} fail:{totals['fail']} xfail:{totals['xfail']} xpass:0 skip:{totals['skip']} error:0"
401    )
402
403
404def ksft_exit():
405    global KSFT_RESULT_ALL
406    sys.exit(0 if KSFT_RESULT_ALL else 1)
407