xref: /linux/tools/testing/selftests/tc-testing/tdc.py (revision 04317b129e4eb5c6f4a58bb899b2019c1545320b)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: GPL-2.0
3
4"""
5tdc.py - Linux tc (Traffic Control) unit test driver
6
7Copyright (C) 2017 Lucas Bates <lucasb@mojatatu.com>
8"""
9
10import re
11import os
12import sys
13import argparse
14import importlib
15import json
16import subprocess
17import time
18import traceback
19import random
20from multiprocessing import Pool
21from collections import OrderedDict
22from string import Template
23
24from tdc_config import *
25from tdc_helper import *
26
27import TdcPlugin
28from TdcResults import *
29
30class PluginDependencyException(Exception):
31    def __init__(self, missing_pg):
32        self.missing_pg = missing_pg
33
34class PluginMgrTestFail(Exception):
35    def __init__(self, stage, output, message):
36        self.stage = stage
37        self.output = output
38        self.message = message
39
40class PluginMgr:
41    def __init__(self, argparser):
42        super().__init__()
43        self.plugins = set()
44        self.plugin_instances = []
45        self.failed_plugins = {}
46        self.argparser = argparser
47
48        plugindir = os.getenv('TDC_PLUGIN_DIR', './plugins')
49        for dirpath, dirnames, filenames in os.walk(plugindir):
50            for fn in filenames:
51                if (fn.endswith('.py') and
52                    not fn == '__init__.py' and
53                    not fn.startswith('#') and
54                    not fn.startswith('.#')):
55                    mn = fn[0:-3]
56                    foo = importlib.import_module('plugins.' + mn)
57                    self.plugins.add(mn)
58                    self.plugin_instances[mn] = foo.SubPlugin()
59
60    def load_plugin(self, pgdir, pgname):
61        pgname = pgname[0:-3]
62        self.plugins.add(pgname)
63
64        foo = importlib.import_module('{}.{}'.format(pgdir, pgname))
65
66        # nsPlugin must always be the first one
67        if pgname == "nsPlugin":
68            self.plugin_instances.insert(0, (pgname, foo.SubPlugin()))
69            self.plugin_instances[0][1].check_args(self.args, None)
70        else:
71            self.plugin_instances.append((pgname, foo.SubPlugin()))
72            self.plugin_instances[-1][1].check_args(self.args, None)
73
74    def get_required_plugins(self, testlist):
75        '''
76        Get all required plugins from the list of test cases and return
77        all unique items.
78        '''
79        reqs = set()
80        for t in testlist:
81            try:
82                if 'requires' in t['plugins']:
83                    if isinstance(t['plugins']['requires'], list):
84                        reqs.update(set(t['plugins']['requires']))
85                    else:
86                        reqs.add(t['plugins']['requires'])
87                    t['plugins'] = t['plugins']['requires']
88                else:
89                    t['plugins'] = []
90            except KeyError:
91                t['plugins'] = []
92                continue
93
94        return reqs
95
96    def load_required_plugins(self, reqs, parser, args, remaining):
97        '''
98        Get all required plugins from the list of test cases and load any plugin
99        that is not already enabled.
100        '''
101        pgd = ['plugin-lib', 'plugin-lib-custom']
102        pnf = []
103
104        for r in reqs:
105            if r not in self.plugins:
106                fname = '{}.py'.format(r)
107                source_path = []
108                for d in pgd:
109                    pgpath = '{}/{}'.format(d, fname)
110                    if os.path.isfile(pgpath):
111                        source_path.append(pgpath)
112                if len(source_path) == 0:
113                    print('ERROR: unable to find required plugin {}'.format(r))
114                    pnf.append(fname)
115                    continue
116                elif len(source_path) > 1:
117                    print('WARNING: multiple copies of plugin {} found, using version found')
118                    print('at {}'.format(source_path[0]))
119                pgdir = source_path[0]
120                pgdir = pgdir.split('/')[0]
121                self.load_plugin(pgdir, fname)
122        if len(pnf) > 0:
123            raise PluginDependencyException(pnf)
124
125        parser = self.call_add_args(parser)
126        (args, remaining) = parser.parse_known_args(args=remaining, namespace=args)
127        return args
128
129    def call_pre_suite(self, testcount, testidlist):
130        for (_, pgn_inst) in self.plugin_instances:
131            pgn_inst.pre_suite(testcount, testidlist)
132
133    def call_post_suite(self, index):
134        for (_, pgn_inst) in reversed(self.plugin_instances):
135            pgn_inst.post_suite(index)
136
137    def call_pre_case(self, caseinfo, *, test_skip=False):
138        for (pgn, pgn_inst) in self.plugin_instances:
139            if pgn not in caseinfo['plugins']:
140                continue
141            try:
142                pgn_inst.pre_case(caseinfo, test_skip)
143            except Exception as ee:
144                print('exception {} in call to pre_case for {} plugin'.
145                      format(ee, pgn_inst.__class__))
146                print('test_ordinal is {}'.format(test_ordinal))
147                print('testid is {}'.format(caseinfo['id']))
148                raise
149
150    def call_post_case(self, caseinfo):
151        for (pgn, pgn_inst) in reversed(self.plugin_instances):
152            if pgn not in caseinfo['plugins']:
153                continue
154            pgn_inst.post_case()
155
156    def call_pre_execute(self, caseinfo):
157        for (pgn, pgn_inst) in self.plugin_instances:
158            if pgn not in caseinfo['plugins']:
159                continue
160            pgn_inst.pre_execute()
161
162    def call_post_execute(self, caseinfo):
163        for (pgn, pgn_inst) in reversed(self.plugin_instances):
164            if pgn not in caseinfo['plugins']:
165                continue
166            pgn_inst.post_execute()
167
168    def call_add_args(self, parser):
169        for (pgn, pgn_inst) in self.plugin_instances:
170            parser = pgn_inst.add_args(parser)
171        return parser
172
173    def call_check_args(self, args, remaining):
174        for (pgn, pgn_inst) in self.plugin_instances:
175            pgn_inst.check_args(args, remaining)
176
177    def call_adjust_command(self, caseinfo, stage, command):
178        for (pgn, pgn_inst) in self.plugin_instances:
179            if pgn not in caseinfo['plugins']:
180                continue
181            command = pgn_inst.adjust_command(stage, command)
182        return command
183
184    def set_args(self, args):
185        self.args = args
186
187    @staticmethod
188    def _make_argparser(args):
189        self.argparser = argparse.ArgumentParser(
190            description='Linux TC unit tests')
191
192def replace_keywords(cmd):
193    """
194    For a given executable command, substitute any known
195    variables contained within NAMES with the correct values
196    """
197    tcmd = Template(cmd)
198    subcmd = tcmd.safe_substitute(NAMES)
199    return subcmd
200
201
202def exec_cmd(caseinfo, args, pm, stage, command):
203    """
204    Perform any required modifications on an executable command, then run
205    it in a subprocess and return the results.
206    """
207    if len(command.strip()) == 0:
208        return None, None
209    if '$' in command:
210        command = replace_keywords(command)
211
212    command = pm.call_adjust_command(caseinfo, stage, command)
213    if args.verbose > 0:
214        print('command "{}"'.format(command))
215
216    proc = subprocess.Popen(command,
217        shell=True,
218        stdout=subprocess.PIPE,
219        stderr=subprocess.PIPE,
220        env=ENVIR)
221
222    try:
223        (rawout, serr) = proc.communicate(timeout=NAMES['TIMEOUT'])
224        if proc.returncode != 0 and len(serr) > 0:
225            foutput = serr.decode("utf-8", errors="ignore")
226        else:
227            foutput = rawout.decode("utf-8", errors="ignore")
228    except subprocess.TimeoutExpired:
229        foutput = "Command \"{}\" timed out\n".format(command)
230        proc.returncode = 255
231
232    proc.stdout.close()
233    proc.stderr.close()
234    return proc, foutput
235
236
237def prepare_env(caseinfo, args, pm, stage, prefix, cmdlist, output = None):
238    """
239    Execute the setup/teardown commands for a test case.
240    Optionally terminate test execution if the command fails.
241    """
242    if args.verbose > 0:
243        print('{}'.format(prefix))
244    for cmdinfo in cmdlist:
245        if isinstance(cmdinfo, list):
246            exit_codes = cmdinfo[1:]
247            cmd = cmdinfo[0]
248        else:
249            exit_codes = [0]
250            cmd = cmdinfo
251
252        if not cmd:
253            continue
254
255        (proc, foutput) = exec_cmd(caseinfo, args, pm, stage, cmd)
256
257        if proc and (proc.returncode not in exit_codes):
258            print('', file=sys.stderr)
259            print("{} *** Could not execute: \"{}\"".format(prefix, cmd),
260                  file=sys.stderr)
261            print("\n{} *** Error message: \"{}\"".format(prefix, foutput),
262                  file=sys.stderr)
263            print("returncode {}; expected {}".format(proc.returncode,
264                                                      exit_codes))
265            print("\n{} *** Aborting test run.".format(prefix), file=sys.stderr)
266            print("\n\n{} *** stdout ***".format(proc.stdout), file=sys.stderr)
267            print("\n\n{} *** stderr ***".format(proc.stderr), file=sys.stderr)
268            raise PluginMgrTestFail(
269                stage, output,
270                '"{}" did not complete successfully'.format(prefix))
271
272def verify_by_json(procout, res, tidx, args, pm):
273    try:
274        outputJSON = json.loads(procout)
275    except json.JSONDecodeError:
276        res.set_result(ResultState.fail)
277        res.set_failmsg('Cannot decode verify command\'s output. Is it JSON?')
278        return res
279
280    matchJSON = json.loads(json.dumps(tidx['matchJSON']))
281
282    if type(outputJSON) != type(matchJSON):
283        failmsg = 'Original output and matchJSON value are not the same type: output: {} != matchJSON: {} '
284        failmsg = failmsg.format(type(outputJSON).__name__, type(matchJSON).__name__)
285        res.set_result(ResultState.fail)
286        res.set_failmsg(failmsg)
287        return res
288
289    if len(matchJSON) > len(outputJSON):
290        failmsg = "Your matchJSON value is an array, and it contains more elements than the command under test\'s output:\ncommand output (length: {}):\n{}\nmatchJSON value (length: {}):\n{}"
291        failmsg = failmsg.format(len(outputJSON), outputJSON, len(matchJSON), matchJSON)
292        res.set_result(ResultState.fail)
293        res.set_failmsg(failmsg)
294        return res
295    res = find_in_json(res, outputJSON, matchJSON, 0)
296
297    return res
298
299def find_in_json(res, outputJSONVal, matchJSONVal, matchJSONKey=None):
300    if res.get_result() == ResultState.fail:
301        return res
302
303    if type(matchJSONVal) == list:
304        res = find_in_json_list(res, outputJSONVal, matchJSONVal, matchJSONKey)
305
306    elif type(matchJSONVal) == dict:
307        res = find_in_json_dict(res, outputJSONVal, matchJSONVal)
308    else:
309        res = find_in_json_other(res, outputJSONVal, matchJSONVal, matchJSONKey)
310
311    if res.get_result() != ResultState.fail:
312        res.set_result(ResultState.success)
313        return res
314
315    return res
316
317def find_in_json_list(res, outputJSONVal, matchJSONVal, matchJSONKey=None):
318    if (type(matchJSONVal) != type(outputJSONVal)):
319        failmsg = 'Original output and matchJSON value are not the same type: output: {} != matchJSON: {}'
320        failmsg = failmsg.format(outputJSONVal, matchJSONVal)
321        res.set_result(ResultState.fail)
322        res.set_failmsg(failmsg)
323        return res
324
325    if len(matchJSONVal) > len(outputJSONVal):
326        failmsg = "Your matchJSON value is an array, and it contains more elements than the command under test\'s output:\ncommand output (length: {}):\n{}\nmatchJSON value (length: {}):\n{}"
327        failmsg = failmsg.format(len(outputJSONVal), outputJSONVal, len(matchJSONVal), matchJSONVal)
328        res.set_result(ResultState.fail)
329        res.set_failmsg(failmsg)
330        return res
331
332    for matchJSONIdx, matchJSONVal in enumerate(matchJSONVal):
333        res = find_in_json(res, outputJSONVal[matchJSONIdx], matchJSONVal,
334                           matchJSONKey)
335    return res
336
337def find_in_json_dict(res, outputJSONVal, matchJSONVal):
338    for matchJSONKey, matchJSONVal in matchJSONVal.items():
339        if type(outputJSONVal) == dict:
340            if matchJSONKey not in outputJSONVal:
341                failmsg = 'Key not found in json output: {}: {}\nMatching against output: {}'
342                failmsg = failmsg.format(matchJSONKey, matchJSONVal, outputJSONVal)
343                res.set_result(ResultState.fail)
344                res.set_failmsg(failmsg)
345                return res
346
347        else:
348            failmsg = 'Original output and matchJSON value are not the same type: output: {} != matchJSON: {}'
349            failmsg = failmsg.format(type(outputJSON).__name__, type(matchJSON).__name__)
350            res.set_result(ResultState.fail)
351            res.set_failmsg(failmsg)
352            return rest
353
354        if type(outputJSONVal) == dict and (type(outputJSONVal[matchJSONKey]) == dict or
355                type(outputJSONVal[matchJSONKey]) == list):
356            if len(matchJSONVal) > 0:
357                res = find_in_json(res, outputJSONVal[matchJSONKey], matchJSONVal, matchJSONKey)
358            # handling corner case where matchJSONVal == [] or matchJSONVal == {}
359            else:
360                res = find_in_json_other(res, outputJSONVal, matchJSONVal, matchJSONKey)
361        else:
362            res = find_in_json(res, outputJSONVal, matchJSONVal, matchJSONKey)
363    return res
364
365def find_in_json_other(res, outputJSONVal, matchJSONVal, matchJSONKey=None):
366    if matchJSONKey in outputJSONVal:
367        if matchJSONVal != outputJSONVal[matchJSONKey]:
368            failmsg = 'Value doesn\'t match: {}: {} != {}\nMatching against output: {}'
369            failmsg = failmsg.format(matchJSONKey, matchJSONVal, outputJSONVal[matchJSONKey], outputJSONVal)
370            res.set_result(ResultState.fail)
371            res.set_failmsg(failmsg)
372            return res
373
374    return res
375
376def run_one_test(pm, args, index, tidx):
377    global NAMES
378    ns = NAMES['NS']
379    dev0 = NAMES['DEV0']
380    dev1 = NAMES['DEV1']
381    dummy = NAMES['DUMMY']
382    result = True
383    tresult = ""
384    tap = ""
385    res = TestResult(tidx['id'], tidx['name'])
386    if args.verbose > 0:
387        print("\t====================\n=====> ", end="")
388    print("Test " + tidx["id"] + ": " + tidx["name"])
389
390    if 'skip' in tidx:
391        if tidx['skip'] == 'yes':
392            res = TestResult(tidx['id'], tidx['name'])
393            res.set_result(ResultState.skip)
394            res.set_errormsg('Test case designated as skipped.')
395            pm.call_pre_case(tidx, test_skip=True)
396            pm.call_post_execute(tidx)
397            return res
398
399    if 'dependsOn' in tidx:
400        if (args.verbose > 0):
401            print('probe command for test skip')
402        (p, procout) = exec_cmd(tidx, args, pm, 'execute', tidx['dependsOn'])
403        if p:
404            if (p.returncode != 0):
405                res = TestResult(tidx['id'], tidx['name'])
406                res.set_result(ResultState.skip)
407                res.set_errormsg('probe command: test skipped.')
408                pm.call_pre_case(tidx, test_skip=True)
409                pm.call_post_execute(tidx)
410                return res
411
412    # populate NAMES with TESTID for this test
413    NAMES['TESTID'] = tidx['id']
414    NAMES['NS'] = '{}-{}'.format(NAMES['NS'], tidx['random'])
415    NAMES['DEV0'] = '{}id{}'.format(NAMES['DEV0'], tidx['id'])
416    NAMES['DEV1'] = '{}id{}'.format(NAMES['DEV1'], tidx['id'])
417    NAMES['DUMMY'] = '{}id{}'.format(NAMES['DUMMY'], tidx['id'])
418
419    pm.call_pre_case(tidx)
420    prepare_env(tidx, args, pm, 'setup', "-----> prepare stage", tidx["setup"])
421
422    if (args.verbose > 0):
423        print('-----> execute stage')
424    pm.call_pre_execute(tidx)
425    (p, procout) = exec_cmd(tidx, args, pm, 'execute', tidx["cmdUnderTest"])
426    if p:
427        exit_code = p.returncode
428    else:
429        exit_code = None
430
431    pm.call_post_execute(tidx)
432
433    if (exit_code is None or exit_code != int(tidx["expExitCode"])):
434        print("exit: {!r}".format(exit_code))
435        print("exit: {}".format(int(tidx["expExitCode"])))
436        #print("exit: {!r} {}".format(exit_code, int(tidx["expExitCode"])))
437        res.set_result(ResultState.fail)
438        res.set_failmsg('Command exited with {}, expected {}\n{}'.format(exit_code, tidx["expExitCode"], procout))
439        print(procout)
440    else:
441        if args.verbose > 0:
442            print('-----> verify stage')
443        (p, procout) = exec_cmd(tidx, args, pm, 'verify', tidx["verifyCmd"])
444        if procout:
445            if 'matchJSON' in tidx:
446                verify_by_json(procout, res, tidx, args, pm)
447            elif 'matchPattern' in tidx:
448                match_pattern = re.compile(
449                    str(tidx["matchPattern"]), re.DOTALL | re.MULTILINE)
450                match_index = re.findall(match_pattern, procout)
451                if len(match_index) != int(tidx["matchCount"]):
452                    res.set_result(ResultState.fail)
453                    res.set_failmsg('Could not match regex pattern. Verify command output:\n{}'.format(procout))
454                else:
455                    res.set_result(ResultState.success)
456            else:
457                res.set_result(ResultState.fail)
458                res.set_failmsg('Must specify a match option: matchJSON or matchPattern\n{}'.format(procout))
459        elif int(tidx["matchCount"]) != 0:
460            res.set_result(ResultState.fail)
461            res.set_failmsg('No output generated by verify command.')
462        else:
463            res.set_result(ResultState.success)
464
465    prepare_env(tidx, args, pm, 'teardown', '-----> teardown stage', tidx['teardown'], procout)
466    pm.call_post_case(tidx)
467
468    index += 1
469
470    # remove TESTID from NAMES
471    del(NAMES['TESTID'])
472
473    # Restore names
474    NAMES['NS'] = ns
475    NAMES['DEV0'] = dev0
476    NAMES['DEV1'] = dev1
477    NAMES['DUMMY'] = dummy
478
479    return res
480
481def prepare_run(pm, args, testlist):
482    tcount = len(testlist)
483    emergency_exit = False
484    emergency_exit_message = ''
485
486    try:
487        pm.call_pre_suite(tcount, testlist)
488    except Exception as ee:
489        ex_type, ex, ex_tb = sys.exc_info()
490        print('Exception {} {} (caught in pre_suite).'.
491              format(ex_type, ex))
492        traceback.print_tb(ex_tb)
493        emergency_exit_message = 'EMERGENCY EXIT, call_pre_suite failed with exception {} {}\n'.format(ex_type, ex)
494        emergency_exit = True
495
496    if emergency_exit:
497        pm.call_post_suite(1)
498        return emergency_exit_message
499
500    if args.verbose:
501        print('give test rig 2 seconds to stabilize')
502
503    time.sleep(2)
504
505def purge_run(pm, index):
506    pm.call_post_suite(index)
507
508def test_runner(pm, args, filtered_tests):
509    """
510    Driver function for the unit tests.
511
512    Prints information about the tests being run, executes the setup and
513    teardown commands and the command under test itself. Also determines
514    success/failure based on the information in the test case and generates
515    TAP output accordingly.
516    """
517    testlist = filtered_tests
518    tcount = len(testlist)
519    index = 1
520    tap = ''
521    badtest = None
522    stage = None
523
524    tsr = TestSuiteReport()
525
526    for tidx in testlist:
527        if "flower" in tidx["category"] and args.device == None:
528            errmsg = "Tests using the DEV2 variable must define the name of a "
529            errmsg += "physical NIC with the -d option when running tdc.\n"
530            errmsg += "Test has been skipped."
531            if args.verbose > 1:
532                print(errmsg)
533            res = TestResult(tidx['id'], tidx['name'])
534            res.set_result(ResultState.skip)
535            res.set_errormsg(errmsg)
536            tsr.add_resultdata(res)
537            index += 1
538            continue
539        try:
540            badtest = tidx  # in case it goes bad
541            res = run_one_test(pm, args, index, tidx)
542            tsr.add_resultdata(res)
543        except PluginMgrTestFail as pmtf:
544            ex_type, ex, ex_tb = sys.exc_info()
545            stage = pmtf.stage
546            message = pmtf.message
547            output = pmtf.output
548            res = TestResult(tidx['id'], tidx['name'])
549            res.set_result(ResultState.skip)
550            res.set_errormsg(pmtf.message)
551            res.set_failmsg(pmtf.output)
552            tsr.add_resultdata(res)
553            index += 1
554            print(message)
555            print('Exception {} {} (caught in test_runner, running test {} {} {} stage {})'.
556                  format(ex_type, ex, index, tidx['id'], tidx['name'], stage))
557            print('---------------')
558            print('traceback')
559            traceback.print_tb(ex_tb)
560            print('---------------')
561            if stage == 'teardown':
562                print('accumulated output for this test:')
563                if pmtf.output:
564                    print(pmtf.output)
565            print('---------------')
566            break
567        index += 1
568
569    # if we failed in setup or teardown,
570    # fill in the remaining tests with ok-skipped
571    count = index
572
573    if tcount + 1 != count:
574        for tidx in testlist[count - 1:]:
575            res = TestResult(tidx['id'], tidx['name'])
576            res.set_result(ResultState.skip)
577            msg = 'skipped - previous {} failed {} {}'.format(stage,
578                index, badtest.get('id', '--Unknown--'))
579            res.set_errormsg(msg)
580            tsr.add_resultdata(res)
581            count += 1
582
583    if args.pause:
584        print('Want to pause\nPress enter to continue ...')
585        if input(sys.stdin):
586            print('got something on stdin')
587
588    return (index, tsr)
589
590def mp_bins(alltests):
591    serial = []
592    parallel = []
593
594    for test in alltests:
595        if 'nsPlugin' not in test['plugins']:
596            serial.append(test)
597        else:
598            # We can only create one netdevsim device at a time
599            if 'netdevsim/new_device' in str(test['setup']):
600                serial.append(test)
601            else:
602                parallel.append(test)
603
604    return (serial, parallel)
605
606def __mp_runner(tests):
607    (_, tsr) = test_runner(mp_pm, mp_args, tests)
608    return tsr._testsuite
609
610def test_runner_mp(pm, args, alltests):
611    prepare_run(pm, args, alltests)
612
613    (serial, parallel) = mp_bins(alltests)
614
615    batches = [parallel[n : n + 32] for n in range(0, len(parallel), 32)]
616    batches.insert(0, serial)
617
618    print("Executing {} tests in parallel and {} in serial".format(len(parallel), len(serial)))
619    print("Using {} batches".format(len(batches)))
620
621    # We can't pickle these objects so workaround them
622    global mp_pm
623    mp_pm = pm
624
625    global mp_args
626    mp_args = args
627
628    with Pool(args.mp) as p:
629        pres = p.map(__mp_runner, batches)
630
631    tsr = TestSuiteReport()
632    for trs in pres:
633        for res in trs:
634            tsr.add_resultdata(res)
635
636    # Passing an index is not useful in MP
637    purge_run(pm, None)
638
639    return tsr
640
641def test_runner_serial(pm, args, alltests):
642    prepare_run(pm, args, alltests)
643
644    if args.verbose:
645        print("Executing {} tests in serial".format(len(alltests)))
646
647    (index, tsr) = test_runner(pm, args, alltests)
648
649    purge_run(pm, index)
650
651    return tsr
652
653def has_blank_ids(idlist):
654    """
655    Search the list for empty ID fields and return true/false accordingly.
656    """
657    return not(all(k for k in idlist))
658
659
660def load_from_file(filename):
661    """
662    Open the JSON file containing the test cases and return them
663    as list of ordered dictionary objects.
664    """
665    try:
666        with open(filename) as test_data:
667            testlist = json.load(test_data, object_pairs_hook=OrderedDict)
668    except json.JSONDecodeError as jde:
669        print('IGNORING test case file {}\n\tBECAUSE:  {}'.format(filename, jde))
670        testlist = list()
671    else:
672        idlist = get_id_list(testlist)
673        if (has_blank_ids(idlist)):
674            for k in testlist:
675                k['filename'] = filename
676    return testlist
677
678def identity(string):
679    return string
680
681def args_parse():
682    """
683    Create the argument parser.
684    """
685    parser = argparse.ArgumentParser(description='Linux TC unit tests')
686    parser.register('type', None, identity)
687    return parser
688
689
690def set_args(parser):
691    """
692    Set the command line arguments for tdc.
693    """
694    parser.add_argument(
695        '--outfile', type=str,
696        help='Path to the file in which results should be saved. ' +
697        'Default target is the current directory.')
698    parser.add_argument(
699        '-p', '--path', type=str,
700        help='The full path to the tc executable to use')
701    sg = parser.add_argument_group(
702        'selection', 'select which test cases: ' +
703        'files plus directories; filtered by categories plus testids')
704    ag = parser.add_argument_group(
705        'action', 'select action to perform on selected test cases')
706
707    sg.add_argument(
708        '-D', '--directory', nargs='+', metavar='DIR',
709        help='Collect tests from the specified directory(ies) ' +
710        '(default [tc-tests])')
711    sg.add_argument(
712        '-f', '--file', nargs='+', metavar='FILE',
713        help='Run tests from the specified file(s)')
714    sg.add_argument(
715        '-c', '--category', nargs='*', metavar='CATG', default=['+c'],
716        help='Run tests only from the specified category/ies, ' +
717        'or if no category/ies is/are specified, list known categories.')
718    sg.add_argument(
719        '-e', '--execute', nargs='+', metavar='ID',
720        help='Execute the specified test cases with specified IDs')
721    ag.add_argument(
722        '-l', '--list', action='store_true',
723        help='List all test cases, or those only within the specified category')
724    ag.add_argument(
725        '-s', '--show', action='store_true', dest='showID',
726        help='Display the selected test cases')
727    ag.add_argument(
728        '-i', '--id', action='store_true', dest='gen_id',
729        help='Generate ID numbers for new test cases')
730    parser.add_argument(
731        '-v', '--verbose', action='count', default=0,
732        help='Show the commands that are being run')
733    parser.add_argument(
734        '--format', default='tap', const='tap', nargs='?',
735        choices=['none', 'xunit', 'tap'],
736        help='Specify the format for test results. (Default: TAP)')
737    parser.add_argument('-d', '--device',
738                        help='Execute test cases that use a physical device, ' +
739                        'where DEVICE is its name. (If not defined, tests ' +
740                        'that require a physical device will be skipped)')
741    parser.add_argument(
742        '-P', '--pause', action='store_true',
743        help='Pause execution just before post-suite stage')
744    parser.add_argument(
745        '-J', '--multiprocess', type=int, default=1, dest='mp',
746        help='Run tests in parallel whenever possible')
747    return parser
748
749
750def check_default_settings(args, remaining, pm):
751    """
752    Process any arguments overriding the default settings,
753    and ensure the settings are correct.
754    """
755    # Allow for overriding specific settings
756    global NAMES
757
758    if args.path != None:
759        NAMES['TC'] = args.path
760    if args.device != None:
761        NAMES['DEV2'] = args.device
762    if 'TIMEOUT' not in NAMES:
763        NAMES['TIMEOUT'] = None
764    if not os.path.isfile(NAMES['TC']):
765        print("The specified tc path " + NAMES['TC'] + " does not exist.")
766        exit(1)
767
768    pm.call_check_args(args, remaining)
769
770
771def get_id_list(alltests):
772    """
773    Generate a list of all IDs in the test cases.
774    """
775    return [x["id"] for x in alltests]
776
777def check_case_id(alltests):
778    """
779    Check for duplicate test case IDs.
780    """
781    idl = get_id_list(alltests)
782    return [x for x in idl if idl.count(x) > 1]
783
784
785def does_id_exist(alltests, newid):
786    """
787    Check if a given ID already exists in the list of test cases.
788    """
789    idl = get_id_list(alltests)
790    return (any(newid == x for x in idl))
791
792
793def generate_case_ids(alltests):
794    """
795    If a test case has a blank ID field, generate a random hex ID for it
796    and then write the test cases back to disk.
797    """
798    for c in alltests:
799        if (c["id"] == ""):
800            while True:
801                newid = str('{:04x}'.format(random.randrange(16**4)))
802                if (does_id_exist(alltests, newid)):
803                    continue
804                else:
805                    c['id'] = newid
806                    break
807
808    ufilename = []
809    for c in alltests:
810        if ('filename' in c):
811            ufilename.append(c['filename'])
812    ufilename = get_unique_item(ufilename)
813    for f in ufilename:
814        testlist = []
815        for t in alltests:
816            if 'filename' in t:
817                if t['filename'] == f:
818                    del t['filename']
819                    testlist.append(t)
820        outfile = open(f, "w")
821        json.dump(testlist, outfile, indent=4)
822        outfile.write("\n")
823        outfile.close()
824
825def filter_tests_by_id(args, testlist):
826    '''
827    Remove tests from testlist that are not in the named id list.
828    If id list is empty, return empty list.
829    '''
830    newlist = list()
831    if testlist and args.execute:
832        target_ids = args.execute
833
834        if isinstance(target_ids, list) and (len(target_ids) > 0):
835            newlist = list(filter(lambda x: x['id'] in target_ids, testlist))
836    return newlist
837
838def filter_tests_by_category(args, testlist):
839    '''
840    Remove tests from testlist that are not in a named category.
841    '''
842    answer = list()
843    if args.category and testlist:
844        test_ids = list()
845        for catg in set(args.category):
846            if catg == '+c':
847                continue
848            print('considering category {}'.format(catg))
849            for tc in testlist:
850                if catg in tc['category'] and tc['id'] not in test_ids:
851                    answer.append(tc)
852                    test_ids.append(tc['id'])
853
854    return answer
855
856def set_random(alltests):
857    for tidx in alltests:
858        tidx['random'] = random.getrandbits(32)
859
860def get_test_cases(args):
861    """
862    If a test case file is specified, retrieve tests from that file.
863    Otherwise, glob for all json files in subdirectories and load from
864    each one.
865    Also, if requested, filter by category, and add tests matching
866    certain ids.
867    """
868    import fnmatch
869
870    flist = []
871    testdirs = ['tc-tests']
872
873    if args.file:
874        # at least one file was specified - remove the default directory
875        testdirs = []
876
877        for ff in args.file:
878            if not os.path.isfile(ff):
879                print("IGNORING file " + ff + "\n\tBECAUSE does not exist.")
880            else:
881                flist.append(os.path.abspath(ff))
882
883    if args.directory:
884        testdirs = args.directory
885
886    for testdir in testdirs:
887        for root, dirnames, filenames in os.walk(testdir):
888            for filename in fnmatch.filter(filenames, '*.json'):
889                candidate = os.path.abspath(os.path.join(root, filename))
890                if candidate not in testdirs:
891                    flist.append(candidate)
892
893    alltestcases = list()
894    for casefile in flist:
895        alltestcases = alltestcases + (load_from_file(casefile))
896
897    allcatlist = get_test_categories(alltestcases)
898    allidlist = get_id_list(alltestcases)
899
900    testcases_by_cats = get_categorized_testlist(alltestcases, allcatlist)
901    idtestcases = filter_tests_by_id(args, alltestcases)
902    cattestcases = filter_tests_by_category(args, alltestcases)
903
904    cat_ids = [x['id'] for x in cattestcases]
905    if args.execute:
906        if args.category:
907            alltestcases = cattestcases + [x for x in idtestcases if x['id'] not in cat_ids]
908        else:
909            alltestcases = idtestcases
910    else:
911        if cat_ids:
912            alltestcases = cattestcases
913        else:
914            # just accept the existing value of alltestcases,
915            # which has been filtered by file/directory
916            pass
917
918    return allcatlist, allidlist, testcases_by_cats, alltestcases
919
920
921def set_operation_mode(pm, parser, args, remaining):
922    """
923    Load the test case data and process remaining arguments to determine
924    what the script should do for this run, and call the appropriate
925    function.
926    """
927    ucat, idlist, testcases, alltests = get_test_cases(args)
928
929    if args.gen_id:
930        if (has_blank_ids(idlist)):
931            alltests = generate_case_ids(alltests)
932        else:
933            print("No empty ID fields found in test files.")
934        exit(0)
935
936    duplicate_ids = check_case_id(alltests)
937    if (len(duplicate_ids) > 0):
938        print("The following test case IDs are not unique:")
939        print(str(set(duplicate_ids)))
940        print("Please correct them before continuing.")
941        exit(1)
942
943    if args.showID:
944        for atest in alltests:
945            print_test_case(atest)
946        exit(0)
947
948    if isinstance(args.category, list) and (len(args.category) == 0):
949        print("Available categories:")
950        print_sll(ucat)
951        exit(0)
952
953    if args.list:
954        list_test_cases(alltests)
955        exit(0)
956
957    set_random(alltests)
958
959    exit_code = 0 # KSFT_PASS
960    if len(alltests):
961        req_plugins = pm.get_required_plugins(alltests)
962        try:
963            args = pm.load_required_plugins(req_plugins, parser, args, remaining)
964        except PluginDependencyException as pde:
965            print('The following plugins were not found:')
966            print('{}'.format(pde.missing_pg))
967
968        if args.mp > 1:
969            catresults = test_runner_mp(pm, args, alltests)
970        else:
971            catresults = test_runner_serial(pm, args, alltests)
972
973        if catresults.count_failures() != 0:
974            exit_code = 1 # KSFT_FAIL
975        if args.format == 'none':
976            print('Test results output suppression requested\n')
977        else:
978            print('\nAll test results: \n')
979            if args.format == 'xunit':
980                suffix = 'xml'
981                res = catresults.format_xunit()
982            elif args.format == 'tap':
983                suffix = 'tap'
984                res = catresults.format_tap()
985            print(res)
986            print('\n\n')
987            if not args.outfile:
988                fname = 'test-results.{}'.format(suffix)
989            else:
990                fname = args.outfile
991            with open(fname, 'w') as fh:
992                fh.write(res)
993                fh.close()
994                if os.getenv('SUDO_UID') is not None:
995                    os.chown(fname, uid=int(os.getenv('SUDO_UID')),
996                        gid=int(os.getenv('SUDO_GID')))
997    else:
998        print('No tests found\n')
999        exit_code = 4 # KSFT_SKIP
1000    exit(exit_code)
1001
1002def main():
1003    """
1004    Start of execution; set up argument parser and get the arguments,
1005    and start operations.
1006    """
1007    import resource
1008
1009    if sys.version_info.major < 3 or sys.version_info.minor < 8:
1010        sys.exit("tdc requires at least python 3.8")
1011
1012    resource.setrlimit(resource.RLIMIT_NOFILE, (1048576, 1048576))
1013
1014    parser = args_parse()
1015    parser = set_args(parser)
1016    pm = PluginMgr(parser)
1017    parser = pm.call_add_args(parser)
1018    (args, remaining) = parser.parse_known_args()
1019    args.NAMES = NAMES
1020    pm.set_args(args)
1021    check_default_settings(args, remaining, pm)
1022    if args.verbose > 2:
1023        print('args is {}'.format(args))
1024
1025    set_operation_mode(pm, parser, args, remaining)
1026
1027if __name__ == "__main__":
1028    main()
1029