1#!/usr/bin/env python3 2# SPDX-License-Identifier: GPL-2.0 3# Copyright(c) 2025-2026: Mauro Carvalho Chehab <mchehab@kernel.org>. 4# 5# pylint: disable=C0103,R0912,R0914,E1101 6 7""" 8Provides helper functions and classes execute python unit tests. 9 10Those help functions provide a nice colored output summary of each 11executed test and, when a test fails, it shows the different in diff 12format when running in verbose mode, like:: 13 14 $ tools/unittests/nested_match.py -v 15 ... 16 Traceback (most recent call last): 17 File "/new_devel/docs/tools/unittests/nested_match.py", line 69, in test_count_limit 18 self.assertEqual(replaced, "bar(a); bar(b); foo(c)") 19 ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 20 AssertionError: 'bar(a) foo(b); foo(c)' != 'bar(a); bar(b); foo(c)' 21 - bar(a) foo(b); foo(c) 22 ? ^^^^ 23 + bar(a); bar(b); foo(c) 24 ? ^^^^^ 25 ... 26 27It also allows filtering what tests will be executed via ``-k`` parameter. 28 29Typical usage is to do:: 30 31 from unittest_helper import run_unittest 32 ... 33 34 if __name__ == "__main__": 35 run_unittest(__file__) 36 37If passing arguments is needed, on a more complex scenario, it can be 38used like on this example:: 39 40 from unittest_helper import TestUnits, run_unittest 41 ... 42 env = {'sudo': ""} 43 ... 44 if __name__ == "__main__": 45 runner = TestUnits() 46 base_parser = runner.parse_args() 47 base_parser.add_argument('--sudo', action='store_true', 48 help='Enable tests requiring sudo privileges') 49 50 args = base_parser.parse_args() 51 52 # Update module-level flag 53 if args.sudo: 54 env['sudo'] = "1" 55 56 # Run tests with customized arguments 57 runner.run(__file__, parser=base_parser, args=args, env=env) 58""" 59 60import argparse 61import atexit 62import os 63import re 64import unittest 65import sys 66 67from unittest.mock import patch 68 69 70class Summary(unittest.TestResult): 71 """ 72 Overrides ``unittest.TestResult`` class to provide a nice colored 73 summary. When in verbose mode, displays actual/expected difference in 74 unified diff format. 75 """ 76 def __init__(self, *args, **kwargs): 77 super().__init__(*args, **kwargs) 78 79 #: Dictionary to store organized test results. 80 self.test_results = {} 81 82 #: max length of the test names. 83 self.max_name_length = 0 84 85 def startTest(self, test): 86 super().startTest(test) 87 test_id = test.id() 88 parts = test_id.split(".") 89 90 # Extract module, class, and method names 91 if len(parts) >= 3: 92 module_name = parts[-3] 93 else: 94 module_name = "" 95 if len(parts) >= 2: 96 class_name = parts[-2] 97 else: 98 class_name = "" 99 100 method_name = parts[-1] 101 102 # Build the hierarchical structure 103 if module_name not in self.test_results: 104 self.test_results[module_name] = {} 105 106 if class_name not in self.test_results[module_name]: 107 self.test_results[module_name][class_name] = [] 108 109 # Track maximum test name length for alignment 110 display_name = f"{method_name}:" 111 112 self.max_name_length = max(len(display_name), self.max_name_length) 113 114 def _record_test(self, test, status): 115 test_id = test.id() 116 parts = test_id.split(".") 117 if len(parts) >= 3: 118 module_name = parts[-3] 119 else: 120 module_name = "" 121 if len(parts) >= 2: 122 class_name = parts[-2] 123 else: 124 class_name = "" 125 method_name = parts[-1] 126 self.test_results[module_name][class_name].append((method_name, status)) 127 128 def addSuccess(self, test): 129 super().addSuccess(test) 130 self._record_test(test, "OK") 131 132 def addFailure(self, test, err): 133 super().addFailure(test, err) 134 self._record_test(test, "FAIL") 135 136 def addError(self, test, err): 137 super().addError(test, err) 138 self._record_test(test, "ERROR") 139 140 def addSkip(self, test, reason): 141 super().addSkip(test, reason) 142 self._record_test(test, f"SKIP ({reason})") 143 144 def printResults(self): 145 """ 146 Print results using colors if tty. 147 """ 148 # Check for ANSI color support 149 use_color = sys.stdout.isatty() 150 COLORS = { 151 "OK": "\033[32m", # Green 152 "FAIL": "\033[31m", # Red 153 "SKIP": "\033[1;33m", # Yellow 154 "PARTIAL": "\033[33m", # Orange 155 "EXPECTED_FAIL": "\033[36m", # Cyan 156 "reset": "\033[0m", # Reset to default terminal color 157 } 158 if not use_color: 159 for c in COLORS: 160 COLORS[c] = "" 161 162 # Calculate maximum test name length 163 if not self.test_results: 164 return 165 try: 166 lengths = [] 167 for module in self.test_results.values(): 168 for tests in module.values(): 169 for test_name, _ in tests: 170 lengths.append(len(test_name) + 1) # +1 for colon 171 max_length = max(lengths) + 2 # Additional padding 172 except ValueError: 173 sys.exit("Test list is empty") 174 175 # Print results 176 for module_name, classes in self.test_results.items(): 177 print(f"{module_name}:") 178 for class_name, tests in classes.items(): 179 print(f" {class_name}:") 180 for test_name, status in tests: 181 # Get base status without reason for SKIP 182 if status.startswith("SKIP"): 183 status_code = status.split()[0] 184 else: 185 status_code = status 186 color = COLORS.get(status_code, "") 187 print( 188 f" {test_name + ':':<{max_length}}{color}{status}{COLORS['reset']}" 189 ) 190 print() 191 192 # Print summary 193 print(f"\nRan {self.testsRun} tests", end="") 194 if hasattr(self, "timeTaken"): 195 print(f" in {self.timeTaken:.3f}s", end="") 196 print() 197 198 if not self.wasSuccessful(): 199 print(f"\n{COLORS['FAIL']}FAILED (", end="") 200 failures = getattr(self, "failures", []) 201 errors = getattr(self, "errors", []) 202 if failures: 203 print(f"failures={len(failures)}", end="") 204 if errors: 205 if failures: 206 print(", ", end="") 207 print(f"errors={len(errors)}", end="") 208 print(f"){COLORS['reset']}") 209 210 211def flatten_suite(suite): 212 """Flatten test suite hierarchy.""" 213 tests = [] 214 for item in suite: 215 if isinstance(item, unittest.TestSuite): 216 tests.extend(flatten_suite(item)) 217 else: 218 tests.append(item) 219 return tests 220 221 222class TestUnits: 223 """ 224 Helper class to set verbosity level. 225 226 This class discover test files, import its unittest classes and 227 executes the test on it. 228 """ 229 def parse_args(self): 230 """Returns a parser for command line arguments.""" 231 parser = argparse.ArgumentParser(description="Test runner with regex filtering") 232 parser.add_argument("-v", "--verbose", action="count", default=1) 233 parser.add_argument("-f", "--failfast", action="store_true") 234 parser.add_argument("-k", "--keyword", 235 help="Regex pattern to filter test methods") 236 return parser 237 238 def run(self, caller_file=None, pattern=None, 239 suite=None, parser=None, args=None, env=None): 240 """ 241 Execute all tests from the unity test file. 242 243 It contains several optional parameters: 244 245 ``caller_file``: 246 - name of the file that contains test. 247 248 typical usage is to place __file__ at the caller test, e.g.:: 249 250 if __name__ == "__main__": 251 TestUnits().run(__file__) 252 253 ``pattern``: 254 - optional pattern to match multiple file names. Defaults 255 to basename of ``caller_file``. 256 257 ``suite``: 258 - an unittest suite initialized by the caller using 259 ``unittest.TestLoader().discover()``. 260 261 ``parser``: 262 - an argparse parser. If not defined, this helper will create 263 one. 264 265 ``args``: 266 - an ``argparse.Namespace`` data filled by the caller. 267 268 ``env``: 269 - environment variables that will be passed to the test suite 270 271 At least ``caller_file`` or ``suite`` must be used, otherwise a 272 ``TypeError`` will be raised. 273 """ 274 if not args: 275 if not parser: 276 parser = self.parse_args() 277 args = parser.parse_args() 278 279 if not caller_file and not suite: 280 raise TypeError("Either caller_file or suite is needed at TestUnits") 281 282 verbose = args.verbose 283 284 if not env: 285 env = os.environ.copy() 286 287 env["VERBOSE"] = f"{verbose}" 288 289 patcher = patch.dict(os.environ, env) 290 patcher.start() 291 # ensure it gets stopped after 292 atexit.register(patcher.stop) 293 294 295 if verbose >= 2: 296 unittest.TextTestRunner(verbosity=verbose).run = lambda suite: suite 297 298 # Load ONLY tests from the calling file 299 if not suite: 300 if not pattern: 301 pattern = caller_file 302 303 loader = unittest.TestLoader() 304 suite = loader.discover(start_dir=os.path.dirname(caller_file), 305 pattern=os.path.basename(caller_file)) 306 307 # Flatten the suite for environment injection 308 tests_to_inject = flatten_suite(suite) 309 310 # Filter tests by method name if -k specified 311 if args.keyword: 312 try: 313 pattern = re.compile(args.keyword) 314 filtered_suite = unittest.TestSuite() 315 for test in tests_to_inject: # Use the pre-flattened list 316 method_name = test.id().split(".")[-1] 317 if pattern.search(method_name): 318 filtered_suite.addTest(test) 319 suite = filtered_suite 320 except re.error as e: 321 sys.stderr.write(f"Invalid regex pattern: {e}\n") 322 sys.exit(1) 323 else: 324 # Maintain original suite structure if no keyword filtering 325 suite = unittest.TestSuite(tests_to_inject) 326 327 if verbose >= 2: 328 resultclass = None 329 else: 330 resultclass = Summary 331 332 runner = unittest.TextTestRunner(verbosity=args.verbose, 333 resultclass=resultclass, 334 failfast=args.failfast) 335 result = runner.run(suite) 336 if resultclass: 337 result.printResults() 338 339 sys.exit(not result.wasSuccessful()) 340 341 342def run_unittest(fname): 343 """ 344 Basic usage of TestUnits class. 345 346 Use it when there's no need to pass any extra argument to the tests 347 with. The recommended way is to place this at the end of each 348 unittest module:: 349 350 if __name__ == "__main__": 351 run_unittest(__file__) 352 """ 353 TestUnits().run(fname) 354