xref: /freebsd/tests/atf_python/ktest.py (revision 43e29d03f416d7dda52112a29600a7c82ee1a91e)
1import logging
2import time
3from typing import NamedTuple
4
5import pytest
6from atf_python.sys.netlink.attrs import NlAttrNested
7from atf_python.sys.netlink.attrs import NlAttrStr
8from atf_python.sys.netlink.netlink import NetlinkMultipartIterator
9from atf_python.sys.netlink.netlink import NlHelper
10from atf_python.sys.netlink.netlink import Nlsock
11from atf_python.sys.netlink.netlink_generic import KtestAttrType
12from atf_python.sys.netlink.netlink_generic import KtestInfoMessage
13from atf_python.sys.netlink.netlink_generic import KtestLogMsgType
14from atf_python.sys.netlink.netlink_generic import KtestMsgAttrType
15from atf_python.sys.netlink.netlink_generic import KtestMsgType
16from atf_python.sys.netlink.netlink_generic import timespec
17from atf_python.sys.netlink.utils import NlConst
18from atf_python.utils import BaseTest
19from atf_python.utils import libc
20from atf_python.utils import nodeid_to_method_name
21
22
23datefmt = "%H:%M:%S"
24fmt = "%(asctime)s.%(msecs)03d %(filename)s:%(funcName)s:%(lineno)d %(message)s"
25logging.basicConfig(level=logging.DEBUG, format=fmt, datefmt=datefmt)
26logger = logging.getLogger("ktest")
27
28
29NETLINK_FAMILY = "ktest"
30
31
32class KtestItem(pytest.Item):
33    def __init__(self, *, descr, kcls, **kwargs):
34        super().__init__(**kwargs)
35        self.descr = descr
36        self._kcls = kcls
37
38    def runtest(self):
39        self._kcls().runtest()
40
41
42class KtestCollector(pytest.Class):
43    def collect(self):
44        obj = self.obj
45        exclude_names = set([n for n in dir(obj) if not n.startswith("_")])
46
47        autoload = obj.KTEST_MODULE_AUTOLOAD
48        module_name = obj.KTEST_MODULE_NAME
49        loader = KtestLoader(module_name, autoload)
50        ktests = loader.load_ktests()
51        if not ktests:
52            return
53
54        orig = pytest.Class.from_parent(self.parent, name=self.name, obj=obj)
55        for py_test in orig.collect():
56            yield py_test
57
58        for ktest in ktests:
59            name = ktest["name"]
60            descr = ktest["desc"]
61            if name in exclude_names:
62                continue
63            yield KtestItem.from_parent(self, name=name, descr=descr, kcls=obj)
64
65
66class KtestLoader(object):
67    def __init__(self, module_name: str, autoload: bool):
68        self.module_name = module_name
69        self.autoload = autoload
70        self.helper = NlHelper()
71        self.nlsock = Nlsock(NlConst.NETLINK_GENERIC, self.helper)
72        self.family_id = self._get_family_id()
73
74    def _get_family_id(self):
75        try:
76            family_id = self.nlsock.get_genl_family_id(NETLINK_FAMILY)
77        except ValueError:
78            if self.autoload:
79                libc.kldload(self.module_name)
80                family_id = self.nlsock.get_genl_family_id(NETLINK_FAMILY)
81            else:
82                raise
83        return family_id
84
85    def _load_ktests(self):
86        msg = KtestInfoMessage(self.helper, self.family_id, KtestMsgType.KTEST_CMD_LIST)
87        msg.set_request()
88        msg.add_nla(NlAttrStr(KtestAttrType.KTEST_ATTR_MOD_NAME, self.module_name))
89        self.nlsock.write_message(msg, verbose=False)
90        nlmsg_seq = msg.nl_hdr.nlmsg_seq
91
92        ret = []
93        for rx_msg in NetlinkMultipartIterator(self.nlsock, nlmsg_seq, self.family_id):
94            # rx_msg.print_message()
95            tst = {
96                "mod_name": rx_msg.get_nla(KtestAttrType.KTEST_ATTR_MOD_NAME).text,
97                "name": rx_msg.get_nla(KtestAttrType.KTEST_ATTR_TEST_NAME).text,
98                "desc": rx_msg.get_nla(KtestAttrType.KTEST_ATTR_TEST_DESCR).text,
99            }
100            ret.append(tst)
101        return ret
102
103    def load_ktests(self):
104        ret = self._load_ktests()
105        if not ret and self.autoload:
106            libc.kldload(self.module_name)
107            ret = self._load_ktests()
108        return ret
109
110
111def generate_ktests(collector, name, obj):
112    if getattr(obj, "KTEST_MODULE_NAME", None) is not None:
113        return KtestCollector.from_parent(collector, name=name, obj=obj)
114    return None
115
116
117class BaseKernelTest(BaseTest):
118    KTEST_MODULE_AUTOLOAD = True
119    KTEST_MODULE_NAME = None
120
121    def _get_record_time(self, msg) -> float:
122        timespec = msg.get_nla(KtestMsgAttrType.KTEST_MSG_ATTR_TS).ts
123        epoch_ktime = timespec.tv_sec * 1.0 + timespec.tv_nsec * 1.0 / 1000000000
124        if not hasattr(self, "_start_epoch"):
125            self._start_ktime = epoch_ktime
126            self._start_time = time.time()
127            epoch_time = self._start_time
128        else:
129            epoch_time = time.time() - self._start_time + epoch_ktime
130        return epoch_time
131
132    def _log_message(self, msg):
133        # Convert syslog-type l
134        syslog_level = msg.get_nla(KtestMsgAttrType.KTEST_MSG_ATTR_LEVEL).u8
135        if syslog_level <= 6:
136            loglevel = logging.INFO
137        else:
138            loglevel = logging.DEBUG
139        rec = logging.LogRecord(
140            self.KTEST_MODULE_NAME,
141            loglevel,
142            msg.get_nla(KtestMsgAttrType.KTEST_MSG_ATTR_FILE).text,
143            msg.get_nla(KtestMsgAttrType.KTEST_MSG_ATTR_LINE).u32,
144            "%s",
145            (msg.get_nla(KtestMsgAttrType.KTEST_MSG_ATTR_TEXT).text),
146            None,
147            msg.get_nla(KtestMsgAttrType.KTEST_MSG_ATTR_FUNC).text,
148            None,
149        )
150        rec.created = self._get_record_time(msg)
151        logger.handle(rec)
152
153    def _runtest_name(self, test_name: str, test_data):
154        module_name = self.KTEST_MODULE_NAME
155        # print("Running kernel test {} for module {}".format(test_name, module_name))
156        helper = NlHelper()
157        nlsock = Nlsock(NlConst.NETLINK_GENERIC, helper)
158        family_id = nlsock.get_genl_family_id(NETLINK_FAMILY)
159        msg = KtestInfoMessage(helper, family_id, KtestMsgType.KTEST_CMD_RUN)
160        msg.set_request()
161        msg.add_nla(NlAttrStr(KtestAttrType.KTEST_ATTR_MOD_NAME, module_name))
162        msg.add_nla(NlAttrStr(KtestAttrType.KTEST_ATTR_TEST_NAME, test_name))
163        if test_data is not None:
164            msg.add_nla(NlAttrNested(KtestAttrType.KTEST_ATTR_TEST_META, test_data))
165        nlsock.write_message(msg, verbose=False)
166
167        for log_msg in NetlinkMultipartIterator(
168            nlsock, msg.nl_hdr.nlmsg_seq, family_id
169        ):
170            self._log_message(log_msg)
171
172    def runtest(self, test_data=None):
173        self._runtest_name(nodeid_to_method_name(self.test_id), test_data)
174