xref: /freebsd/tests/atf_python/atf_pytest.py (revision 942815c54820783d3d4f7f6faa71ab7919b5f0e5)
1import types
2from typing import Any
3from typing import Dict
4from typing import List
5from typing import NamedTuple
6from typing import Optional
7from typing import Tuple
9import pytest
10import os
13def nodeid_to_method_name(nodeid: str) -> str:
14    """file_name.py::ClassName::method_name[parametrize] -> method_name"""
15    return nodeid.split("::")[-1].split("[")[0]
18class ATFCleanupItem(pytest.Item):
19    def runtest(self):
20        """Runs cleanup procedure for the test instead of the test itself"""
21        instance = self.parent.cls()
22        cleanup_name = "cleanup_{}".format(nodeid_to_method_name(self.nodeid))
23        if hasattr(instance, cleanup_name):
24            cleanup = getattr(instance, cleanup_name)
25            cleanup(self.nodeid)
26        elif hasattr(instance, "cleanup"):
27            instance.cleanup(self.nodeid)
29    def setup_method_noop(self, method):
30        """Overrides runtest setup method"""
31        pass
33    def teardown_method_noop(self, method):
34        """Overrides runtest teardown method"""
35        pass
38class ATFTestObj(object):
39    def __init__(self, obj, has_cleanup):
40        # Use nodeid without name to properly name class-derived tests
41        self.ident = obj.nodeid.split("::", 1)[1]
42        self.description = self._get_test_description(obj)
43        self.has_cleanup = has_cleanup
44        self.obj = obj
46    def _get_test_description(self, obj):
47        """Returns first non-empty line from func docstring or func name"""
48        docstr = obj.function.__doc__
49        if docstr:
50            for line in docstr.split("\n"):
51                if line:
52                    return line
53        return obj.name
55    @staticmethod
56    def _convert_user_mark(mark, obj, ret: Dict):
57        username = mark.args[0]
58        if username == "unprivileged":
59            # Special unprivileged user requested.
60            # First, require the unprivileged-user config option presence
61            key = "require.config"
62            if key not in ret:
63                ret[key] = "unprivileged_user"
64            else:
65                ret[key] = "{} {}".format(ret[key], "unprivileged_user")
66        # Check if the framework requires root
67        test_cls = ATFHandler.get_test_class(obj)
68        if test_cls and getattr(test_cls, "NEED_ROOT", False):
69            # Yes, so we ask kyua to run us under root instead
70            # It is up to the implementation to switch back to the desired
71            # user
72            ret["require.user"] = "root"
73        else:
74            ret["require.user"] = username
77    def _convert_marks(self, obj) -> Dict[str, Any]:
78        wj_func = lambda x: " ".join(x)  # noqa: E731
79        _map: Dict[str, Dict] = {
80            "require_user": {"handler": self._convert_user_mark},
81            "require_arch": {"name": "require.arch", "fmt": wj_func},
82            "require_diskspace": {"name": "require.diskspace"},
83            "require_files": {"name": "require.files", "fmt": wj_func},
84            "require_machine": {"name": "require.machine", "fmt": wj_func},
85            "require_memory": {"name": "require.memory"},
86            "require_progs": {"name": "require.progs", "fmt": wj_func},
87            "timeout": {},
88        }
89        ret = {}
90        for mark in obj.iter_markers():
91            if mark.name in _map:
92                if "handler" in _map[mark.name]:
93                    _map[mark.name]["handler"](mark, obj, ret)
94                    continue
95                name = _map[mark.name].get("name", mark.name)
96                if "fmt" in _map[mark.name]:
97                    val = _map[mark.name]["fmt"](mark.args[0])
98                else:
99                    val = mark.args[0]
100                ret[name] = val
101        return ret
103    def as_lines(self) -> List[str]:
104        """Output test definition in ATF-specific format"""
105        ret = []
106        ret.append("ident: {}".format(self.ident))
107        ret.append("descr: {}".format(self._get_test_description(self.obj)))
108        if self.has_cleanup:
109            ret.append("has.cleanup: true")
110        for key, value in self._convert_marks(self.obj).items():
111            ret.append("{}: {}".format(key, value))
112        return ret
115class ATFHandler(object):
116    class ReportState(NamedTuple):
117        state: str
118        reason: str
120    def __init__(self, report_file_name: Optional[str]):
121        self._tests_state_map: Dict[str, ReportStatus] = {}
122        self._report_file_name = report_file_name
123        self._report_file_handle = None
125    def setup_configure(self):
126        fname = self._report_file_name
127        if fname:
128            self._report_file_handle = open(fname, mode="w")
130    def setup_method_pre(self, item):
131        """Called before actually running the test setup_method"""
132        # Check if we need to manually drop the privileges
133        for mark in item.iter_markers():
134            if mark.name == "require_user":
135                cls = self.get_test_class(item)
136                cls.TARGET_USER = mark.args[0]
137                break
139    def override_runtest(self, obj):
140        # Override basic runtest command
141        obj.runtest = types.MethodType(ATFCleanupItem.runtest, obj)
142        # Override class setup/teardown
143        obj.parent.cls.setup_method = ATFCleanupItem.setup_method_noop
144        obj.parent.cls.teardown_method = ATFCleanupItem.teardown_method_noop
146    @staticmethod
147    def get_test_class(obj):
148        if hasattr(obj, "parent") and obj.parent is not None:
149            if hasattr(obj.parent, "cls"):
150                return obj.parent.cls
152    def has_object_cleanup(self, obj):
153        cls = self.get_test_class(obj)
154        if cls is not None:
155            method_name = nodeid_to_method_name(obj.nodeid)
156            cleanup_name = "cleanup_{}".format(method_name)
157            if hasattr(cls, "cleanup") or hasattr(cls, cleanup_name):
158                return True
159        return False
161    def list_tests(self, tests: List[str]):
162        print('Content-Type: application/X-atf-tp; version="1"')
163        print()
164        for test_obj in tests:
165            has_cleanup = self.has_object_cleanup(test_obj)
166            atf_test = ATFTestObj(test_obj, has_cleanup)
167            for line in atf_test.as_lines():
168                print(line)
169            print()
171    def set_report_state(self, test_name: str, state: str, reason: str):
172        self._tests_state_map[test_name] = self.ReportState(state, reason)
174    def _extract_report_reason(self, report):
175        data = report.longrepr
176        if data is None:
177            return None
178        if isinstance(data, Tuple):
179            # ('/path/to/test.py', 23, 'Skipped: unable to test')
180            reason = data[2]
181            for prefix in "Skipped: ":
182                if reason.startswith(prefix):
183                    reason = reason[len(prefix):]
184            return reason
185        else:
186            # string/ traceback / exception report. Capture the last line
187            return str(data).split("\n")[-1]
188        return None
190    def add_report(self, report):
191        # MAP pytest report state to the atf-desired state
192        #
193        # ATF test states:
194        # (1) expected_death, (2) expected_exit, (3) expected_failure
195        # (4) expected_signal, (5) expected_timeout, (6) passed
196        # (7) skipped, (8) failed
197        #
198        # Note that ATF don't have the concept of "soft xfail" - xpass
199        # is a failure. It also calls teardown routine in a separate
200        # process, thus teardown states (pytest-only) are handled as
201        # body continuation.
203        # (stage, state, wasxfail)
205        # Just a passing test: WANT: passed
206        # GOT: (setup, passed, F), (call, passed, F), (teardown, passed, F)
207        #
208        # Failing body test: WHAT: failed
209        # GOT: (setup, passed, F), (call, failed, F), (teardown, passed, F)
210        #
211        # pytest.skip test decorator: WANT: skipped
212        # GOT: (setup,skipped, False), (teardown, passed, False)
213        #
214        # pytest.skip call inside test function: WANT: skipped
215        # GOT: (setup, passed, F), (call, skipped, F), (teardown,passed, F)
216        #
217        # mark.xfail decorator+pytest.xfail: WANT: expected_failure
218        # GOT: (setup, passed, F), (call, skipped, T), (teardown, passed, F)
219        #
220        # mark.xfail decorator+pass: WANT: failed
221        # GOT: (setup, passed, F), (call, passed, T), (teardown, passed, F)
223        test_name = report.location[2]
224        stage = report.when
225        state = report.outcome
226        reason = self._extract_report_reason(report)
228        # We don't care about strict xfail - it gets translated to False
230        if stage == "setup":
231            if state in ("skipped", "failed"):
232                # failed init -> failed test, skipped setup -> xskip
233                # for the whole test
234                self.set_report_state(test_name, state, reason)
235        elif stage == "call":
236            # "call" stage shouldn't matter if setup failed
237            if test_name in self._tests_state_map:
238                if self._tests_state_map[test_name].state == "failed":
239                    return
240            if state == "failed":
241                # Record failure  & override "skipped" state
242                self.set_report_state(test_name, state, reason)
243            elif state == "skipped":
244                if hasattr(reason, "wasxfail"):
245                    # xfail() called in the test body
246                    state = "expected_failure"
247                else:
248                    # skip inside the body
249                    pass
250                self.set_report_state(test_name, state, reason)
251            elif state == "passed":
252                if hasattr(reason, "wasxfail"):
253                    # the test was expected to fail but didn't
254                    # mark as hard failure
255                    state = "failed"
256                self.set_report_state(test_name, state, reason)
257        elif stage == "teardown":
258            if state == "failed":
259                # teardown should be empty, as the cleanup
260                # procedures should be implemented as a separate
261                # function/method, so mark teardown failure as
262                # global failure
263                self.set_report_state(test_name, state, reason)
265    def write_report(self):
266        if self._report_file_handle is None:
267            return
268        if self._tests_state_map:
269            # If we're executing in ATF mode, there has to be just one test
270            # Anyway, deterministically pick the first one
271            first_test_name = next(iter(self._tests_state_map))
272            test = self._tests_state_map[first_test_name]
273            if test.state == "passed":
274                line = test.state
275            else:
276                line = "{}: {}".format(test.state, test.reason)
277            print(line, file=self._report_file_handle)
278        self._report_file_handle.close()
280    @staticmethod
281    def get_atf_vars() -> Dict[str, str]:
282        px = "_ATF_VAR_"
283        return {k[len(px):]: v for k, v in os.environ.items() if k.startswith(px)}