xref: /freebsd/tests/atf_python/atf_pytest.py (revision d8e36cd2b10f78470c1de56337f685c10ce26ed2)
1import types
2from typing import Any
3from typing import Dict
4from typing import List
5from typing import NamedTuple
6from typing import Optional
7from typing import Tuple
8
9from atf_python.utils import nodeid_to_method_name
10
11import pytest
12import os
13
14
15class ATFCleanupItem(pytest.Item):
16    def runtest(self):
17        """Runs cleanup procedure for the test instead of the test itself"""
18        instance = self.parent.cls()
19        cleanup_name = "cleanup_{}".format(nodeid_to_method_name(self.nodeid))
20        if hasattr(instance, cleanup_name):
21            cleanup = getattr(instance, cleanup_name)
22            cleanup(self.nodeid)
23        elif hasattr(instance, "cleanup"):
24            instance.cleanup(self.nodeid)
25
26    def setup_method_noop(self, method):
27        """Overrides runtest setup method"""
28        pass
29
30    def teardown_method_noop(self, method):
31        """Overrides runtest teardown method"""
32        pass
33
34
35class ATFTestObj(object):
36    def __init__(self, obj, has_cleanup):
37        # Use nodeid without name to properly name class-derived tests
38        self.ident = obj.nodeid.split("::", 1)[1]
39        self.description = self._get_test_description(obj)
40        self.has_cleanup = has_cleanup
41        self.obj = obj
42
43    def _get_test_description(self, obj):
44        """Returns first non-empty line from func docstring or func name"""
45        docstr = obj.function.__doc__
46        if docstr:
47            for line in docstr.split("\n"):
48                if line:
49                    return line
50        return obj.name
51
52    @staticmethod
53    def _convert_user_mark(mark, obj, ret: Dict):
54        username = mark.args[0]
55        if username == "unprivileged":
56            # Special unprivileged user requested.
57            # First, require the unprivileged-user config option presence
58            key = "require.config"
59            if key not in ret:
60                ret[key] = "unprivileged_user"
61            else:
62                ret[key] = "{} {}".format(ret[key], "unprivileged_user")
63        # Check if the framework requires root
64        test_cls = ATFHandler.get_test_class(obj)
65        if test_cls and getattr(test_cls, "NEED_ROOT", False):
66            # Yes, so we ask kyua to run us under root instead
67            # It is up to the implementation to switch back to the desired
68            # user
69            ret["require.user"] = "root"
70        else:
71            ret["require.user"] = username
72
73    def _convert_marks(self, obj) -> Dict[str, Any]:
74        wj_func = lambda x: " ".join(x)  # noqa: E731
75        _map: Dict[str, Dict] = {
76            "require_user": {"handler": self._convert_user_mark},
77            "require_arch": {"name": "require.arch", "fmt": wj_func},
78            "require_diskspace": {"name": "require.diskspace"},
79            "require_files": {"name": "require.files", "fmt": wj_func},
80            "require_machine": {"name": "require.machine", "fmt": wj_func},
81            "require_memory": {"name": "require.memory"},
82            "require_progs": {"name": "require.progs", "fmt": wj_func},
83            "timeout": {},
84        }
85        ret = {}
86        for mark in obj.iter_markers():
87            if mark.name in _map:
88                if "handler" in _map[mark.name]:
89                    _map[mark.name]["handler"](mark, obj, ret)
90                    continue
91                name = _map[mark.name].get("name", mark.name)
92                if "fmt" in _map[mark.name]:
93                    val = _map[mark.name]["fmt"](mark.args[0])
94                else:
95                    val = mark.args[0]
96                ret[name] = val
97        return ret
98
99    def as_lines(self) -> List[str]:
100        """Output test definition in ATF-specific format"""
101        ret = []
102        ret.append("ident: {}".format(self.ident))
103        ret.append("descr: {}".format(self._get_test_description(self.obj)))
104        if self.has_cleanup:
105            ret.append("has.cleanup: true")
106        for key, value in self._convert_marks(self.obj).items():
107            ret.append("{}: {}".format(key, value))
108        return ret
109
110
111class ATFHandler(object):
112    class ReportState(NamedTuple):
113        state: str
114        reason: str
115
116    def __init__(self, report_file_name: Optional[str]):
117        self._tests_state_map: Dict[str, ReportStatus] = {}
118        self._report_file_name = report_file_name
119        self._report_file_handle = None
120
121    def setup_configure(self):
122        fname = self._report_file_name
123        if fname:
124            self._report_file_handle = open(fname, mode="w")
125
126    def setup_method_pre(self, item):
127        """Called before actually running the test setup_method"""
128        # Check if we need to manually drop the privileges
129        for mark in item.iter_markers():
130            if mark.name == "require_user":
131                cls = self.get_test_class(item)
132                cls.TARGET_USER = mark.args[0]
133                break
134
135    def override_runtest(self, obj):
136        # Override basic runtest command
137        obj.runtest = types.MethodType(ATFCleanupItem.runtest, obj)
138        # Override class setup/teardown
139        obj.parent.cls.setup_method = ATFCleanupItem.setup_method_noop
140        obj.parent.cls.teardown_method = ATFCleanupItem.teardown_method_noop
141
142    @staticmethod
143    def get_test_class(obj):
144        if hasattr(obj, "parent") and obj.parent is not None:
145            if hasattr(obj.parent, "cls"):
146                return obj.parent.cls
147
148    def has_object_cleanup(self, obj):
149        cls = self.get_test_class(obj)
150        if cls is not None:
151            method_name = nodeid_to_method_name(obj.nodeid)
152            cleanup_name = "cleanup_{}".format(method_name)
153            if hasattr(cls, "cleanup") or hasattr(cls, cleanup_name):
154                return True
155        return False
156
157    def _generate_test_cleanups(self, items):
158        new_items = []
159        for obj in items:
160            if self.has_object_cleanup(obj):
161                self.override_runtest(obj)
162                new_items.append(obj)
163        items.clear()
164        items.extend(new_items)
165
166    def modify_tests(self, items, config):
167        if config.option.atf_cleanup:
168            self._generate_test_cleanups(items)
169
170    def list_tests(self, tests: List[str]):
171        print('Content-Type: application/X-atf-tp; version="1"')
172        print()
173        for test_obj in tests:
174            has_cleanup = self.has_object_cleanup(test_obj)
175            atf_test = ATFTestObj(test_obj, has_cleanup)
176            for line in atf_test.as_lines():
177                print(line)
178            print()
179
180    def set_report_state(self, test_name: str, state: str, reason: str):
181        self._tests_state_map[test_name] = self.ReportState(state, reason)
182
183    def _extract_report_reason(self, report):
184        data = report.longrepr
185        if data is None:
186            return None
187        if isinstance(data, Tuple):
188            # ('/path/to/test.py', 23, 'Skipped: unable to test')
189            reason = data[2]
190            for prefix in "Skipped: ":
191                if reason.startswith(prefix):
192                    reason = reason[len(prefix):]
193            return reason
194        else:
195            # string/ traceback / exception report. Capture the last line
196            return str(data).split("\n")[-1]
197        return None
198
199    def add_report(self, report):
200        # MAP pytest report state to the atf-desired state
201        #
202        # ATF test states:
203        # (1) expected_death, (2) expected_exit, (3) expected_failure
204        # (4) expected_signal, (5) expected_timeout, (6) passed
205        # (7) skipped, (8) failed
206        #
207        # Note that ATF don't have the concept of "soft xfail" - xpass
208        # is a failure. It also calls teardown routine in a separate
209        # process, thus teardown states (pytest-only) are handled as
210        # body continuation.
211
212        # (stage, state, wasxfail)
213
214        # Just a passing test: WANT: passed
215        # GOT: (setup, passed, F), (call, passed, F), (teardown, passed, F)
216        #
217        # Failing body test: WHAT: failed
218        # GOT: (setup, passed, F), (call, failed, F), (teardown, passed, F)
219        #
220        # pytest.skip test decorator: WANT: skipped
221        # GOT: (setup,skipped, False), (teardown, passed, False)
222        #
223        # pytest.skip call inside test function: WANT: skipped
224        # GOT: (setup, passed, F), (call, skipped, F), (teardown,passed, F)
225        #
226        # mark.xfail decorator+pytest.xfail: WANT: expected_failure
227        # GOT: (setup, passed, F), (call, skipped, T), (teardown, passed, F)
228        #
229        # mark.xfail decorator+pass: WANT: failed
230        # GOT: (setup, passed, F), (call, passed, T), (teardown, passed, F)
231
232        test_name = report.location[2]
233        stage = report.when
234        state = report.outcome
235        reason = self._extract_report_reason(report)
236
237        # We don't care about strict xfail - it gets translated to False
238
239        if stage == "setup":
240            if state in ("skipped", "failed"):
241                # failed init -> failed test, skipped setup -> xskip
242                # for the whole test
243                self.set_report_state(test_name, state, reason)
244        elif stage == "call":
245            # "call" stage shouldn't matter if setup failed
246            if test_name in self._tests_state_map:
247                if self._tests_state_map[test_name].state == "failed":
248                    return
249            if state == "failed":
250                # Record failure  & override "skipped" state
251                self.set_report_state(test_name, state, reason)
252            elif state == "skipped":
253                if hasattr(reason, "wasxfail"):
254                    # xfail() called in the test body
255                    state = "expected_failure"
256                else:
257                    # skip inside the body
258                    pass
259                self.set_report_state(test_name, state, reason)
260            elif state == "passed":
261                if hasattr(reason, "wasxfail"):
262                    # the test was expected to fail but didn't
263                    # mark as hard failure
264                    state = "failed"
265                self.set_report_state(test_name, state, reason)
266        elif stage == "teardown":
267            if state == "failed":
268                # teardown should be empty, as the cleanup
269                # procedures should be implemented as a separate
270                # function/method, so mark teardown failure as
271                # global failure
272                self.set_report_state(test_name, state, reason)
273
274    def write_report(self):
275        if self._report_file_handle is None:
276            return
277        if self._tests_state_map:
278            # If we're executing in ATF mode, there has to be just one test
279            # Anyway, deterministically pick the first one
280            first_test_name = next(iter(self._tests_state_map))
281            test = self._tests_state_map[first_test_name]
282            if test.state == "passed":
283                line = test.state
284            else:
285                line = "{}: {}".format(test.state, test.reason)
286            print(line, file=self._report_file_handle)
287        self._report_file_handle.close()
288
289    @staticmethod
290    def get_atf_vars() -> Dict[str, str]:
291        px = "_ATF_VAR_"
292        return {k[len(px):]: v for k, v in os.environ.items() if k.startswith(px)}
293