xref: /linux/tools/testing/selftests/net/lib/py/ksft.py (revision e46ff213f7a5f5aaebd6bca589517844aa0fe73a)
1# SPDX-License-Identifier: GPL-2.0
2
3import fnmatch
4import functools
5import getopt
6import inspect
7import os
8import signal
9import sys
10import time
11import traceback
12from collections import namedtuple
13from .consts import KSFT_MAIN_NAME
14from . import utils
15
16KSFT_RESULT = None
17KSFT_RESULT_ALL = True
18KSFT_DISRUPTIVE = True
19
20
21class KsftFailEx(Exception):
22    pass
23
24
25class KsftSkipEx(Exception):
26    pass
27
28
29class KsftXfailEx(Exception):
30    pass
31
32
33class KsftTerminate(KeyboardInterrupt):
34    pass
35
36
37class _KsftArgs:
38    def __init__(self):
39        self.list_tests = False
40        self.filters = []
41
42        try:
43            opts, _ = getopt.getopt(sys.argv[1:], 'hlt:T:')
44        except getopt.GetoptError as e:
45            print(e, file=sys.stderr)
46            sys.exit(1)
47
48        for opt, val in opts:
49            if opt == '-h':
50                print(f"Usage: {sys.argv[0]} [-h|-l] [-t|-T name]\n"
51                      f"\t-h       print help\n"
52                      f"\t-l       list tests (filtered, if filters were specified)\n"
53                      f"\t-t name  include test\n"
54                      f"\t-T name  exclude test",
55                      file=sys.stderr)
56                sys.exit(0)
57            elif opt == '-l':
58                self.list_tests = True
59            elif opt == '-t':
60                self.filters.append((True, val))
61            elif opt == '-T':
62                self.filters.append((False, val))
63
64
65@functools.lru_cache()
66def _ksft_supports_color():
67    if os.environ.get("NO_COLOR") is not None:
68        return False
69    if not hasattr(sys.stdout, "isatty") or not sys.stdout.isatty():
70        return False
71    if os.environ.get("TERM") == "dumb":
72        return False
73    return True
74
75
76def ksft_pr(*objs, **kwargs):
77    """
78    Print logs to stdout.
79
80    Behaves like print() but log lines will be prefixed
81    with # to prevent breaking the TAP output formatting.
82
83    Extra arguments (on top of what print() supports):
84      line_pfx - add extra string before each line
85    """
86    sep = kwargs.pop("sep", " ")
87    pfx = kwargs.pop("line_pfx", "")
88    pfx = "#" + (" " + pfx if pfx else "")
89    kwargs["flush"] = True
90
91    text = sep.join(str(obj) for obj in objs)
92    prefixed = f"\n{pfx} ".join(text.split('\n'))
93    print(pfx, prefixed, **kwargs)
94
95
96def _fail(*args):
97    global KSFT_RESULT
98    KSFT_RESULT = False
99
100    stack = inspect.stack()
101    started = False
102    for frame in reversed(stack[2:]):
103        # Start printing from the test case function
104        if not started:
105            if frame.function == 'ksft_run':
106                started = True
107            continue
108
109        ksft_pr("Check| At " + frame.filename + ", line " + str(frame.lineno) +
110                ", in " + frame.function + ":")
111        ksft_pr("Check|     " + frame.code_context[0].strip())
112    ksft_pr(*args)
113
114
115def ksft_eq(a, b, comment=""):
116    global KSFT_RESULT
117    if a != b:
118        _fail("Check failed", a, "!=", b, comment)
119
120
121def ksft_ne(a, b, comment=""):
122    global KSFT_RESULT
123    if a == b:
124        _fail("Check failed", a, "==", b, comment)
125
126
127def ksft_true(a, comment=""):
128    if not a:
129        _fail("Check failed", a, "does not eval to True", comment)
130
131
132def ksft_not_none(a, comment=""):
133    if a is None:
134        _fail("Check failed", a, "is None", comment)
135
136
137def ksft_in(a, b, comment=""):
138    if a not in b:
139        _fail("Check failed", a, "not in", b, comment)
140
141
142def ksft_not_in(a, b, comment=""):
143    if a in b:
144        _fail("Check failed", a, "in", b, comment)
145
146
147def ksft_is(a, b, comment=""):
148    if a is not b:
149        _fail("Check failed", a, "is not", b, comment)
150
151
152def ksft_ge(a, b, comment=""):
153    if a < b:
154        _fail("Check failed", a, "<", b, comment)
155
156
157def ksft_gt(a, b, comment=""):
158    if a <= b:
159        _fail("Check failed", a, "<=", b, comment)
160
161
162def ksft_lt(a, b, comment=""):
163    if a >= b:
164        _fail("Check failed", a, ">=", b, comment)
165
166
167class ksft_raises:
168    def __init__(self, expected_type):
169        self.exception = None
170        self.expected_type = expected_type
171
172    def __enter__(self):
173        return self
174
175    def __exit__(self, exc_type, exc_val, exc_tb):
176        if exc_type is None:
177            _fail(f"Expected exception {str(self.expected_type.__name__)}, none raised")
178        elif self.expected_type != exc_type:
179            _fail(f"Expected exception {str(self.expected_type.__name__)}, raised {str(exc_type.__name__)}")
180        self.exception = exc_val
181        # Suppress the exception if its the expected one
182        return self.expected_type == exc_type
183
184
185def ksft_busy_wait(cond, sleep=0.005, deadline=1, comment=""):
186    end = time.monotonic() + deadline
187    while True:
188        if cond():
189            return
190        if time.monotonic() > end:
191            _fail("Waiting for condition timed out", comment)
192            return
193        time.sleep(sleep)
194
195
196def ktap_result(ok, cnt=1, case_name="", comment=""):
197    global KSFT_RESULT_ALL
198    KSFT_RESULT_ALL = KSFT_RESULT_ALL and ok
199
200    res = ""
201    if not ok:
202        res += "not "
203    res += "ok "
204    res += str(cnt) + " "
205    res += KSFT_MAIN_NAME
206    if case_name:
207        res += "." + case_name
208    if comment:
209        res += " # " + comment
210    if _ksft_supports_color():
211        if comment.startswith(("SKIP", "XFAIL")):
212            color = "\033[33m"
213        elif ok:
214            color = "\033[32m"
215        else:
216            color = "\033[31m"
217        res = color + res + "\033[0m"
218    print(res, flush=True)
219
220
221def _ksft_defer_arm(state):
222    """ Allow or disallow the use of defer() """
223    utils.GLOBAL_DEFER_ARMED = state
224
225
226def ksft_flush_defer():
227    global KSFT_RESULT
228
229    i = 0
230    qlen_start = len(utils.GLOBAL_DEFER_QUEUE)
231    while utils.GLOBAL_DEFER_QUEUE:
232        i += 1
233        entry = utils.GLOBAL_DEFER_QUEUE.pop()
234        try:
235            entry.exec_only()
236        except Exception:
237            ksft_pr(f"Exception while handling defer / cleanup (callback {i} of {qlen_start})!")
238            ksft_pr(traceback.format_exc(), line_pfx="Defer Exception|")
239            KSFT_RESULT = False
240
241
242KsftCaseFunction = namedtuple("KsftCaseFunction",
243                              ['name', 'original_func', 'variants'])
244
245
246def ksft_disruptive(func):
247    """
248    Decorator that marks the test as disruptive (e.g. the test
249    that can down the interface). Disruptive tests can be skipped
250    by passing DISRUPTIVE=False environment variable.
251    """
252
253    @functools.wraps(func)
254    def wrapper(*args, **kwargs):
255        if not KSFT_DISRUPTIVE:
256            raise KsftSkipEx("marked as disruptive")
257        return func(*args, **kwargs)
258    return wrapper
259
260
261class KsftNamedVariant:
262    """ Named string name + argument list tuple for @ksft_variants """
263
264    def __init__(self, name, *params):
265        self.params = params
266        self.name = name or "_".join([str(x) for x in self.params])
267
268
269def ksft_variants(params):
270    """
271    Decorator defining the sets of inputs for a test.
272    The parameters will be included in the name of the resulting sub-case.
273    Parameters can be either single object, tuple or a KsftNamedVariant.
274    The argument can be a list or a generator.
275
276    Example:
277
278    @ksft_variants([
279        (1, "a"),
280        (2, "b"),
281        KsftNamedVariant("three", 3, "c"),
282    ])
283    def my_case(cfg, a, b):
284        pass # ...
285
286    ksft_run(cases=[my_case], args=(cfg, ))
287
288    Will generate cases:
289        my_case.1_a
290        my_case.2_b
291        my_case.three
292    """
293
294    return lambda func: KsftCaseFunction(func.__name__, func, params)
295
296
297def ksft_setup(env):
298    """
299    Setup test framework global state from the environment.
300    """
301
302    def get_bool(env, name):
303        value = env.get(name, "").lower()
304        if value in ["yes", "true"]:
305            return True
306        if value in ["no", "false"]:
307            return False
308        try:
309            return bool(int(value))
310        except Exception:
311            raise Exception(f"failed to parse {name}")
312
313    if "DISRUPTIVE" in env:
314        global KSFT_DISRUPTIVE
315        KSFT_DISRUPTIVE = get_bool(env, "DISRUPTIVE")
316
317    return env
318
319
320def _ksft_intr(signum, frame):
321    # ksft runner.sh sends 2 SIGTERMs in a row on a timeout
322    # if we don't ignore the second one it will stop us from handling cleanup
323    global term_cnt
324    term_cnt += 1
325    if term_cnt == 1:
326        raise KsftTerminate()
327    else:
328        ksft_pr(f"Ignoring SIGTERM (cnt: {term_cnt}), already exiting...")
329
330
331def _ksft_name_matches(name, pattern):
332    if '*' in pattern or '?' in pattern or '[' in pattern:
333        return fnmatch.fnmatchcase(name, pattern)
334    return name == pattern
335
336
337def _ksft_test_enabled(name, filters):
338    has_positive = False
339    for include, pattern in filters:
340        has_positive |= include
341        if _ksft_name_matches(name, pattern):
342            return include
343    return not has_positive
344
345
346def _ksft_generate_test_cases(cases, globs, case_pfx, args, cli_args):
347    """Generate a filtered list of (func, args, name) tuples.
348
349    If -l is given, prints matching test names and exits.
350    """
351
352    cases = cases or []
353    test_cases = []
354
355    # If using the globs method find all relevant functions
356    if globs and case_pfx:
357        for key, value in globs.items():
358            if not callable(value):
359                continue
360            for prefix in case_pfx:
361                if key.startswith(prefix):
362                    cases.append(value)
363                    break
364
365    for func in cases:
366        if isinstance(func, KsftCaseFunction):
367            # Parametrized test - create case for each param
368            for param in func.variants:
369                if not isinstance(param, KsftNamedVariant):
370                    if not isinstance(param, tuple):
371                        param = (param, )
372                    param = KsftNamedVariant(None, *param)
373
374                test_cases.append((func.original_func,
375                                   (*args, *param.params),
376                                   func.name + "." + param.name))
377        else:
378            test_cases.append((func, args, func.__name__))
379
380    if cli_args.filters:
381        test_cases = [tc for tc in test_cases
382                      if _ksft_test_enabled(tc[2], cli_args.filters)]
383
384    if cli_args.list_tests:
385        for _, _, name in test_cases:
386            print(name)
387        sys.exit(0)
388
389    return test_cases
390
391
392def ksft_run(cases=None, globs=None, case_pfx=None, args=()):
393    cli_args = _KsftArgs()
394    test_cases = _ksft_generate_test_cases(cases, globs, case_pfx, args,
395                                           cli_args)
396
397    global term_cnt
398    term_cnt = 0
399    prev_sigterm = signal.signal(signal.SIGTERM, _ksft_intr)
400
401    totals = {"pass": 0, "fail": 0, "skip": 0, "xfail": 0}
402
403    global KSFT_RESULT
404    if KSFT_RESULT is not None:
405        raise RuntimeError("ksft_run() can't be called multiple times.")
406
407    print("TAP version 13", flush=True)
408    print("1.." + str(len(test_cases)), flush=True)
409
410    cnt = 0
411    stop = False
412    for func, args, name in test_cases:
413        KSFT_RESULT = True
414        cnt += 1
415        comment = ""
416        cnt_key = ""
417
418        _ksft_defer_arm(True)
419        try:
420            func(*args)
421        except KsftSkipEx as e:
422            comment = "SKIP " + str(e)
423            cnt_key = 'skip'
424        except KsftXfailEx as e:
425            comment = "XFAIL " + str(e)
426            cnt_key = 'xfail'
427        except BaseException as e:
428            stop |= isinstance(e, KeyboardInterrupt)
429            ksft_pr(traceback.format_exc(), line_pfx="Exception|")
430            if stop:
431                ksft_pr(f"Stopping tests due to {type(e).__name__}.")
432            KSFT_RESULT = False
433            cnt_key = 'fail'
434        _ksft_defer_arm(False)
435
436        try:
437            ksft_flush_defer()
438        except BaseException as e:
439            ksft_pr(traceback.format_exc(), line_pfx="Exception|")
440            if isinstance(e, KeyboardInterrupt):
441                ksft_pr()
442                ksft_pr("WARN: defer() interrupted, cleanup may be incomplete.")
443                ksft_pr("      Attempting to finish cleanup before exiting.")
444                ksft_pr("      Interrupt again to exit immediately.")
445                ksft_pr()
446                stop = True
447            # Flush was interrupted, try to finish the job best we can
448            ksft_flush_defer()
449
450        if not cnt_key:
451            cnt_key = 'pass' if KSFT_RESULT else 'fail'
452
453        ktap_result(KSFT_RESULT, cnt, name, comment=comment)
454        totals[cnt_key] += 1
455
456        if stop:
457            break
458
459    signal.signal(signal.SIGTERM, prev_sigterm)
460
461    print(
462        f"# Totals: pass:{totals['pass']} fail:{totals['fail']} xfail:{totals['xfail']} xpass:0 skip:{totals['skip']} error:0"
463    )
464
465
466def ksft_exit():
467    global KSFT_RESULT_ALL
468    sys.exit(0 if KSFT_RESULT_ALL else 1)
469