xref: /linux/tools/testing/selftests/net/lib/py/ksft.py (revision c4277d21ab694c7964a48759a5452e5bbbe12965)
1# SPDX-License-Identifier: GPL-2.0
2
3import functools
4import inspect
5import signal
6import sys
7import time
8import traceback
9from collections import namedtuple
10from .consts import KSFT_MAIN_NAME
11from . import utils
12
13KSFT_RESULT = None
14KSFT_RESULT_ALL = True
15KSFT_DISRUPTIVE = True
16
17
18class KsftFailEx(Exception):
19    pass
20
21
22class KsftSkipEx(Exception):
23    pass
24
25
26class KsftXfailEx(Exception):
27    pass
28
29
30class KsftTerminate(KeyboardInterrupt):
31    pass
32
33
34def ksft_pr(*objs, **kwargs):
35    kwargs["flush"] = True
36    print("#", *objs, **kwargs)
37
38
39def _fail(*args):
40    global KSFT_RESULT
41    KSFT_RESULT = False
42
43    stack = inspect.stack()
44    started = False
45    for frame in reversed(stack[2:]):
46        # Start printing from the test case function
47        if not started:
48            if frame.function == 'ksft_run':
49                started = True
50            continue
51
52        ksft_pr("Check| At " + frame.filename + ", line " + str(frame.lineno) +
53                ", in " + frame.function + ":")
54        ksft_pr("Check|     " + frame.code_context[0].strip())
55    ksft_pr(*args)
56
57
58def ksft_eq(a, b, comment=""):
59    global KSFT_RESULT
60    if a != b:
61        _fail("Check failed", a, "!=", b, comment)
62
63
64def ksft_ne(a, b, comment=""):
65    global KSFT_RESULT
66    if a == b:
67        _fail("Check failed", a, "==", b, comment)
68
69
70def ksft_true(a, comment=""):
71    if not a:
72        _fail("Check failed", a, "does not eval to True", comment)
73
74
75def ksft_not_none(a, comment=""):
76    if a is None:
77        _fail("Check failed", a, "is None", comment)
78
79
80def ksft_in(a, b, comment=""):
81    if a not in b:
82        _fail("Check failed", a, "not in", b, comment)
83
84
85def ksft_not_in(a, b, comment=""):
86    if a in b:
87        _fail("Check failed", a, "in", b, comment)
88
89
90def ksft_is(a, b, comment=""):
91    if a is not b:
92        _fail("Check failed", a, "is not", b, comment)
93
94
95def ksft_ge(a, b, comment=""):
96    if a < b:
97        _fail("Check failed", a, "<", b, comment)
98
99
100def ksft_gt(a, b, comment=""):
101    if a <= b:
102        _fail("Check failed", a, "<=", b, comment)
103
104
105def ksft_lt(a, b, comment=""):
106    if a >= b:
107        _fail("Check failed", a, ">=", b, comment)
108
109
110class ksft_raises:
111    def __init__(self, expected_type):
112        self.exception = None
113        self.expected_type = expected_type
114
115    def __enter__(self):
116        return self
117
118    def __exit__(self, exc_type, exc_val, exc_tb):
119        if exc_type is None:
120            _fail(f"Expected exception {str(self.expected_type.__name__)}, none raised")
121        elif self.expected_type != exc_type:
122            _fail(f"Expected exception {str(self.expected_type.__name__)}, raised {str(exc_type.__name__)}")
123        self.exception = exc_val
124        # Suppress the exception if its the expected one
125        return self.expected_type == exc_type
126
127
128def ksft_busy_wait(cond, sleep=0.005, deadline=1, comment=""):
129    end = time.monotonic() + deadline
130    while True:
131        if cond():
132            return
133        if time.monotonic() > end:
134            _fail("Waiting for condition timed out", comment)
135            return
136        time.sleep(sleep)
137
138
139def ktap_result(ok, cnt=1, case_name="", comment=""):
140    global KSFT_RESULT_ALL
141    KSFT_RESULT_ALL = KSFT_RESULT_ALL and ok
142
143    res = ""
144    if not ok:
145        res += "not "
146    res += "ok "
147    res += str(cnt) + " "
148    res += KSFT_MAIN_NAME
149    if case_name:
150        res += "." + case_name
151    if comment:
152        res += " # " + comment
153    print(res, flush=True)
154
155
156def _ksft_defer_arm(state):
157    """ Allow or disallow the use of defer() """
158    utils.GLOBAL_DEFER_ARMED = state
159
160
161def ksft_flush_defer():
162    global KSFT_RESULT
163
164    i = 0
165    qlen_start = len(utils.GLOBAL_DEFER_QUEUE)
166    while utils.GLOBAL_DEFER_QUEUE:
167        i += 1
168        entry = utils.GLOBAL_DEFER_QUEUE.pop()
169        try:
170            entry.exec_only()
171        except Exception:
172            ksft_pr(f"Exception while handling defer / cleanup (callback {i} of {qlen_start})!")
173            tb = traceback.format_exc()
174            for line in tb.strip().split('\n'):
175                ksft_pr("Defer Exception|", line)
176            KSFT_RESULT = False
177
178
179KsftCaseFunction = namedtuple("KsftCaseFunction",
180                              ['name', 'original_func', 'variants'])
181
182
183def ksft_disruptive(func):
184    """
185    Decorator that marks the test as disruptive (e.g. the test
186    that can down the interface). Disruptive tests can be skipped
187    by passing DISRUPTIVE=False environment variable.
188    """
189
190    @functools.wraps(func)
191    def wrapper(*args, **kwargs):
192        if not KSFT_DISRUPTIVE:
193            raise KsftSkipEx("marked as disruptive")
194        return func(*args, **kwargs)
195    return wrapper
196
197
198class KsftNamedVariant:
199    """ Named string name + argument list tuple for @ksft_variants """
200
201    def __init__(self, name, *params):
202        self.params = params
203        self.name = name or "_".join([str(x) for x in self.params])
204
205
206def ksft_variants(params):
207    """
208    Decorator defining the sets of inputs for a test.
209    The parameters will be included in the name of the resulting sub-case.
210    Parameters can be either single object, tuple or a KsftNamedVariant.
211    The argument can be a list or a generator.
212
213    Example:
214
215    @ksft_variants([
216        (1, "a"),
217        (2, "b"),
218        KsftNamedVariant("three", 3, "c"),
219    ])
220    def my_case(cfg, a, b):
221        pass # ...
222
223    ksft_run(cases=[my_case], args=(cfg, ))
224
225    Will generate cases:
226        my_case.1_a
227        my_case.2_b
228        my_case.three
229    """
230
231    return lambda func: KsftCaseFunction(func.__name__, func, params)
232
233
234def ksft_setup(env):
235    """
236    Setup test framework global state from the environment.
237    """
238
239    def get_bool(env, name):
240        value = env.get(name, "").lower()
241        if value in ["yes", "true"]:
242            return True
243        if value in ["no", "false"]:
244            return False
245        try:
246            return bool(int(value))
247        except Exception:
248            raise Exception(f"failed to parse {name}")
249
250    if "DISRUPTIVE" in env:
251        global KSFT_DISRUPTIVE
252        KSFT_DISRUPTIVE = get_bool(env, "DISRUPTIVE")
253
254    return env
255
256
257def _ksft_intr(signum, frame):
258    # ksft runner.sh sends 2 SIGTERMs in a row on a timeout
259    # if we don't ignore the second one it will stop us from handling cleanup
260    global term_cnt
261    term_cnt += 1
262    if term_cnt == 1:
263        raise KsftTerminate()
264    else:
265        ksft_pr(f"Ignoring SIGTERM (cnt: {term_cnt}), already exiting...")
266
267
268def _ksft_generate_test_cases(cases, globs, case_pfx, args):
269    """Generate a flat list of (func, args, name) tuples"""
270
271    cases = cases or []
272    test_cases = []
273
274    # If using the globs method find all relevant functions
275    if globs and case_pfx:
276        for key, value in globs.items():
277            if not callable(value):
278                continue
279            for prefix in case_pfx:
280                if key.startswith(prefix):
281                    cases.append(value)
282                    break
283
284    for func in cases:
285        if isinstance(func, KsftCaseFunction):
286            # Parametrized test - create case for each param
287            for param in func.variants:
288                if not isinstance(param, KsftNamedVariant):
289                    if not isinstance(param, tuple):
290                        param = (param, )
291                    param = KsftNamedVariant(None, *param)
292
293                test_cases.append((func.original_func,
294                                   (*args, *param.params),
295                                   func.name + "." + param.name))
296        else:
297            test_cases.append((func, args, func.__name__))
298
299    return test_cases
300
301
302def ksft_run(cases=None, globs=None, case_pfx=None, args=()):
303    test_cases = _ksft_generate_test_cases(cases, globs, case_pfx, args)
304
305    global term_cnt
306    term_cnt = 0
307    prev_sigterm = signal.signal(signal.SIGTERM, _ksft_intr)
308
309    totals = {"pass": 0, "fail": 0, "skip": 0, "xfail": 0}
310
311    print("TAP version 13", flush=True)
312    print("1.." + str(len(test_cases)), flush=True)
313
314    global KSFT_RESULT
315    cnt = 0
316    stop = False
317    for func, args, name in test_cases:
318        KSFT_RESULT = True
319        cnt += 1
320        comment = ""
321        cnt_key = ""
322
323        _ksft_defer_arm(True)
324        try:
325            func(*args)
326        except KsftSkipEx as e:
327            comment = "SKIP " + str(e)
328            cnt_key = 'skip'
329        except KsftXfailEx as e:
330            comment = "XFAIL " + str(e)
331            cnt_key = 'xfail'
332        except BaseException as e:
333            stop |= isinstance(e, KeyboardInterrupt)
334            tb = traceback.format_exc()
335            for line in tb.strip().split('\n'):
336                ksft_pr("Exception|", line)
337            if stop:
338                ksft_pr(f"Stopping tests due to {type(e).__name__}.")
339            KSFT_RESULT = False
340            cnt_key = 'fail'
341        _ksft_defer_arm(False)
342
343        try:
344            ksft_flush_defer()
345        except BaseException as e:
346            tb = traceback.format_exc()
347            for line in tb.strip().split('\n'):
348                ksft_pr("Exception|", line)
349            if isinstance(e, KeyboardInterrupt):
350                ksft_pr()
351                ksft_pr("WARN: defer() interrupted, cleanup may be incomplete.")
352                ksft_pr("      Attempting to finish cleanup before exiting.")
353                ksft_pr("      Interrupt again to exit immediately.")
354                ksft_pr()
355                stop = True
356            # Flush was interrupted, try to finish the job best we can
357            ksft_flush_defer()
358
359        if not cnt_key:
360            cnt_key = 'pass' if KSFT_RESULT else 'fail'
361
362        ktap_result(KSFT_RESULT, cnt, name, comment=comment)
363        totals[cnt_key] += 1
364
365        if stop:
366            break
367
368    signal.signal(signal.SIGTERM, prev_sigterm)
369
370    print(
371        f"# Totals: pass:{totals['pass']} fail:{totals['fail']} xfail:{totals['xfail']} xpass:0 skip:{totals['skip']} error:0"
372    )
373
374
375def ksft_exit():
376    global KSFT_RESULT_ALL
377    sys.exit(0 if KSFT_RESULT_ALL else 1)
378