xref: /linux/tools/testing/selftests/tc-testing/tdc.py (revision 5754a1c9f9b6e298791c4bb34263f37dfe93ee35)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: GPL-2.0
3
4"""
5tdc.py - Linux tc (Traffic Control) unit test driver
6
7Copyright (C) 2017 Lucas Bates <lucasb@mojatatu.com>
8"""
9
10import re
11import os
12import sys
13import argparse
14import importlib
15import json
16import subprocess
17import time
18import traceback
19import random
20from multiprocessing import Pool
21from collections import OrderedDict
22from string import Template
23
24from tdc_config import *
25from tdc_helper import *
26
27import TdcPlugin
28from TdcResults import *
29
30class PluginDependencyException(Exception):
31    def __init__(self, missing_pg):
32        self.missing_pg = missing_pg
33
34class PluginMgrTestFail(Exception):
35    def __init__(self, stage, output, message):
36        self.stage = stage
37        self.output = output
38        self.message = message
39
40class PluginMgr:
41    def __init__(self, argparser):
42        super().__init__()
43        self.plugins = set()
44        self.plugin_instances = []
45        self.failed_plugins = {}
46        self.argparser = argparser
47
48        plugindir = os.getenv('TDC_PLUGIN_DIR', './plugins')
49        for dirpath, dirnames, filenames in os.walk(plugindir):
50            for fn in filenames:
51                if (fn.endswith('.py') and
52                    not fn == '__init__.py' and
53                    not fn.startswith('#') and
54                    not fn.startswith('.#')):
55                    mn = fn[0:-3]
56                    foo = importlib.import_module('plugins.' + mn)
57                    self.plugins.add(mn)
58                    self.plugin_instances[mn] = foo.SubPlugin()
59
60    def load_plugin(self, pgdir, pgname):
61        pgname = pgname[0:-3]
62        self.plugins.add(pgname)
63
64        foo = importlib.import_module('{}.{}'.format(pgdir, pgname))
65
66        # nsPlugin must always be the first one
67        if pgname == "nsPlugin":
68            self.plugin_instances.insert(0, (pgname, foo.SubPlugin()))
69            self.plugin_instances[0][1].check_args(self.args, None)
70        else:
71            self.plugin_instances.append((pgname, foo.SubPlugin()))
72            self.plugin_instances[-1][1].check_args(self.args, None)
73
74    def get_required_plugins(self, testlist):
75        '''
76        Get all required plugins from the list of test cases and return
77        all unique items.
78        '''
79        reqs = set()
80        for t in testlist:
81            try:
82                if 'requires' in t['plugins']:
83                    if isinstance(t['plugins']['requires'], list):
84                        reqs.update(set(t['plugins']['requires']))
85                    else:
86                        reqs.add(t['plugins']['requires'])
87                    t['plugins'] = t['plugins']['requires']
88                else:
89                    t['plugins'] = []
90            except KeyError:
91                t['plugins'] = []
92                continue
93
94        return reqs
95
96    def load_required_plugins(self, reqs, parser, args, remaining):
97        '''
98        Get all required plugins from the list of test cases and load any plugin
99        that is not already enabled.
100        '''
101        pgd = ['plugin-lib', 'plugin-lib-custom']
102        pnf = []
103
104        for r in reqs:
105            if r not in self.plugins:
106                fname = '{}.py'.format(r)
107                source_path = []
108                for d in pgd:
109                    pgpath = '{}/{}'.format(d, fname)
110                    if os.path.isfile(pgpath):
111                        source_path.append(pgpath)
112                if len(source_path) == 0:
113                    print('ERROR: unable to find required plugin {}'.format(r))
114                    pnf.append(fname)
115                    continue
116                elif len(source_path) > 1:
117                    print('WARNING: multiple copies of plugin {} found, using version found')
118                    print('at {}'.format(source_path[0]))
119                pgdir = source_path[0]
120                pgdir = pgdir.split('/')[0]
121                self.load_plugin(pgdir, fname)
122        if len(pnf) > 0:
123            raise PluginDependencyException(pnf)
124
125        parser = self.call_add_args(parser)
126        (args, remaining) = parser.parse_known_args(args=remaining, namespace=args)
127        return args
128
129    def call_pre_suite(self, testcount, testidlist):
130        for (_, pgn_inst) in self.plugin_instances:
131            pgn_inst.pre_suite(testcount, testidlist)
132
133    def call_post_suite(self, index):
134        for (_, pgn_inst) in reversed(self.plugin_instances):
135            pgn_inst.post_suite(index)
136
137    def call_pre_case(self, caseinfo, *, test_skip=False):
138        for (pgn, pgn_inst) in self.plugin_instances:
139            if pgn not in caseinfo['plugins']:
140                continue
141            try:
142                pgn_inst.pre_case(caseinfo, test_skip)
143            except Exception as ee:
144                print('exception {} in call to pre_case for {} plugin'.
145                      format(ee, pgn_inst.__class__))
146                print('testid is {}'.format(caseinfo['id']))
147                raise
148
149    def call_post_case(self, caseinfo):
150        for (pgn, pgn_inst) in reversed(self.plugin_instances):
151            if pgn not in caseinfo['plugins']:
152                continue
153            pgn_inst.post_case()
154
155    def call_pre_execute(self, caseinfo):
156        for (pgn, pgn_inst) in self.plugin_instances:
157            if pgn not in caseinfo['plugins']:
158                continue
159            pgn_inst.pre_execute()
160
161    def call_post_execute(self, caseinfo):
162        for (pgn, pgn_inst) in reversed(self.plugin_instances):
163            if pgn not in caseinfo['plugins']:
164                continue
165            pgn_inst.post_execute()
166
167    def call_add_args(self, parser):
168        for (pgn, pgn_inst) in self.plugin_instances:
169            parser = pgn_inst.add_args(parser)
170        return parser
171
172    def call_check_args(self, args, remaining):
173        for (pgn, pgn_inst) in self.plugin_instances:
174            pgn_inst.check_args(args, remaining)
175
176    def call_adjust_command(self, caseinfo, stage, command):
177        for (pgn, pgn_inst) in self.plugin_instances:
178            if pgn not in caseinfo['plugins']:
179                continue
180            command = pgn_inst.adjust_command(stage, command)
181        return command
182
183    def set_args(self, args):
184        self.args = args
185
186    @staticmethod
187    def _make_argparser(args):
188        self.argparser = argparse.ArgumentParser(
189            description='Linux TC unit tests')
190
191def replace_keywords(cmd):
192    """
193    For a given executable command, substitute any known
194    variables contained within NAMES with the correct values
195    """
196    tcmd = Template(cmd)
197    subcmd = tcmd.safe_substitute(NAMES)
198    return subcmd
199
200
201def exec_cmd(caseinfo, args, pm, stage, command):
202    """
203    Perform any required modifications on an executable command, then run
204    it in a subprocess and return the results.
205    """
206    if len(command.strip()) == 0:
207        return None, None
208    if '$' in command:
209        command = replace_keywords(command)
210
211    command = pm.call_adjust_command(caseinfo, stage, command)
212    if args.verbose > 0:
213        print('command "{}"'.format(command))
214
215    proc = subprocess.Popen(command,
216        shell=True,
217        stdout=subprocess.PIPE,
218        stderr=subprocess.PIPE,
219        env=ENVIR)
220
221    try:
222        (rawout, serr) = proc.communicate(timeout=NAMES['TIMEOUT'])
223        if proc.returncode != 0 and len(serr) > 0:
224            foutput = serr.decode("utf-8", errors="ignore")
225        else:
226            foutput = rawout.decode("utf-8", errors="ignore")
227    except subprocess.TimeoutExpired:
228        foutput = "Command \"{}\" timed out\n".format(command)
229        proc.returncode = 255
230
231    proc.stdout.close()
232    proc.stderr.close()
233    return proc, foutput
234
235
236def prepare_env(caseinfo, args, pm, stage, prefix, cmdlist, output = None):
237    """
238    Execute the setup/teardown commands for a test case.
239    Optionally terminate test execution if the command fails.
240    """
241    if args.verbose > 0:
242        print('{}'.format(prefix))
243    for cmdinfo in cmdlist:
244        if isinstance(cmdinfo, list):
245            exit_codes = cmdinfo[1:]
246            cmd = cmdinfo[0]
247        else:
248            exit_codes = [0]
249            cmd = cmdinfo
250
251        if not cmd:
252            continue
253
254        (proc, foutput) = exec_cmd(caseinfo, args, pm, stage, cmd)
255
256        if proc and (proc.returncode not in exit_codes):
257            print('', file=sys.stderr)
258            print("{} *** Could not execute: \"{}\"".format(prefix, cmd),
259                  file=sys.stderr)
260            print("\n{} *** Error message: \"{}\"".format(prefix, foutput),
261                  file=sys.stderr)
262            print("returncode {}; expected {}".format(proc.returncode,
263                                                      exit_codes))
264            print("\n{} *** Aborting test run.".format(prefix), file=sys.stderr)
265            print("\n\n{} *** stdout ***".format(proc.stdout), file=sys.stderr)
266            print("\n\n{} *** stderr ***".format(proc.stderr), file=sys.stderr)
267            raise PluginMgrTestFail(
268                stage, output,
269                '"{}" did not complete successfully'.format(prefix))
270
271def verify_by_json(procout, res, tidx, args, pm):
272    try:
273        outputJSON = json.loads(procout)
274    except json.JSONDecodeError:
275        res.set_result(ResultState.fail)
276        res.set_failmsg('Cannot decode verify command\'s output. Is it JSON?')
277        return res
278
279    matchJSON = json.loads(json.dumps(tidx['matchJSON']))
280
281    if type(outputJSON) != type(matchJSON):
282        failmsg = 'Original output and matchJSON value are not the same type: output: {} != matchJSON: {} '
283        failmsg = failmsg.format(type(outputJSON).__name__, type(matchJSON).__name__)
284        res.set_result(ResultState.fail)
285        res.set_failmsg(failmsg)
286        return res
287
288    if len(matchJSON) > len(outputJSON):
289        failmsg = "Your matchJSON value is an array, and it contains more elements than the command under test\'s output:\ncommand output (length: {}):\n{}\nmatchJSON value (length: {}):\n{}"
290        failmsg = failmsg.format(len(outputJSON), outputJSON, len(matchJSON), matchJSON)
291        res.set_result(ResultState.fail)
292        res.set_failmsg(failmsg)
293        return res
294    res = find_in_json(res, outputJSON, matchJSON, 0)
295
296    return res
297
298def find_in_json(res, outputJSONVal, matchJSONVal, matchJSONKey=None):
299    if res.get_result() == ResultState.fail:
300        return res
301
302    if type(matchJSONVal) == list:
303        res = find_in_json_list(res, outputJSONVal, matchJSONVal, matchJSONKey)
304
305    elif type(matchJSONVal) == dict:
306        res = find_in_json_dict(res, outputJSONVal, matchJSONVal)
307    else:
308        res = find_in_json_other(res, outputJSONVal, matchJSONVal, matchJSONKey)
309
310    if res.get_result() != ResultState.fail:
311        res.set_result(ResultState.success)
312        return res
313
314    return res
315
316def find_in_json_list(res, outputJSONVal, matchJSONVal, matchJSONKey=None):
317    if (type(matchJSONVal) != type(outputJSONVal)):
318        failmsg = 'Original output and matchJSON value are not the same type: output: {} != matchJSON: {}'
319        failmsg = failmsg.format(outputJSONVal, matchJSONVal)
320        res.set_result(ResultState.fail)
321        res.set_failmsg(failmsg)
322        return res
323
324    if len(matchJSONVal) > len(outputJSONVal):
325        failmsg = "Your matchJSON value is an array, and it contains more elements than the command under test\'s output:\ncommand output (length: {}):\n{}\nmatchJSON value (length: {}):\n{}"
326        failmsg = failmsg.format(len(outputJSONVal), outputJSONVal, len(matchJSONVal), matchJSONVal)
327        res.set_result(ResultState.fail)
328        res.set_failmsg(failmsg)
329        return res
330
331    for matchJSONIdx, matchJSONVal in enumerate(matchJSONVal):
332        res = find_in_json(res, outputJSONVal[matchJSONIdx], matchJSONVal,
333                           matchJSONKey)
334    return res
335
336def find_in_json_dict(res, outputJSONVal, matchJSONVal):
337    for matchJSONKey, matchJSONVal in matchJSONVal.items():
338        if type(outputJSONVal) == dict:
339            if matchJSONKey not in outputJSONVal:
340                failmsg = 'Key not found in json output: {}: {}\nMatching against output: {}'
341                failmsg = failmsg.format(matchJSONKey, matchJSONVal, outputJSONVal)
342                res.set_result(ResultState.fail)
343                res.set_failmsg(failmsg)
344                return res
345
346        else:
347            failmsg = 'Original output and matchJSON value are not the same type: output: {} != matchJSON: {}'
348            failmsg = failmsg.format(type(outputJSON).__name__, type(matchJSON).__name__)
349            res.set_result(ResultState.fail)
350            res.set_failmsg(failmsg)
351            return rest
352
353        if type(outputJSONVal) == dict and (type(outputJSONVal[matchJSONKey]) == dict or
354                type(outputJSONVal[matchJSONKey]) == list):
355            if len(matchJSONVal) > 0:
356                res = find_in_json(res, outputJSONVal[matchJSONKey], matchJSONVal, matchJSONKey)
357            # handling corner case where matchJSONVal == [] or matchJSONVal == {}
358            else:
359                res = find_in_json_other(res, outputJSONVal, matchJSONVal, matchJSONKey)
360        else:
361            res = find_in_json(res, outputJSONVal, matchJSONVal, matchJSONKey)
362    return res
363
364def find_in_json_other(res, outputJSONVal, matchJSONVal, matchJSONKey=None):
365    if matchJSONKey in outputJSONVal:
366        if matchJSONVal != outputJSONVal[matchJSONKey]:
367            failmsg = 'Value doesn\'t match: {}: {} != {}\nMatching against output: {}'
368            failmsg = failmsg.format(matchJSONKey, matchJSONVal, outputJSONVal[matchJSONKey], outputJSONVal)
369            res.set_result(ResultState.fail)
370            res.set_failmsg(failmsg)
371            return res
372
373    return res
374
375def run_one_test(pm, args, index, tidx):
376    global NAMES
377    ns = NAMES['NS']
378    dev0 = NAMES['DEV0']
379    dev1 = NAMES['DEV1']
380    dummy = NAMES['DUMMY']
381    result = True
382    tresult = ""
383    tap = ""
384    res = TestResult(tidx['id'], tidx['name'])
385    if args.verbose > 0:
386        print("\t====================\n=====> ", end="")
387    print("Test " + tidx["id"] + ": " + tidx["name"])
388
389    if 'skip' in tidx:
390        if tidx['skip'] == 'yes':
391            res = TestResult(tidx['id'], tidx['name'])
392            res.set_result(ResultState.skip)
393            res.set_errormsg('Test case designated as skipped.')
394            pm.call_pre_case(tidx, test_skip=True)
395            pm.call_post_execute(tidx)
396            return res
397
398    if 'dependsOn' in tidx:
399        if (args.verbose > 0):
400            print('probe command for test skip')
401        (p, procout) = exec_cmd(tidx, args, pm, 'execute', tidx['dependsOn'])
402        if p:
403            if (p.returncode != 0):
404                res = TestResult(tidx['id'], tidx['name'])
405                res.set_result(ResultState.skip)
406                res.set_errormsg('probe command: test skipped.')
407                pm.call_pre_case(tidx, test_skip=True)
408                pm.call_post_execute(tidx)
409                return res
410
411    # populate NAMES with TESTID for this test
412    NAMES['TESTID'] = tidx['id']
413    NAMES['NS'] = '{}-{}'.format(NAMES['NS'], tidx['random'])
414    NAMES['DEV0'] = '{}id{}'.format(NAMES['DEV0'], tidx['id'])
415    NAMES['DEV1'] = '{}id{}'.format(NAMES['DEV1'], tidx['id'])
416    NAMES['DUMMY'] = '{}id{}'.format(NAMES['DUMMY'], tidx['id'])
417
418    pm.call_pre_case(tidx)
419    prepare_env(tidx, args, pm, 'setup', "-----> prepare stage", tidx["setup"])
420
421    if (args.verbose > 0):
422        print('-----> execute stage')
423    pm.call_pre_execute(tidx)
424    (p, procout) = exec_cmd(tidx, args, pm, 'execute', tidx["cmdUnderTest"])
425    if p:
426        exit_code = p.returncode
427    else:
428        exit_code = None
429
430    pm.call_post_execute(tidx)
431
432    if (exit_code is None or exit_code != int(tidx["expExitCode"])):
433        print("exit: {!r}".format(exit_code))
434        print("exit: {}".format(int(tidx["expExitCode"])))
435        #print("exit: {!r} {}".format(exit_code, int(tidx["expExitCode"])))
436        res.set_result(ResultState.fail)
437        res.set_failmsg('Command exited with {}, expected {}\n{}'.format(exit_code, tidx["expExitCode"], procout))
438        print(procout)
439    else:
440        if args.verbose > 0:
441            print('-----> verify stage')
442        (p, procout) = exec_cmd(tidx, args, pm, 'verify', tidx["verifyCmd"])
443        if procout:
444            if 'matchJSON' in tidx:
445                verify_by_json(procout, res, tidx, args, pm)
446            elif 'matchPattern' in tidx:
447                match_pattern = re.compile(
448                    str(tidx["matchPattern"]), re.DOTALL | re.MULTILINE)
449                match_index = re.findall(match_pattern, procout)
450                if len(match_index) != int(tidx["matchCount"]):
451                    res.set_result(ResultState.fail)
452                    res.set_failmsg('Could not match regex pattern. Verify command output:\n{}'.format(procout))
453                else:
454                    res.set_result(ResultState.success)
455            else:
456                res.set_result(ResultState.fail)
457                res.set_failmsg('Must specify a match option: matchJSON or matchPattern\n{}'.format(procout))
458        elif int(tidx["matchCount"]) != 0:
459            res.set_result(ResultState.fail)
460            res.set_failmsg('No output generated by verify command.')
461        else:
462            res.set_result(ResultState.success)
463
464    prepare_env(tidx, args, pm, 'teardown', '-----> teardown stage', tidx['teardown'], procout)
465    pm.call_post_case(tidx)
466
467    index += 1
468
469    # remove TESTID from NAMES
470    del(NAMES['TESTID'])
471
472    # Restore names
473    NAMES['NS'] = ns
474    NAMES['DEV0'] = dev0
475    NAMES['DEV1'] = dev1
476    NAMES['DUMMY'] = dummy
477
478    return res
479
480def prepare_run(pm, args, testlist):
481    tcount = len(testlist)
482    emergency_exit = False
483    emergency_exit_message = ''
484
485    try:
486        pm.call_pre_suite(tcount, testlist)
487    except Exception as ee:
488        ex_type, ex, ex_tb = sys.exc_info()
489        print('Exception {} {} (caught in pre_suite).'.
490              format(ex_type, ex))
491        traceback.print_tb(ex_tb)
492        emergency_exit_message = 'EMERGENCY EXIT, call_pre_suite failed with exception {} {}\n'.format(ex_type, ex)
493        emergency_exit = True
494
495    if emergency_exit:
496        pm.call_post_suite(1)
497        return emergency_exit_message
498
499def purge_run(pm, index):
500    pm.call_post_suite(index)
501
502def test_runner(pm, args, filtered_tests):
503    """
504    Driver function for the unit tests.
505
506    Prints information about the tests being run, executes the setup and
507    teardown commands and the command under test itself. Also determines
508    success/failure based on the information in the test case and generates
509    TAP output accordingly.
510    """
511    testlist = filtered_tests
512    tcount = len(testlist)
513    index = 1
514    tap = ''
515    badtest = None
516    stage = None
517
518    tsr = TestSuiteReport()
519
520    for tidx in testlist:
521        if "flower" in tidx["category"] and args.device == None:
522            errmsg = "Tests using the DEV2 variable must define the name of a "
523            errmsg += "physical NIC with the -d option when running tdc.\n"
524            errmsg += "Test has been skipped."
525            if args.verbose > 1:
526                print(errmsg)
527            res = TestResult(tidx['id'], tidx['name'])
528            res.set_result(ResultState.skip)
529            res.set_errormsg(errmsg)
530            tsr.add_resultdata(res)
531            index += 1
532            continue
533        try:
534            badtest = tidx  # in case it goes bad
535            res = run_one_test(pm, args, index, tidx)
536            tsr.add_resultdata(res)
537        except PluginMgrTestFail as pmtf:
538            ex_type, ex, ex_tb = sys.exc_info()
539            stage = pmtf.stage
540            message = pmtf.message
541            output = pmtf.output
542            res = TestResult(tidx['id'], tidx['name'])
543            res.set_result(ResultState.fail)
544            res.set_errormsg(pmtf.message)
545            res.set_failmsg(pmtf.output)
546            tsr.add_resultdata(res)
547            index += 1
548            print(message)
549            print('Exception {} {} (caught in test_runner, running test {} {} {} stage {})'.
550                  format(ex_type, ex, index, tidx['id'], tidx['name'], stage))
551            print('---------------')
552            print('traceback')
553            traceback.print_tb(ex_tb)
554            print('---------------')
555            if stage == 'teardown':
556                print('accumulated output for this test:')
557                if pmtf.output:
558                    print(pmtf.output)
559            print('---------------')
560            break
561        index += 1
562
563    # if we failed in setup or teardown,
564    # fill in the remaining tests with ok-skipped
565    count = index
566
567    if tcount + 1 != count:
568        for tidx in testlist[count - 1:]:
569            res = TestResult(tidx['id'], tidx['name'])
570            res.set_result(ResultState.skip)
571            msg = 'skipped - previous {} failed {} {}'.format(stage,
572                index, badtest.get('id', '--Unknown--'))
573            res.set_errormsg(msg)
574            tsr.add_resultdata(res)
575            count += 1
576
577    if args.pause:
578        print('Want to pause\nPress enter to continue ...')
579        if input(sys.stdin):
580            print('got something on stdin')
581
582    return (index, tsr)
583
584def mp_bins(alltests):
585    serial = []
586    parallel = []
587
588    for test in alltests:
589        if 'nsPlugin' not in test['plugins']:
590            serial.append(test)
591        else:
592            # We can only create one netdevsim device at a time
593            if 'netdevsim/new_device' in str(test['setup']):
594                serial.append(test)
595            else:
596                parallel.append(test)
597
598    return (serial, parallel)
599
600def __mp_runner(tests):
601    (_, tsr) = test_runner(mp_pm, mp_args, tests)
602    return tsr._testsuite
603
604def test_runner_mp(pm, args, alltests):
605    prepare_run(pm, args, alltests)
606
607    (serial, parallel) = mp_bins(alltests)
608
609    batches = [parallel[n : n + 32] for n in range(0, len(parallel), 32)]
610    batches.insert(0, serial)
611
612    print("Executing {} tests in parallel and {} in serial".format(len(parallel), len(serial)))
613    print("Using {} batches and {} workers".format(len(batches), args.mp))
614
615    # We can't pickle these objects so workaround them
616    global mp_pm
617    mp_pm = pm
618
619    global mp_args
620    mp_args = args
621
622    with Pool(args.mp) as p:
623        pres = p.map(__mp_runner, batches)
624
625    tsr = TestSuiteReport()
626    for trs in pres:
627        for res in trs:
628            tsr.add_resultdata(res)
629
630    # Passing an index is not useful in MP
631    purge_run(pm, None)
632
633    return tsr
634
635def test_runner_serial(pm, args, alltests):
636    prepare_run(pm, args, alltests)
637
638    if args.verbose:
639        print("Executing {} tests in serial".format(len(alltests)))
640
641    (index, tsr) = test_runner(pm, args, alltests)
642
643    purge_run(pm, index)
644
645    return tsr
646
647def has_blank_ids(idlist):
648    """
649    Search the list for empty ID fields and return true/false accordingly.
650    """
651    return not(all(k for k in idlist))
652
653
654def load_from_file(filename):
655    """
656    Open the JSON file containing the test cases and return them
657    as list of ordered dictionary objects.
658    """
659    try:
660        with open(filename) as test_data:
661            testlist = json.load(test_data, object_pairs_hook=OrderedDict)
662    except json.JSONDecodeError as jde:
663        print('IGNORING test case file {}\n\tBECAUSE:  {}'.format(filename, jde))
664        testlist = list()
665    else:
666        idlist = get_id_list(testlist)
667        if (has_blank_ids(idlist)):
668            for k in testlist:
669                k['filename'] = filename
670    return testlist
671
672def identity(string):
673    return string
674
675def args_parse():
676    """
677    Create the argument parser.
678    """
679    parser = argparse.ArgumentParser(description='Linux TC unit tests')
680    parser.register('type', None, identity)
681    return parser
682
683
684def set_args(parser):
685    """
686    Set the command line arguments for tdc.
687    """
688    parser.add_argument(
689        '--outfile', type=str,
690        help='Path to the file in which results should be saved. ' +
691        'Default target is the current directory.')
692    parser.add_argument(
693        '-p', '--path', type=str,
694        help='The full path to the tc executable to use')
695    sg = parser.add_argument_group(
696        'selection', 'select which test cases: ' +
697        'files plus directories; filtered by categories plus testids')
698    ag = parser.add_argument_group(
699        'action', 'select action to perform on selected test cases')
700
701    sg.add_argument(
702        '-D', '--directory', nargs='+', metavar='DIR',
703        help='Collect tests from the specified directory(ies) ' +
704        '(default [tc-tests])')
705    sg.add_argument(
706        '-f', '--file', nargs='+', metavar='FILE',
707        help='Run tests from the specified file(s)')
708    sg.add_argument(
709        '-c', '--category', nargs='*', metavar='CATG', default=['+c'],
710        help='Run tests only from the specified category/ies, ' +
711        'or if no category/ies is/are specified, list known categories.')
712    sg.add_argument(
713        '-e', '--execute', nargs='+', metavar='ID',
714        help='Execute the specified test cases with specified IDs')
715    ag.add_argument(
716        '-l', '--list', action='store_true',
717        help='List all test cases, or those only within the specified category')
718    ag.add_argument(
719        '-s', '--show', action='store_true', dest='showID',
720        help='Display the selected test cases')
721    ag.add_argument(
722        '-i', '--id', action='store_true', dest='gen_id',
723        help='Generate ID numbers for new test cases')
724    parser.add_argument(
725        '-v', '--verbose', action='count', default=0,
726        help='Show the commands that are being run')
727    parser.add_argument(
728        '--format', default='tap', const='tap', nargs='?',
729        choices=['none', 'xunit', 'tap'],
730        help='Specify the format for test results. (Default: TAP)')
731    parser.add_argument('-d', '--device',
732                        help='Execute test cases that use a physical device, ' +
733                        'where DEVICE is its name. (If not defined, tests ' +
734                        'that require a physical device will be skipped)')
735    parser.add_argument(
736        '-P', '--pause', action='store_true',
737        help='Pause execution just before post-suite stage')
738    parser.add_argument(
739        '-J', '--multiprocess', type=int, default=1, dest='mp',
740        help='Run tests in parallel whenever possible')
741    return parser
742
743
744def check_default_settings(args, remaining, pm):
745    """
746    Process any arguments overriding the default settings,
747    and ensure the settings are correct.
748    """
749    # Allow for overriding specific settings
750    global NAMES
751
752    if args.path != None:
753        NAMES['TC'] = args.path
754    if args.device != None:
755        NAMES['DEV2'] = args.device
756    if 'TIMEOUT' not in NAMES:
757        NAMES['TIMEOUT'] = None
758    if 'ETHTOOL' in NAMES and not os.path.isfile(NAMES['ETHTOOL']):
759        print(f"The specified ethtool path {NAMES['ETHTOOL']} does not exist.")
760        exit(1)
761    if not os.path.isfile(NAMES['TC']):
762        print("The specified tc path " + NAMES['TC'] + " does not exist.")
763        exit(1)
764
765    pm.call_check_args(args, remaining)
766
767
768def get_id_list(alltests):
769    """
770    Generate a list of all IDs in the test cases.
771    """
772    return [x["id"] for x in alltests]
773
774def check_case_id(alltests):
775    """
776    Check for duplicate test case IDs.
777    """
778    idl = get_id_list(alltests)
779    return [x for x in idl if idl.count(x) > 1]
780
781
782def does_id_exist(alltests, newid):
783    """
784    Check if a given ID already exists in the list of test cases.
785    """
786    idl = get_id_list(alltests)
787    return (any(newid == x for x in idl))
788
789
790def generate_case_ids(alltests):
791    """
792    If a test case has a blank ID field, generate a random hex ID for it
793    and then write the test cases back to disk.
794    """
795    for c in alltests:
796        if (c["id"] == ""):
797            while True:
798                newid = str('{:04x}'.format(random.randrange(16**4)))
799                if (does_id_exist(alltests, newid)):
800                    continue
801                else:
802                    c['id'] = newid
803                    break
804
805    ufilename = []
806    for c in alltests:
807        if ('filename' in c):
808            ufilename.append(c['filename'])
809    ufilename = get_unique_item(ufilename)
810    for f in ufilename:
811        testlist = []
812        for t in alltests:
813            if 'filename' in t:
814                if t['filename'] == f:
815                    del t['filename']
816                    testlist.append(t)
817        outfile = open(f, "w")
818        json.dump(testlist, outfile, indent=4)
819        outfile.write("\n")
820        outfile.close()
821
822def filter_tests_by_id(args, testlist):
823    '''
824    Remove tests from testlist that are not in the named id list.
825    If id list is empty, return empty list.
826    '''
827    newlist = list()
828    if testlist and args.execute:
829        target_ids = args.execute
830
831        if isinstance(target_ids, list) and (len(target_ids) > 0):
832            newlist = list(filter(lambda x: x['id'] in target_ids, testlist))
833    return newlist
834
835def filter_tests_by_category(args, testlist):
836    '''
837    Remove tests from testlist that are not in a named category.
838    '''
839    answer = list()
840    if args.category and testlist:
841        test_ids = list()
842        for catg in set(args.category):
843            if catg == '+c':
844                continue
845            print('considering category {}'.format(catg))
846            for tc in testlist:
847                if catg in tc['category'] and tc['id'] not in test_ids:
848                    answer.append(tc)
849                    test_ids.append(tc['id'])
850
851    return answer
852
853def set_random(alltests):
854    for tidx in alltests:
855        tidx['random'] = random.getrandbits(32)
856
857def get_test_cases(args):
858    """
859    If a test case file is specified, retrieve tests from that file.
860    Otherwise, glob for all json files in subdirectories and load from
861    each one.
862    Also, if requested, filter by category, and add tests matching
863    certain ids.
864    """
865    import fnmatch
866
867    flist = []
868    testdirs = ['tc-tests']
869
870    if args.file:
871        # at least one file was specified - remove the default directory
872        testdirs = []
873
874        for ff in args.file:
875            if not os.path.isfile(ff):
876                print("IGNORING file " + ff + "\n\tBECAUSE does not exist.")
877            else:
878                flist.append(os.path.abspath(ff))
879
880    if args.directory:
881        testdirs = args.directory
882
883    for testdir in testdirs:
884        for root, dirnames, filenames in os.walk(testdir):
885            for filename in fnmatch.filter(filenames, '*.json'):
886                candidate = os.path.abspath(os.path.join(root, filename))
887                if candidate not in testdirs:
888                    flist.append(candidate)
889
890    alltestcases = list()
891    for casefile in flist:
892        alltestcases = alltestcases + (load_from_file(casefile))
893
894    allcatlist = get_test_categories(alltestcases)
895    allidlist = get_id_list(alltestcases)
896
897    testcases_by_cats = get_categorized_testlist(alltestcases, allcatlist)
898    idtestcases = filter_tests_by_id(args, alltestcases)
899    cattestcases = filter_tests_by_category(args, alltestcases)
900
901    cat_ids = [x['id'] for x in cattestcases]
902    if args.execute:
903        if args.category:
904            alltestcases = cattestcases + [x for x in idtestcases if x['id'] not in cat_ids]
905        else:
906            alltestcases = idtestcases
907    else:
908        if cat_ids:
909            alltestcases = cattestcases
910        else:
911            # just accept the existing value of alltestcases,
912            # which has been filtered by file/directory
913            pass
914
915    return allcatlist, allidlist, testcases_by_cats, alltestcases
916
917
918def set_operation_mode(pm, parser, args, remaining):
919    """
920    Load the test case data and process remaining arguments to determine
921    what the script should do for this run, and call the appropriate
922    function.
923    """
924    ucat, idlist, testcases, alltests = get_test_cases(args)
925
926    if args.gen_id:
927        if (has_blank_ids(idlist)):
928            alltests = generate_case_ids(alltests)
929        else:
930            print("No empty ID fields found in test files.")
931        exit(0)
932
933    duplicate_ids = check_case_id(alltests)
934    if (len(duplicate_ids) > 0):
935        print("The following test case IDs are not unique:")
936        print(str(set(duplicate_ids)))
937        print("Please correct them before continuing.")
938        exit(1)
939
940    if args.showID:
941        for atest in alltests:
942            print_test_case(atest)
943        exit(0)
944
945    if isinstance(args.category, list) and (len(args.category) == 0):
946        print("Available categories:")
947        print_sll(ucat)
948        exit(0)
949
950    if args.list:
951        list_test_cases(alltests)
952        exit(0)
953
954    set_random(alltests)
955
956    exit_code = 0 # KSFT_PASS
957    if len(alltests):
958        req_plugins = pm.get_required_plugins(alltests)
959        try:
960            args = pm.load_required_plugins(req_plugins, parser, args, remaining)
961        except PluginDependencyException as pde:
962            print('The following plugins were not found:')
963            print('{}'.format(pde.missing_pg))
964
965        if args.mp > 1:
966            catresults = test_runner_mp(pm, args, alltests)
967        else:
968            catresults = test_runner_serial(pm, args, alltests)
969
970        if catresults.count_failures() != 0:
971            exit_code = 1 # KSFT_FAIL
972        if args.format == 'none':
973            print('Test results output suppression requested\n')
974        else:
975            print('\nAll test results: \n')
976            if args.format == 'xunit':
977                suffix = 'xml'
978                res = catresults.format_xunit()
979            elif args.format == 'tap':
980                suffix = 'tap'
981                res = catresults.format_tap()
982            print(res)
983            print('\n\n')
984            if not args.outfile:
985                fname = 'test-results.{}'.format(suffix)
986            else:
987                fname = args.outfile
988            with open(fname, 'w') as fh:
989                fh.write(res)
990                fh.close()
991                if os.getenv('SUDO_UID') is not None:
992                    os.chown(fname, uid=int(os.getenv('SUDO_UID')),
993                        gid=int(os.getenv('SUDO_GID')))
994    else:
995        print('No tests found\n')
996        exit_code = 4 # KSFT_SKIP
997    exit(exit_code)
998
999def main():
1000    """
1001    Start of execution; set up argument parser and get the arguments,
1002    and start operations.
1003    """
1004    import resource
1005
1006    if sys.version_info.major < 3 or sys.version_info.minor < 8:
1007        sys.exit("tdc requires at least python 3.8")
1008
1009    resource.setrlimit(resource.RLIMIT_NOFILE, (1048576, 1048576))
1010
1011    parser = args_parse()
1012    parser = set_args(parser)
1013    pm = PluginMgr(parser)
1014    parser = pm.call_add_args(parser)
1015    (args, remaining) = parser.parse_known_args()
1016    args.NAMES = NAMES
1017    args.mp = min(args.mp, 4)
1018    pm.set_args(args)
1019    check_default_settings(args, remaining, pm)
1020    if args.verbose > 2:
1021        print('args is {}'.format(args))
1022
1023    try:
1024        set_operation_mode(pm, parser, args, remaining)
1025    except KeyboardInterrupt:
1026        # Cleanup on Ctrl-C
1027        pm.call_post_suite(None)
1028
1029if __name__ == "__main__":
1030    main()
1031