xref: /linux/tools/testing/selftests/tc-testing/tdc.py (revision 90e63d5354951d37fa2b3b91e6f17b95d2bf9bee)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: GPL-2.0
3
4"""
5tdc.py - Linux tc (Traffic Control) unit test driver
6
7Copyright (C) 2017 Lucas Bates <lucasb@mojatatu.com>
8"""
9
10import re
11import os
12import sys
13import argparse
14import importlib
15import json
16import subprocess
17import time
18import traceback
19import random
20from multiprocessing import Pool
21from collections import OrderedDict
22from string import Template
23
24from tdc_config import *
25from tdc_helper import *
26
27import TdcPlugin
28from TdcResults import *
29
30class PluginDependencyException(Exception):
31    def __init__(self, missing_pg):
32        self.missing_pg = missing_pg
33
34class PluginMgrTestFail(Exception):
35    def __init__(self, stage, output, message):
36        self.stage = stage
37        self.output = output
38        self.message = message
39
40class PluginMgr:
41    def __init__(self, argparser):
42        super().__init__()
43        self.plugins = set()
44        self.plugin_instances = []
45        self.failed_plugins = {}
46        self.argparser = argparser
47
48        plugindir = os.getenv('TDC_PLUGIN_DIR', './plugins')
49        for dirpath, dirnames, filenames in os.walk(plugindir):
50            for fn in filenames:
51                if (fn.endswith('.py') and
52                    not fn == '__init__.py' and
53                    not fn.startswith('#') and
54                    not fn.startswith('.#')):
55                    mn = fn[0:-3]
56                    foo = importlib.import_module('plugins.' + mn)
57                    self.plugins.add(mn)
58                    self.plugin_instances[mn] = foo.SubPlugin()
59
60    def load_plugin(self, pgdir, pgname):
61        pgname = pgname[0:-3]
62        self.plugins.add(pgname)
63
64        foo = importlib.import_module('{}.{}'.format(pgdir, pgname))
65
66        # nsPlugin must always be the first one
67        if pgname == "nsPlugin":
68            self.plugin_instances.insert(0, (pgname, foo.SubPlugin()))
69            self.plugin_instances[0][1].check_args(self.args, None)
70        else:
71            self.plugin_instances.append((pgname, foo.SubPlugin()))
72            self.plugin_instances[-1][1].check_args(self.args, None)
73
74    def get_required_plugins(self, testlist):
75        '''
76        Get all required plugins from the list of test cases and return
77        all unique items.
78        '''
79        reqs = set()
80        for t in testlist:
81            try:
82                if 'requires' in t['plugins']:
83                    if isinstance(t['plugins']['requires'], list):
84                        reqs.update(set(t['plugins']['requires']))
85                    else:
86                        reqs.add(t['plugins']['requires'])
87                    t['plugins'] = t['plugins']['requires']
88                else:
89                    t['plugins'] = []
90            except KeyError:
91                t['plugins'] = []
92                continue
93
94        return reqs
95
96    def load_required_plugins(self, reqs, parser, args, remaining):
97        '''
98        Get all required plugins from the list of test cases and load any plugin
99        that is not already enabled.
100        '''
101        pgd = ['plugin-lib', 'plugin-lib-custom']
102        pnf = []
103
104        for r in reqs:
105            if r not in self.plugins:
106                fname = '{}.py'.format(r)
107                source_path = []
108                for d in pgd:
109                    pgpath = '{}/{}'.format(d, fname)
110                    if os.path.isfile(pgpath):
111                        source_path.append(pgpath)
112                if len(source_path) == 0:
113                    print('ERROR: unable to find required plugin {}'.format(r))
114                    pnf.append(fname)
115                    continue
116                elif len(source_path) > 1:
117                    print('WARNING: multiple copies of plugin {} found, using version found')
118                    print('at {}'.format(source_path[0]))
119                pgdir = source_path[0]
120                pgdir = pgdir.split('/')[0]
121                self.load_plugin(pgdir, fname)
122        if len(pnf) > 0:
123            raise PluginDependencyException(pnf)
124
125        parser = self.call_add_args(parser)
126        (args, remaining) = parser.parse_known_args(args=remaining, namespace=args)
127        return args
128
129    def call_pre_suite(self, testcount, testidlist):
130        for (_, pgn_inst) in self.plugin_instances:
131            pgn_inst.pre_suite(testcount, testidlist)
132
133    def call_post_suite(self, index):
134        for (_, pgn_inst) in reversed(self.plugin_instances):
135            pgn_inst.post_suite(index)
136
137    def call_pre_case(self, caseinfo, *, test_skip=False):
138        for (pgn, pgn_inst) in self.plugin_instances:
139            if pgn not in caseinfo['plugins']:
140                continue
141            try:
142                pgn_inst.pre_case(caseinfo, test_skip)
143            except Exception as ee:
144                print('exception {} in call to pre_case for {} plugin'.
145                      format(ee, pgn_inst.__class__))
146                print('testid is {}'.format(caseinfo['id']))
147                raise
148
149    def call_post_case(self, caseinfo):
150        for (pgn, pgn_inst) in reversed(self.plugin_instances):
151            if pgn not in caseinfo['plugins']:
152                continue
153            pgn_inst.post_case()
154
155    def call_pre_execute(self, caseinfo):
156        for (pgn, pgn_inst) in self.plugin_instances:
157            if pgn not in caseinfo['plugins']:
158                continue
159            pgn_inst.pre_execute()
160
161    def call_post_execute(self, caseinfo):
162        for (pgn, pgn_inst) in reversed(self.plugin_instances):
163            if pgn not in caseinfo['plugins']:
164                continue
165            pgn_inst.post_execute()
166
167    def call_add_args(self, parser):
168        for (pgn, pgn_inst) in self.plugin_instances:
169            parser = pgn_inst.add_args(parser)
170        return parser
171
172    def call_check_args(self, args, remaining):
173        for (pgn, pgn_inst) in self.plugin_instances:
174            pgn_inst.check_args(args, remaining)
175
176    def call_adjust_command(self, caseinfo, stage, command):
177        for (pgn, pgn_inst) in self.plugin_instances:
178            if pgn not in caseinfo['plugins']:
179                continue
180            command = pgn_inst.adjust_command(stage, command)
181        return command
182
183    def set_args(self, args):
184        self.args = args
185
186    @staticmethod
187    def _make_argparser(args):
188        self.argparser = argparse.ArgumentParser(
189            description='Linux TC unit tests')
190
191def replace_keywords(cmd):
192    """
193    For a given executable command, substitute any known
194    variables contained within NAMES with the correct values
195    """
196    tcmd = Template(cmd)
197    subcmd = tcmd.safe_substitute(NAMES)
198    return subcmd
199
200
201def exec_cmd(caseinfo, args, pm, stage, command):
202    """
203    Perform any required modifications on an executable command, then run
204    it in a subprocess and return the results.
205    """
206    if len(command.strip()) == 0:
207        return None, None
208    if '$' in command:
209        command = replace_keywords(command)
210
211    command = pm.call_adjust_command(caseinfo, stage, command)
212    if args.verbose > 0:
213        print('command "{}"'.format(command))
214
215    proc = subprocess.Popen(command,
216        shell=True,
217        stdout=subprocess.PIPE,
218        stderr=subprocess.PIPE,
219        env=ENVIR)
220
221    try:
222        (rawout, serr) = proc.communicate(timeout=NAMES['TIMEOUT'])
223        if proc.returncode != 0 and len(serr) > 0:
224            foutput = serr.decode("utf-8", errors="ignore")
225        else:
226            foutput = rawout.decode("utf-8", errors="ignore")
227    except subprocess.TimeoutExpired:
228        foutput = "Command \"{}\" timed out\n".format(command)
229        proc.returncode = 255
230
231    proc.stdout.close()
232    proc.stderr.close()
233    return proc, foutput
234
235
236def prepare_env(caseinfo, args, pm, stage, prefix, cmdlist, output = None):
237    """
238    Execute the setup/teardown commands for a test case.
239    Optionally terminate test execution if the command fails.
240    """
241    if args.verbose > 0:
242        print('{}'.format(prefix))
243    for cmdinfo in cmdlist:
244        if isinstance(cmdinfo, list):
245            exit_codes = cmdinfo[1:]
246            cmd = cmdinfo[0]
247        else:
248            exit_codes = [0]
249            cmd = cmdinfo
250
251        if not cmd:
252            continue
253
254        (proc, foutput) = exec_cmd(caseinfo, args, pm, stage, cmd)
255
256        if proc and (proc.returncode not in exit_codes):
257            print('', file=sys.stderr)
258            print("{} *** Could not execute: \"{}\"".format(prefix, cmd),
259                  file=sys.stderr)
260            print("\n{} *** Error message: \"{}\"".format(prefix, foutput),
261                  file=sys.stderr)
262            print("returncode {}; expected {}".format(proc.returncode,
263                                                      exit_codes))
264            print("\n{} *** Aborting test run.".format(prefix), file=sys.stderr)
265            print("\n\n{} *** stdout ***".format(proc.stdout), file=sys.stderr)
266            print("\n\n{} *** stderr ***".format(proc.stderr), file=sys.stderr)
267            raise PluginMgrTestFail(
268                stage, output,
269                '"{}" did not complete successfully'.format(prefix))
270
271def verify_by_json(procout, res, tidx, args, pm):
272    try:
273        outputJSON = json.loads(procout)
274    except json.JSONDecodeError:
275        res.set_result(ResultState.fail)
276        res.set_failmsg('Cannot decode verify command\'s output. Is it JSON?')
277        return res
278
279    matchJSON = json.loads(json.dumps(tidx['matchJSON']))
280
281    if type(outputJSON) != type(matchJSON):
282        failmsg = 'Original output and matchJSON value are not the same type: output: {} != matchJSON: {} '
283        failmsg = failmsg.format(type(outputJSON).__name__, type(matchJSON).__name__)
284        res.set_result(ResultState.fail)
285        res.set_failmsg(failmsg)
286        return res
287
288    if len(matchJSON) > len(outputJSON):
289        failmsg = "Your matchJSON value is an array, and it contains more elements than the command under test\'s output:\ncommand output (length: {}):\n{}\nmatchJSON value (length: {}):\n{}"
290        failmsg = failmsg.format(len(outputJSON), outputJSON, len(matchJSON), matchJSON)
291        res.set_result(ResultState.fail)
292        res.set_failmsg(failmsg)
293        return res
294    res = find_in_json(res, outputJSON, matchJSON, 0)
295
296    return res
297
298def find_in_json(res, outputJSONVal, matchJSONVal, matchJSONKey=None):
299    if res.get_result() == ResultState.fail:
300        return res
301
302    if type(matchJSONVal) == list:
303        res = find_in_json_list(res, outputJSONVal, matchJSONVal, matchJSONKey)
304
305    elif type(matchJSONVal) == dict:
306        res = find_in_json_dict(res, outputJSONVal, matchJSONVal)
307    else:
308        res = find_in_json_other(res, outputJSONVal, matchJSONVal, matchJSONKey)
309
310    if res.get_result() != ResultState.fail:
311        res.set_result(ResultState.success)
312        return res
313
314    return res
315
316def find_in_json_list(res, outputJSONVal, matchJSONVal, matchJSONKey=None):
317    if (type(matchJSONVal) != type(outputJSONVal)):
318        failmsg = 'Original output and matchJSON value are not the same type: output: {} != matchJSON: {}'
319        failmsg = failmsg.format(outputJSONVal, matchJSONVal)
320        res.set_result(ResultState.fail)
321        res.set_failmsg(failmsg)
322        return res
323
324    if len(matchJSONVal) > len(outputJSONVal):
325        failmsg = "Your matchJSON value is an array, and it contains more elements than the command under test\'s output:\ncommand output (length: {}):\n{}\nmatchJSON value (length: {}):\n{}"
326        failmsg = failmsg.format(len(outputJSONVal), outputJSONVal, len(matchJSONVal), matchJSONVal)
327        res.set_result(ResultState.fail)
328        res.set_failmsg(failmsg)
329        return res
330
331    for matchJSONIdx, matchJSONVal in enumerate(matchJSONVal):
332        res = find_in_json(res, outputJSONVal[matchJSONIdx], matchJSONVal,
333                           matchJSONKey)
334    return res
335
336def find_in_json_dict(res, outputJSONVal, matchJSONVal):
337    for matchJSONKey, matchJSONVal in matchJSONVal.items():
338        if type(outputJSONVal) == dict:
339            if matchJSONKey not in outputJSONVal:
340                failmsg = 'Key not found in json output: {}: {}\nMatching against output: {}'
341                failmsg = failmsg.format(matchJSONKey, matchJSONVal, outputJSONVal)
342                res.set_result(ResultState.fail)
343                res.set_failmsg(failmsg)
344                return res
345
346        else:
347            failmsg = 'Original output and matchJSON value are not the same type: output: {} != matchJSON: {}'
348            failmsg = failmsg.format(type(outputJSON).__name__, type(matchJSON).__name__)
349            res.set_result(ResultState.fail)
350            res.set_failmsg(failmsg)
351            return rest
352
353        if type(outputJSONVal) == dict and (type(outputJSONVal[matchJSONKey]) == dict or
354                type(outputJSONVal[matchJSONKey]) == list):
355            if len(matchJSONVal) > 0:
356                res = find_in_json(res, outputJSONVal[matchJSONKey], matchJSONVal, matchJSONKey)
357            # handling corner case where matchJSONVal == [] or matchJSONVal == {}
358            else:
359                res = find_in_json_other(res, outputJSONVal, matchJSONVal, matchJSONKey)
360        else:
361            res = find_in_json(res, outputJSONVal, matchJSONVal, matchJSONKey)
362    return res
363
364def find_in_json_other(res, outputJSONVal, matchJSONVal, matchJSONKey=None):
365    if matchJSONKey in outputJSONVal:
366        if matchJSONVal != outputJSONVal[matchJSONKey]:
367            failmsg = 'Value doesn\'t match: {}: {} != {}\nMatching against output: {}'
368            failmsg = failmsg.format(matchJSONKey, matchJSONVal, outputJSONVal[matchJSONKey], outputJSONVal)
369            res.set_result(ResultState.fail)
370            res.set_failmsg(failmsg)
371            return res
372
373    return res
374
375def run_one_test(pm, args, index, tidx):
376    global NAMES
377    ns = NAMES['NS']
378    dev0 = NAMES['DEV0']
379    dev1 = NAMES['DEV1']
380    dummy = NAMES['DUMMY']
381    ifb = NAMES['IFB']
382    result = True
383    tresult = ""
384    tap = ""
385    res = TestResult(tidx['id'], tidx['name'])
386    if args.verbose > 0:
387        print("\t====================\n=====> ", end="")
388    print("Test " + tidx["id"] + ": " + tidx["name"])
389
390    if 'skip' in tidx:
391        if tidx['skip'] == 'yes':
392            res = TestResult(tidx['id'], tidx['name'])
393            res.set_result(ResultState.skip)
394            res.set_errormsg('Test case designated as skipped.')
395            pm.call_pre_case(tidx, test_skip=True)
396            pm.call_post_execute(tidx)
397            return res
398
399    if 'dependsOn' in tidx:
400        if (args.verbose > 0):
401            print('probe command for test skip')
402        (p, procout) = exec_cmd(tidx, args, pm, 'execute', tidx['dependsOn'])
403        if p:
404            if (p.returncode != 0):
405                res = TestResult(tidx['id'], tidx['name'])
406                res.set_result(ResultState.skip)
407                res.set_errormsg('probe command: test skipped.')
408                pm.call_pre_case(tidx, test_skip=True)
409                pm.call_post_execute(tidx)
410                return res
411
412    # populate NAMES with TESTID for this test
413    NAMES['TESTID'] = tidx['id']
414    NAMES['NS'] = '{}-{}'.format(NAMES['NS'], tidx['random'])
415    NAMES['DEV0'] = '{}id{}'.format(NAMES['DEV0'], tidx['id'])
416    NAMES['DEV1'] = '{}id{}'.format(NAMES['DEV1'], tidx['id'])
417    NAMES['DUMMY'] = '{}id{}'.format(NAMES['DUMMY'], tidx['id'])
418    NAMES['IFB'] = '{}id{}'.format(NAMES['IFB'], tidx['id'])
419
420    pm.call_pre_case(tidx)
421    prepare_env(tidx, args, pm, 'setup', "-----> prepare stage", tidx["setup"])
422
423    if (args.verbose > 0):
424        print('-----> execute stage')
425    pm.call_pre_execute(tidx)
426    (p, procout) = exec_cmd(tidx, args, pm, 'execute', tidx["cmdUnderTest"])
427    if p:
428        exit_code = p.returncode
429    else:
430        exit_code = None
431
432    pm.call_post_execute(tidx)
433
434    if (exit_code is None or exit_code != int(tidx["expExitCode"])):
435        print("exit: {!r}".format(exit_code))
436        print("exit: {}".format(int(tidx["expExitCode"])))
437        #print("exit: {!r} {}".format(exit_code, int(tidx["expExitCode"])))
438        res.set_result(ResultState.fail)
439        res.set_failmsg('Command exited with {}, expected {}\n{}'.format(exit_code, tidx["expExitCode"], procout))
440        print(procout)
441    else:
442        if args.verbose > 0:
443            print('-----> verify stage')
444        (p, procout) = exec_cmd(tidx, args, pm, 'verify', tidx["verifyCmd"])
445        if procout:
446            if 'matchJSON' in tidx:
447                verify_by_json(procout, res, tidx, args, pm)
448            elif 'matchPattern' in tidx:
449                match_pattern = re.compile(
450                    str(tidx["matchPattern"]), re.DOTALL | re.MULTILINE)
451                match_index = re.findall(match_pattern, procout)
452                if len(match_index) != int(tidx["matchCount"]):
453                    res.set_result(ResultState.fail)
454                    res.set_failmsg('Could not match regex pattern. Verify command output:\n{}'.format(procout))
455                else:
456                    res.set_result(ResultState.success)
457            else:
458                res.set_result(ResultState.fail)
459                res.set_failmsg('Must specify a match option: matchJSON or matchPattern\n{}'.format(procout))
460        elif int(tidx["matchCount"]) != 0:
461            res.set_result(ResultState.fail)
462            res.set_failmsg('No output generated by verify command.')
463        else:
464            res.set_result(ResultState.success)
465
466    prepare_env(tidx, args, pm, 'teardown', '-----> teardown stage', tidx['teardown'], procout)
467    pm.call_post_case(tidx)
468
469    index += 1
470
471    # remove TESTID from NAMES
472    del(NAMES['TESTID'])
473
474    # Restore names
475    NAMES['NS'] = ns
476    NAMES['DEV0'] = dev0
477    NAMES['DEV1'] = dev1
478    NAMES['DUMMY'] = dummy
479    NAMES['IFB'] = ifb
480
481    return res
482
483def prepare_run(pm, args, testlist):
484    tcount = len(testlist)
485    emergency_exit = False
486    emergency_exit_message = ''
487
488    try:
489        pm.call_pre_suite(tcount, testlist)
490    except Exception as ee:
491        ex_type, ex, ex_tb = sys.exc_info()
492        print('Exception {} {} (caught in pre_suite).'.
493              format(ex_type, ex))
494        traceback.print_tb(ex_tb)
495        emergency_exit_message = 'EMERGENCY EXIT, call_pre_suite failed with exception {} {}\n'.format(ex_type, ex)
496        emergency_exit = True
497
498    if emergency_exit:
499        pm.call_post_suite(1)
500        return emergency_exit_message
501
502def purge_run(pm, index):
503    pm.call_post_suite(index)
504
505def test_runner(pm, args, filtered_tests):
506    """
507    Driver function for the unit tests.
508
509    Prints information about the tests being run, executes the setup and
510    teardown commands and the command under test itself. Also determines
511    success/failure based on the information in the test case and generates
512    TAP output accordingly.
513    """
514    testlist = filtered_tests
515    tcount = len(testlist)
516    index = 1
517    tap = ''
518    badtest = None
519    stage = None
520
521    tsr = TestSuiteReport()
522
523    for tidx in testlist:
524        if "flower" in tidx["category"] and args.device == None:
525            errmsg = "Tests using the DEV2 variable must define the name of a "
526            errmsg += "physical NIC with the -d option when running tdc.\n"
527            errmsg += "Test has been skipped."
528            if args.verbose > 1:
529                print(errmsg)
530            res = TestResult(tidx['id'], tidx['name'])
531            res.set_result(ResultState.skip)
532            res.set_errormsg(errmsg)
533            tsr.add_resultdata(res)
534            index += 1
535            continue
536        try:
537            badtest = tidx  # in case it goes bad
538            res = run_one_test(pm, args, index, tidx)
539            tsr.add_resultdata(res)
540        except PluginMgrTestFail as pmtf:
541            ex_type, ex, ex_tb = sys.exc_info()
542            stage = pmtf.stage
543            message = pmtf.message
544            output = pmtf.output
545            res = TestResult(tidx['id'], tidx['name'])
546            res.set_result(ResultState.fail)
547            res.set_errormsg(pmtf.message)
548            res.set_failmsg(pmtf.output)
549            tsr.add_resultdata(res)
550            index += 1
551            print(message)
552            print('Exception {} {} (caught in test_runner, running test {} {} {} stage {})'.
553                  format(ex_type, ex, index, tidx['id'], tidx['name'], stage))
554            print('---------------')
555            print('traceback')
556            traceback.print_tb(ex_tb)
557            print('---------------')
558            if stage == 'teardown':
559                print('accumulated output for this test:')
560                if pmtf.output:
561                    print(pmtf.output)
562            print('---------------')
563            break
564        index += 1
565
566    # if we failed in setup or teardown,
567    # fill in the remaining tests with ok-skipped
568    count = index
569
570    if tcount + 1 != count:
571        for tidx in testlist[count - 1:]:
572            res = TestResult(tidx['id'], tidx['name'])
573            res.set_result(ResultState.skip)
574            msg = 'skipped - previous {} failed {} {}'.format(stage,
575                index, badtest.get('id', '--Unknown--'))
576            res.set_errormsg(msg)
577            tsr.add_resultdata(res)
578            count += 1
579
580    if args.pause:
581        print('Want to pause\nPress enter to continue ...')
582        if input(sys.stdin):
583            print('got something on stdin')
584
585    return (index, tsr)
586
587def mp_bins(alltests):
588    serial = []
589    parallel = []
590
591    for test in alltests:
592        if 'nsPlugin' not in test['plugins']:
593            serial.append(test)
594        else:
595            # We can only create one netdevsim device at a time
596            if 'netdevsim/new_device' in str(test['setup']):
597                serial.append(test)
598            else:
599                parallel.append(test)
600
601    return (serial, parallel)
602
603def __mp_runner(tests):
604    (_, tsr) = test_runner(mp_pm, mp_args, tests)
605    return tsr._testsuite
606
607def test_runner_mp(pm, args, alltests):
608    prepare_run(pm, args, alltests)
609
610    (serial, parallel) = mp_bins(alltests)
611
612    batches = [parallel[n : n + 32] for n in range(0, len(parallel), 32)]
613    batches.insert(0, serial)
614
615    print("Executing {} tests in parallel and {} in serial".format(len(parallel), len(serial)))
616    print("Using {} batches and {} workers".format(len(batches), args.mp))
617
618    # We can't pickle these objects so workaround them
619    global mp_pm
620    mp_pm = pm
621
622    global mp_args
623    mp_args = args
624
625    with Pool(args.mp) as p:
626        pres = p.map(__mp_runner, batches)
627
628    tsr = TestSuiteReport()
629    for trs in pres:
630        for res in trs:
631            tsr.add_resultdata(res)
632
633    # Passing an index is not useful in MP
634    purge_run(pm, None)
635
636    return tsr
637
638def test_runner_serial(pm, args, alltests):
639    prepare_run(pm, args, alltests)
640
641    if args.verbose:
642        print("Executing {} tests in serial".format(len(alltests)))
643
644    (index, tsr) = test_runner(pm, args, alltests)
645
646    purge_run(pm, index)
647
648    return tsr
649
650def has_blank_ids(idlist):
651    """
652    Search the list for empty ID fields and return true/false accordingly.
653    """
654    return not(all(k for k in idlist))
655
656
657def load_from_file(filename):
658    """
659    Open the JSON file containing the test cases and return them
660    as list of ordered dictionary objects.
661    """
662    try:
663        with open(filename) as test_data:
664            testlist = json.load(test_data, object_pairs_hook=OrderedDict)
665    except json.JSONDecodeError as jde:
666        print('IGNORING test case file {}\n\tBECAUSE:  {}'.format(filename, jde))
667        testlist = list()
668    else:
669        idlist = get_id_list(testlist)
670        if (has_blank_ids(idlist)):
671            for k in testlist:
672                k['filename'] = filename
673    return testlist
674
675def identity(string):
676    return string
677
678def args_parse():
679    """
680    Create the argument parser.
681    """
682    parser = argparse.ArgumentParser(description='Linux TC unit tests')
683    parser.register('type', None, identity)
684    return parser
685
686
687def set_args(parser):
688    """
689    Set the command line arguments for tdc.
690    """
691    parser.add_argument(
692        '--outfile', type=str,
693        help='Path to the file in which results should be saved. ' +
694        'Default target is the current directory.')
695    parser.add_argument(
696        '-p', '--path', type=str,
697        help='The full path to the tc executable to use')
698    sg = parser.add_argument_group(
699        'selection', 'select which test cases: ' +
700        'files plus directories; filtered by categories plus testids')
701    ag = parser.add_argument_group(
702        'action', 'select action to perform on selected test cases')
703
704    sg.add_argument(
705        '-D', '--directory', nargs='+', metavar='DIR',
706        help='Collect tests from the specified directory(ies) ' +
707        '(default [tc-tests])')
708    sg.add_argument(
709        '-f', '--file', nargs='+', metavar='FILE',
710        help='Run tests from the specified file(s)')
711    sg.add_argument(
712        '-c', '--category', nargs='*', metavar='CATG', default=['+c'],
713        help='Run tests only from the specified category/ies, ' +
714        'or if no category/ies is/are specified, list known categories.')
715    sg.add_argument(
716        '-e', '--execute', nargs='+', metavar='ID',
717        help='Execute the specified test cases with specified IDs')
718    ag.add_argument(
719        '-l', '--list', action='store_true',
720        help='List all test cases, or those only within the specified category')
721    ag.add_argument(
722        '-s', '--show', action='store_true', dest='showID',
723        help='Display the selected test cases')
724    ag.add_argument(
725        '-i', '--id', action='store_true', dest='gen_id',
726        help='Generate ID numbers for new test cases')
727    parser.add_argument(
728        '-v', '--verbose', action='count', default=0,
729        help='Show the commands that are being run')
730    parser.add_argument(
731        '--format', default='tap', const='tap', nargs='?',
732        choices=['none', 'xunit', 'tap'],
733        help='Specify the format for test results. (Default: TAP)')
734    parser.add_argument('-d', '--device',
735                        help='Execute test cases that use a physical device, ' +
736                        'where DEVICE is its name. (If not defined, tests ' +
737                        'that require a physical device will be skipped)')
738    parser.add_argument(
739        '-P', '--pause', action='store_true',
740        help='Pause execution just before post-suite stage')
741    parser.add_argument(
742        '-J', '--multiprocess', type=int, default=1, dest='mp',
743        help='Run tests in parallel whenever possible')
744    return parser
745
746
747def check_default_settings(args, remaining, pm):
748    """
749    Process any arguments overriding the default settings,
750    and ensure the settings are correct.
751    """
752    # Allow for overriding specific settings
753    global NAMES
754
755    if args.path != None:
756        NAMES['TC'] = args.path
757    if args.device != None:
758        NAMES['DEV2'] = args.device
759    if 'TIMEOUT' not in NAMES:
760        NAMES['TIMEOUT'] = None
761    if 'ETHTOOL' in NAMES and not os.path.isfile(NAMES['ETHTOOL']):
762        print(f"The specified ethtool path {NAMES['ETHTOOL']} does not exist.")
763        exit(1)
764    if not os.path.isfile(NAMES['TC']):
765        print("The specified tc path " + NAMES['TC'] + " does not exist.")
766        exit(1)
767
768    pm.call_check_args(args, remaining)
769
770
771def get_id_list(alltests):
772    """
773    Generate a list of all IDs in the test cases.
774    """
775    return [x["id"] for x in alltests]
776
777def check_case_id(alltests):
778    """
779    Check for duplicate test case IDs.
780    """
781    idl = get_id_list(alltests)
782    return [x for x in idl if idl.count(x) > 1]
783
784
785def does_id_exist(alltests, newid):
786    """
787    Check if a given ID already exists in the list of test cases.
788    """
789    idl = get_id_list(alltests)
790    return (any(newid == x for x in idl))
791
792
793def generate_case_ids(alltests):
794    """
795    If a test case has a blank ID field, generate a random hex ID for it
796    and then write the test cases back to disk.
797    """
798    for c in alltests:
799        if (c["id"] == ""):
800            while True:
801                newid = str('{:04x}'.format(random.randrange(16**4)))
802                if (does_id_exist(alltests, newid)):
803                    continue
804                else:
805                    c['id'] = newid
806                    break
807
808    ufilename = []
809    for c in alltests:
810        if ('filename' in c):
811            ufilename.append(c['filename'])
812    ufilename = get_unique_item(ufilename)
813    for f in ufilename:
814        testlist = []
815        for t in alltests:
816            if 'filename' in t:
817                if t['filename'] == f:
818                    del t['filename']
819                    testlist.append(t)
820        outfile = open(f, "w")
821        json.dump(testlist, outfile, indent=4)
822        outfile.write("\n")
823        outfile.close()
824
825def filter_tests_by_id(args, testlist):
826    '''
827    Remove tests from testlist that are not in the named id list.
828    If id list is empty, return empty list.
829    '''
830    newlist = list()
831    if testlist and args.execute:
832        target_ids = args.execute
833
834        if isinstance(target_ids, list) and (len(target_ids) > 0):
835            newlist = list(filter(lambda x: x['id'] in target_ids, testlist))
836    return newlist
837
838def filter_tests_by_category(args, testlist):
839    '''
840    Remove tests from testlist that are not in a named category.
841    '''
842    answer = list()
843    if args.category and testlist:
844        test_ids = list()
845        for catg in set(args.category):
846            if catg == '+c':
847                continue
848            print('considering category {}'.format(catg))
849            for tc in testlist:
850                if catg in tc['category'] and tc['id'] not in test_ids:
851                    answer.append(tc)
852                    test_ids.append(tc['id'])
853
854    return answer
855
856def set_random(alltests):
857    for tidx in alltests:
858        tidx['random'] = random.getrandbits(32)
859
860def get_test_cases(args):
861    """
862    If a test case file is specified, retrieve tests from that file.
863    Otherwise, glob for all json files in subdirectories and load from
864    each one.
865    Also, if requested, filter by category, and add tests matching
866    certain ids.
867    """
868    import fnmatch
869
870    flist = []
871    testdirs = ['tc-tests']
872
873    if args.file:
874        # at least one file was specified - remove the default directory
875        testdirs = []
876
877        for ff in args.file:
878            if not os.path.isfile(ff):
879                print("IGNORING file " + ff + "\n\tBECAUSE does not exist.")
880            else:
881                flist.append(os.path.abspath(ff))
882
883    if args.directory:
884        testdirs = args.directory
885
886    for testdir in testdirs:
887        for root, dirnames, filenames in os.walk(testdir):
888            for filename in fnmatch.filter(filenames, '*.json'):
889                candidate = os.path.abspath(os.path.join(root, filename))
890                if candidate not in testdirs:
891                    flist.append(candidate)
892
893    alltestcases = list()
894    for casefile in flist:
895        alltestcases = alltestcases + (load_from_file(casefile))
896
897    allcatlist = get_test_categories(alltestcases)
898    allidlist = get_id_list(alltestcases)
899
900    testcases_by_cats = get_categorized_testlist(alltestcases, allcatlist)
901    idtestcases = filter_tests_by_id(args, alltestcases)
902    cattestcases = filter_tests_by_category(args, alltestcases)
903
904    cat_ids = [x['id'] for x in cattestcases]
905    if args.execute:
906        if args.category:
907            alltestcases = cattestcases + [x for x in idtestcases if x['id'] not in cat_ids]
908        else:
909            alltestcases = idtestcases
910    else:
911        if cat_ids:
912            alltestcases = cattestcases
913        else:
914            # just accept the existing value of alltestcases,
915            # which has been filtered by file/directory
916            pass
917
918    return allcatlist, allidlist, testcases_by_cats, alltestcases
919
920
921def set_operation_mode(pm, parser, args, remaining):
922    """
923    Load the test case data and process remaining arguments to determine
924    what the script should do for this run, and call the appropriate
925    function.
926    """
927    ucat, idlist, testcases, alltests = get_test_cases(args)
928
929    if args.gen_id:
930        if (has_blank_ids(idlist)):
931            alltests = generate_case_ids(alltests)
932        else:
933            print("No empty ID fields found in test files.")
934        exit(0)
935
936    duplicate_ids = check_case_id(alltests)
937    if (len(duplicate_ids) > 0):
938        print("The following test case IDs are not unique:")
939        print(str(set(duplicate_ids)))
940        print("Please correct them before continuing.")
941        exit(1)
942
943    if args.showID:
944        for atest in alltests:
945            print_test_case(atest)
946        exit(0)
947
948    if isinstance(args.category, list) and (len(args.category) == 0):
949        print("Available categories:")
950        print_sll(ucat)
951        exit(0)
952
953    if args.list:
954        list_test_cases(alltests)
955        exit(0)
956
957    set_random(alltests)
958
959    exit_code = 0 # KSFT_PASS
960    if len(alltests):
961        req_plugins = pm.get_required_plugins(alltests)
962        try:
963            args = pm.load_required_plugins(req_plugins, parser, args, remaining)
964        except PluginDependencyException as pde:
965            print('The following plugins were not found:')
966            print('{}'.format(pde.missing_pg))
967
968        if args.mp > 1:
969            catresults = test_runner_mp(pm, args, alltests)
970        else:
971            catresults = test_runner_serial(pm, args, alltests)
972
973        if catresults.count_failures() != 0:
974            exit_code = 1 # KSFT_FAIL
975        if args.format == 'none':
976            print('Test results output suppression requested\n')
977        else:
978            print('\nAll test results: \n')
979            if args.format == 'xunit':
980                suffix = 'xml'
981                res = catresults.format_xunit()
982            elif args.format == 'tap':
983                suffix = 'tap'
984                res = catresults.format_tap()
985            print(res)
986            print('\n\n')
987            if not args.outfile:
988                fname = 'test-results.{}'.format(suffix)
989            else:
990                fname = args.outfile
991            with open(fname, 'w') as fh:
992                fh.write(res)
993                fh.close()
994                if os.getenv('SUDO_UID') is not None:
995                    os.chown(fname, uid=int(os.getenv('SUDO_UID')),
996                        gid=int(os.getenv('SUDO_GID')))
997    else:
998        print('No tests found\n')
999        exit_code = 4 # KSFT_SKIP
1000    exit(exit_code)
1001
1002def main():
1003    """
1004    Start of execution; set up argument parser and get the arguments,
1005    and start operations.
1006    """
1007    import resource
1008
1009    if sys.version_info.major < 3 or sys.version_info.minor < 8:
1010        sys.exit("tdc requires at least python 3.8")
1011
1012    resource.setrlimit(resource.RLIMIT_NOFILE, (1048576, 1048576))
1013
1014    parser = args_parse()
1015    parser = set_args(parser)
1016    pm = PluginMgr(parser)
1017    parser = pm.call_add_args(parser)
1018    (args, remaining) = parser.parse_known_args()
1019    args.NAMES = NAMES
1020    args.mp = min(args.mp, 4)
1021    pm.set_args(args)
1022    check_default_settings(args, remaining, pm)
1023    if args.verbose > 2:
1024        print('args is {}'.format(args))
1025
1026    try:
1027        set_operation_mode(pm, parser, args, remaining)
1028    except KeyboardInterrupt:
1029        # Cleanup on Ctrl-C
1030        pm.call_post_suite(None)
1031
1032if __name__ == "__main__":
1033    main()
1034