xref: /linux/tools/lib/python/unittest_helper.py (revision 781171bec0650c00c642564afcb5cce57abda5bf)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: GPL-2.0
3# Copyright(c) 2025-2026: Mauro Carvalho Chehab <mchehab@kernel.org>.
4#
5# pylint: disable=C0103,R0912,R0914,E1101
6
7"""
8Provides helper functions and classes execute python unit tests.
9
10Those help functions provide a nice colored output summary of each
11executed test and, when a test fails, it shows the different in diff
12format when running in verbose mode, like::
13
14    $ tools/unittests/nested_match.py -v
15    ...
16    Traceback (most recent call last):
17    File "/new_devel/docs/tools/unittests/nested_match.py", line 69, in test_count_limit
18        self.assertEqual(replaced, "bar(a); bar(b); foo(c)")
19        ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
20    AssertionError: 'bar(a) foo(b); foo(c)' != 'bar(a); bar(b); foo(c)'
21    - bar(a) foo(b); foo(c)
22    ?       ^^^^
23    + bar(a); bar(b); foo(c)
24    ?       ^^^^^
25    ...
26
27It also allows filtering what tests will be executed via ``-k`` parameter.
28
29Typical usage is to do::
30
31    from unittest_helper import run_unittest
32    ...
33
34    if __name__ == "__main__":
35        run_unittest(__file__)
36
37If passing arguments is needed, on a more complex scenario, it can be
38used like on this example::
39
40    from unittest_helper import TestUnits, run_unittest
41    ...
42    env = {'sudo': ""}
43    ...
44    if __name__ == "__main__":
45        runner = TestUnits()
46        base_parser = runner.parse_args()
47        base_parser.add_argument('--sudo', action='store_true',
48                                help='Enable tests requiring sudo privileges')
49
50        args = base_parser.parse_args()
51
52        # Update module-level flag
53        if args.sudo:
54            env['sudo'] = "1"
55
56        # Run tests with customized arguments
57        runner.run(__file__, parser=base_parser, args=args, env=env)
58"""
59
60import argparse
61import atexit
62import os
63import re
64import unittest
65import sys
66
67from unittest.mock import patch
68
69
70class Summary(unittest.TestResult):
71    """
72    Overrides ``unittest.TestResult`` class to provide a nice colored
73    summary. When in verbose mode, displays actual/expected difference in
74    unified diff format.
75    """
76    def __init__(self, *args, **kwargs):
77        super().__init__(*args, **kwargs)
78
79        #: Dictionary to store organized test results.
80        self.test_results = {}
81
82        #: max length of the test names.
83        self.max_name_length = 0
84
85    def startTest(self, test):
86        super().startTest(test)
87        test_id = test.id()
88        parts = test_id.split(".")
89
90        # Extract module, class, and method names
91        if len(parts) >= 3:
92            module_name = parts[-3]
93        else:
94            module_name = ""
95        if len(parts) >= 2:
96            class_name = parts[-2]
97        else:
98            class_name = ""
99
100        method_name = parts[-1]
101
102        # Build the hierarchical structure
103        if module_name not in self.test_results:
104            self.test_results[module_name] = {}
105
106        if class_name not in self.test_results[module_name]:
107            self.test_results[module_name][class_name] = []
108
109        # Track maximum test name length for alignment
110        display_name = f"{method_name}:"
111
112        self.max_name_length = max(len(display_name), self.max_name_length)
113
114    def _record_test(self, test, status):
115        test_id = test.id()
116        parts = test_id.split(".")
117        if len(parts) >= 3:
118            module_name = parts[-3]
119        else:
120            module_name = ""
121        if len(parts) >= 2:
122            class_name = parts[-2]
123        else:
124            class_name = ""
125        method_name = parts[-1]
126        self.test_results[module_name][class_name].append((method_name, status))
127
128    def addSuccess(self, test):
129        super().addSuccess(test)
130        self._record_test(test, "OK")
131
132    def addFailure(self, test, err):
133        super().addFailure(test, err)
134        self._record_test(test, "FAIL")
135
136    def addError(self, test, err):
137        super().addError(test, err)
138        self._record_test(test, "ERROR")
139
140    def addSkip(self, test, reason):
141        super().addSkip(test, reason)
142        self._record_test(test, f"SKIP ({reason})")
143
144    def printResults(self):
145        """
146        Print results using colors if tty.
147        """
148        # Check for ANSI color support
149        use_color = sys.stdout.isatty()
150        COLORS = {
151            "OK":            "\033[32m",   # Green
152            "FAIL":          "\033[31m",   # Red
153            "SKIP":          "\033[1;33m", # Yellow
154            "PARTIAL":       "\033[33m",   # Orange
155            "EXPECTED_FAIL": "\033[36m",   # Cyan
156            "reset":         "\033[0m",    # Reset to default terminal color
157        }
158        if not use_color:
159            for c in COLORS:
160                COLORS[c] = ""
161
162        # Calculate maximum test name length
163        if not self.test_results:
164            return
165        try:
166            lengths = []
167            for module in self.test_results.values():
168                for tests in module.values():
169                    for test_name, _ in tests:
170                        lengths.append(len(test_name) + 1)  # +1 for colon
171            max_length = max(lengths) + 2  # Additional padding
172        except ValueError:
173            sys.exit("Test list is empty")
174
175        # Print results
176        for module_name, classes in self.test_results.items():
177            print(f"{module_name}:")
178            for class_name, tests in classes.items():
179                print(f"    {class_name}:")
180                for test_name, status in tests:
181                    # Get base status without reason for SKIP
182                    if status.startswith("SKIP"):
183                        status_code = status.split()[0]
184                    else:
185                        status_code = status
186                    color = COLORS.get(status_code, "")
187                    print(
188                        f"        {test_name + ':':<{max_length}}{color}{status}{COLORS['reset']}"
189                    )
190            print()
191
192        # Print summary
193        print(f"\nRan {self.testsRun} tests", end="")
194        if hasattr(self, "timeTaken"):
195            print(f" in {self.timeTaken:.3f}s", end="")
196        print()
197
198        if not self.wasSuccessful():
199            print(f"\n{COLORS['FAIL']}FAILED (", end="")
200            failures = getattr(self, "failures", [])
201            errors = getattr(self, "errors", [])
202            if failures:
203                print(f"failures={len(failures)}", end="")
204            if errors:
205                if failures:
206                    print(", ", end="")
207                print(f"errors={len(errors)}", end="")
208            print(f"){COLORS['reset']}")
209
210
211def flatten_suite(suite):
212    """Flatten test suite hierarchy."""
213    tests = []
214    for item in suite:
215        if isinstance(item, unittest.TestSuite):
216            tests.extend(flatten_suite(item))
217        else:
218            tests.append(item)
219    return tests
220
221
222class TestUnits:
223    """
224    Helper class to set verbosity level.
225
226    This class discover test files, import its unittest classes and
227    executes the test on it.
228    """
229    def parse_args(self):
230        """Returns a parser for command line arguments."""
231        parser = argparse.ArgumentParser(description="Test runner with regex filtering")
232        parser.add_argument("-v", "--verbose", action="count", default=1)
233        parser.add_argument("-f", "--failfast", action="store_true")
234        parser.add_argument("-k", "--keyword",
235                            help="Regex pattern to filter test methods")
236        return parser
237
238    def run(self, caller_file=None, pattern=None,
239            suite=None, parser=None, args=None, env=None):
240        """
241        Execute all tests from the unity test file.
242
243        It contains several optional parameters:
244
245        ``caller_file``:
246            -  name of the file that contains test.
247
248               typical usage is to place __file__ at the caller test, e.g.::
249
250                    if __name__ == "__main__":
251                        TestUnits().run(__file__)
252
253        ``pattern``:
254            - optional pattern to match multiple file names. Defaults
255              to basename of ``caller_file``.
256
257        ``suite``:
258            - an unittest suite initialized by the caller using
259              ``unittest.TestLoader().discover()``.
260
261        ``parser``:
262            - an argparse parser. If not defined, this helper will create
263              one.
264
265        ``args``:
266            - an ``argparse.Namespace`` data filled by the caller.
267
268        ``env``:
269            - environment variables that will be passed to the test suite
270
271        At least ``caller_file`` or ``suite`` must be used, otherwise a
272        ``TypeError`` will be raised.
273        """
274        if not args:
275            if not parser:
276                parser = self.parse_args()
277            args = parser.parse_args()
278
279        if not caller_file and not suite:
280            raise TypeError("Either caller_file or suite is needed at TestUnits")
281
282        verbose = args.verbose
283
284        if not env:
285            env = os.environ.copy()
286
287        env["VERBOSE"] = f"{verbose}"
288
289        patcher = patch.dict(os.environ, env)
290        patcher.start()
291        # ensure it gets stopped after
292        atexit.register(patcher.stop)
293
294
295        if verbose >= 2:
296            unittest.TextTestRunner(verbosity=verbose).run = lambda suite: suite
297
298        # Load ONLY tests from the calling file
299        if not suite:
300            if not pattern:
301                pattern = caller_file
302
303            loader = unittest.TestLoader()
304            suite = loader.discover(start_dir=os.path.dirname(caller_file),
305                                    pattern=os.path.basename(caller_file))
306
307        # Flatten the suite for environment injection
308        tests_to_inject = flatten_suite(suite)
309
310        # Filter tests by method name if -k specified
311        if args.keyword:
312            try:
313                pattern = re.compile(args.keyword)
314                filtered_suite = unittest.TestSuite()
315                for test in tests_to_inject:  # Use the pre-flattened list
316                    method_name = test.id().split(".")[-1]
317                    if pattern.search(method_name):
318                        filtered_suite.addTest(test)
319                suite = filtered_suite
320            except re.error as e:
321                sys.stderr.write(f"Invalid regex pattern: {e}\n")
322                sys.exit(1)
323        else:
324            # Maintain original suite structure if no keyword filtering
325            suite = unittest.TestSuite(tests_to_inject)
326
327        if verbose >= 2:
328            resultclass = None
329        else:
330            resultclass = Summary
331
332        runner = unittest.TextTestRunner(verbosity=args.verbose,
333                                            resultclass=resultclass,
334                                            failfast=args.failfast)
335        result = runner.run(suite)
336        if resultclass:
337            result.printResults()
338
339        sys.exit(not result.wasSuccessful())
340
341
342def run_unittest(fname):
343    """
344    Basic usage of TestUnits class.
345
346    Use it when there's no need to pass any extra argument to the tests
347    with. The recommended way is to place this at the end of each
348    unittest module::
349
350        if __name__ == "__main__":
351            run_unittest(__file__)
352    """
353    TestUnits().run(fname)
354