xref: /linux/tools/testing/selftests/net/lib/py/ksft.py (revision b324192e36ecea472077f9c9e32bcac8bbafeed6)
1# SPDX-License-Identifier: GPL-2.0
2
3import functools
4import inspect
5import signal
6import sys
7import time
8import traceback
9from collections import namedtuple
10from .consts import KSFT_MAIN_NAME
11from . import utils
12
13KSFT_RESULT = None
14KSFT_RESULT_ALL = True
15KSFT_DISRUPTIVE = True
16
17
18class KsftFailEx(Exception):
19    pass
20
21
22class KsftSkipEx(Exception):
23    pass
24
25
26class KsftXfailEx(Exception):
27    pass
28
29
30class KsftTerminate(KeyboardInterrupt):
31    pass
32
33
34def ksft_pr(*objs, **kwargs):
35    """
36    Print logs to stdout.
37
38    Behaves like print() but log lines will be prefixed
39    with # to prevent breaking the TAP output formatting.
40
41    Extra arguments (on top of what print() supports):
42      line_pfx - add extra string before each line
43    """
44    sep = kwargs.pop("sep", " ")
45    pfx = kwargs.pop("line_pfx", "")
46    pfx = "#" + (" " + pfx if pfx else "")
47    kwargs["flush"] = True
48
49    text = sep.join(str(obj) for obj in objs)
50    prefixed = f"\n{pfx} ".join(text.split('\n'))
51    print(pfx, prefixed, **kwargs)
52
53
54def _fail(*args):
55    global KSFT_RESULT
56    KSFT_RESULT = False
57
58    stack = inspect.stack()
59    started = False
60    for frame in reversed(stack[2:]):
61        # Start printing from the test case function
62        if not started:
63            if frame.function == 'ksft_run':
64                started = True
65            continue
66
67        ksft_pr("Check| At " + frame.filename + ", line " + str(frame.lineno) +
68                ", in " + frame.function + ":")
69        ksft_pr("Check|     " + frame.code_context[0].strip())
70    ksft_pr(*args)
71
72
73def ksft_eq(a, b, comment=""):
74    global KSFT_RESULT
75    if a != b:
76        _fail("Check failed", a, "!=", b, comment)
77
78
79def ksft_ne(a, b, comment=""):
80    global KSFT_RESULT
81    if a == b:
82        _fail("Check failed", a, "==", b, comment)
83
84
85def ksft_true(a, comment=""):
86    if not a:
87        _fail("Check failed", a, "does not eval to True", comment)
88
89
90def ksft_not_none(a, comment=""):
91    if a is None:
92        _fail("Check failed", a, "is None", comment)
93
94
95def ksft_in(a, b, comment=""):
96    if a not in b:
97        _fail("Check failed", a, "not in", b, comment)
98
99
100def ksft_not_in(a, b, comment=""):
101    if a in b:
102        _fail("Check failed", a, "in", b, comment)
103
104
105def ksft_is(a, b, comment=""):
106    if a is not b:
107        _fail("Check failed", a, "is not", b, comment)
108
109
110def ksft_ge(a, b, comment=""):
111    if a < b:
112        _fail("Check failed", a, "<", b, comment)
113
114
115def ksft_gt(a, b, comment=""):
116    if a <= b:
117        _fail("Check failed", a, "<=", b, comment)
118
119
120def ksft_lt(a, b, comment=""):
121    if a >= b:
122        _fail("Check failed", a, ">=", b, comment)
123
124
125class ksft_raises:
126    def __init__(self, expected_type):
127        self.exception = None
128        self.expected_type = expected_type
129
130    def __enter__(self):
131        return self
132
133    def __exit__(self, exc_type, exc_val, exc_tb):
134        if exc_type is None:
135            _fail(f"Expected exception {str(self.expected_type.__name__)}, none raised")
136        elif self.expected_type != exc_type:
137            _fail(f"Expected exception {str(self.expected_type.__name__)}, raised {str(exc_type.__name__)}")
138        self.exception = exc_val
139        # Suppress the exception if its the expected one
140        return self.expected_type == exc_type
141
142
143def ksft_busy_wait(cond, sleep=0.005, deadline=1, comment=""):
144    end = time.monotonic() + deadline
145    while True:
146        if cond():
147            return
148        if time.monotonic() > end:
149            _fail("Waiting for condition timed out", comment)
150            return
151        time.sleep(sleep)
152
153
154def ktap_result(ok, cnt=1, case_name="", comment=""):
155    global KSFT_RESULT_ALL
156    KSFT_RESULT_ALL = KSFT_RESULT_ALL and ok
157
158    res = ""
159    if not ok:
160        res += "not "
161    res += "ok "
162    res += str(cnt) + " "
163    res += KSFT_MAIN_NAME
164    if case_name:
165        res += "." + case_name
166    if comment:
167        res += " # " + comment
168    print(res, flush=True)
169
170
171def _ksft_defer_arm(state):
172    """ Allow or disallow the use of defer() """
173    utils.GLOBAL_DEFER_ARMED = state
174
175
176def ksft_flush_defer():
177    global KSFT_RESULT
178
179    i = 0
180    qlen_start = len(utils.GLOBAL_DEFER_QUEUE)
181    while utils.GLOBAL_DEFER_QUEUE:
182        i += 1
183        entry = utils.GLOBAL_DEFER_QUEUE.pop()
184        try:
185            entry.exec_only()
186        except Exception:
187            ksft_pr(f"Exception while handling defer / cleanup (callback {i} of {qlen_start})!")
188            ksft_pr(traceback.format_exc(), line_pfx="Defer Exception|")
189            KSFT_RESULT = False
190
191
192KsftCaseFunction = namedtuple("KsftCaseFunction",
193                              ['name', 'original_func', 'variants'])
194
195
196def ksft_disruptive(func):
197    """
198    Decorator that marks the test as disruptive (e.g. the test
199    that can down the interface). Disruptive tests can be skipped
200    by passing DISRUPTIVE=False environment variable.
201    """
202
203    @functools.wraps(func)
204    def wrapper(*args, **kwargs):
205        if not KSFT_DISRUPTIVE:
206            raise KsftSkipEx("marked as disruptive")
207        return func(*args, **kwargs)
208    return wrapper
209
210
211class KsftNamedVariant:
212    """ Named string name + argument list tuple for @ksft_variants """
213
214    def __init__(self, name, *params):
215        self.params = params
216        self.name = name or "_".join([str(x) for x in self.params])
217
218
219def ksft_variants(params):
220    """
221    Decorator defining the sets of inputs for a test.
222    The parameters will be included in the name of the resulting sub-case.
223    Parameters can be either single object, tuple or a KsftNamedVariant.
224    The argument can be a list or a generator.
225
226    Example:
227
228    @ksft_variants([
229        (1, "a"),
230        (2, "b"),
231        KsftNamedVariant("three", 3, "c"),
232    ])
233    def my_case(cfg, a, b):
234        pass # ...
235
236    ksft_run(cases=[my_case], args=(cfg, ))
237
238    Will generate cases:
239        my_case.1_a
240        my_case.2_b
241        my_case.three
242    """
243
244    return lambda func: KsftCaseFunction(func.__name__, func, params)
245
246
247def ksft_setup(env):
248    """
249    Setup test framework global state from the environment.
250    """
251
252    def get_bool(env, name):
253        value = env.get(name, "").lower()
254        if value in ["yes", "true"]:
255            return True
256        if value in ["no", "false"]:
257            return False
258        try:
259            return bool(int(value))
260        except Exception:
261            raise Exception(f"failed to parse {name}")
262
263    if "DISRUPTIVE" in env:
264        global KSFT_DISRUPTIVE
265        KSFT_DISRUPTIVE = get_bool(env, "DISRUPTIVE")
266
267    return env
268
269
270def _ksft_intr(signum, frame):
271    # ksft runner.sh sends 2 SIGTERMs in a row on a timeout
272    # if we don't ignore the second one it will stop us from handling cleanup
273    global term_cnt
274    term_cnt += 1
275    if term_cnt == 1:
276        raise KsftTerminate()
277    else:
278        ksft_pr(f"Ignoring SIGTERM (cnt: {term_cnt}), already exiting...")
279
280
281def _ksft_generate_test_cases(cases, globs, case_pfx, args):
282    """Generate a flat list of (func, args, name) tuples"""
283
284    cases = cases or []
285    test_cases = []
286
287    # If using the globs method find all relevant functions
288    if globs and case_pfx:
289        for key, value in globs.items():
290            if not callable(value):
291                continue
292            for prefix in case_pfx:
293                if key.startswith(prefix):
294                    cases.append(value)
295                    break
296
297    for func in cases:
298        if isinstance(func, KsftCaseFunction):
299            # Parametrized test - create case for each param
300            for param in func.variants:
301                if not isinstance(param, KsftNamedVariant):
302                    if not isinstance(param, tuple):
303                        param = (param, )
304                    param = KsftNamedVariant(None, *param)
305
306                test_cases.append((func.original_func,
307                                   (*args, *param.params),
308                                   func.name + "." + param.name))
309        else:
310            test_cases.append((func, args, func.__name__))
311
312    return test_cases
313
314
315def ksft_run(cases=None, globs=None, case_pfx=None, args=()):
316    test_cases = _ksft_generate_test_cases(cases, globs, case_pfx, args)
317
318    global term_cnt
319    term_cnt = 0
320    prev_sigterm = signal.signal(signal.SIGTERM, _ksft_intr)
321
322    totals = {"pass": 0, "fail": 0, "skip": 0, "xfail": 0}
323
324    print("TAP version 13", flush=True)
325    print("1.." + str(len(test_cases)), flush=True)
326
327    global KSFT_RESULT
328    cnt = 0
329    stop = False
330    for func, args, name in test_cases:
331        KSFT_RESULT = True
332        cnt += 1
333        comment = ""
334        cnt_key = ""
335
336        _ksft_defer_arm(True)
337        try:
338            func(*args)
339        except KsftSkipEx as e:
340            comment = "SKIP " + str(e)
341            cnt_key = 'skip'
342        except KsftXfailEx as e:
343            comment = "XFAIL " + str(e)
344            cnt_key = 'xfail'
345        except BaseException as e:
346            stop |= isinstance(e, KeyboardInterrupt)
347            ksft_pr(traceback.format_exc(), line_pfx="Exception|")
348            if stop:
349                ksft_pr(f"Stopping tests due to {type(e).__name__}.")
350            KSFT_RESULT = False
351            cnt_key = 'fail'
352        _ksft_defer_arm(False)
353
354        try:
355            ksft_flush_defer()
356        except BaseException as e:
357            ksft_pr(traceback.format_exc(), line_pfx="Exception|")
358            if isinstance(e, KeyboardInterrupt):
359                ksft_pr()
360                ksft_pr("WARN: defer() interrupted, cleanup may be incomplete.")
361                ksft_pr("      Attempting to finish cleanup before exiting.")
362                ksft_pr("      Interrupt again to exit immediately.")
363                ksft_pr()
364                stop = True
365            # Flush was interrupted, try to finish the job best we can
366            ksft_flush_defer()
367
368        if not cnt_key:
369            cnt_key = 'pass' if KSFT_RESULT else 'fail'
370
371        ktap_result(KSFT_RESULT, cnt, name, comment=comment)
372        totals[cnt_key] += 1
373
374        if stop:
375            break
376
377    signal.signal(signal.SIGTERM, prev_sigterm)
378
379    print(
380        f"# Totals: pass:{totals['pass']} fail:{totals['fail']} xfail:{totals['xfail']} xpass:0 skip:{totals['skip']} error:0"
381    )
382
383
384def ksft_exit():
385    global KSFT_RESULT_ALL
386    sys.exit(0 if KSFT_RESULT_ALL else 1)
387