1import logging 2import time 3from typing import NamedTuple 4 5import pytest 6from atf_python.sys.netlink.attrs import NlAttrNested 7from atf_python.sys.netlink.attrs import NlAttrStr 8from atf_python.sys.netlink.netlink import NetlinkMultipartIterator 9from atf_python.sys.netlink.netlink import NlHelper 10from atf_python.sys.netlink.netlink import Nlsock 11from atf_python.sys.netlink.netlink_generic import KtestAttrType 12from atf_python.sys.netlink.netlink_generic import KtestInfoMessage 13from atf_python.sys.netlink.netlink_generic import KtestLogMsgType 14from atf_python.sys.netlink.netlink_generic import KtestMsgAttrType 15from atf_python.sys.netlink.netlink_generic import KtestMsgType 16from atf_python.sys.netlink.netlink_generic import timespec 17from atf_python.sys.netlink.utils import NlConst 18from atf_python.utils import BaseTest 19from atf_python.utils import libc 20from atf_python.utils import nodeid_to_method_name 21 22 23datefmt = "%H:%M:%S" 24fmt = "%(asctime)s.%(msecs)03d %(filename)s:%(funcName)s:%(lineno)d %(message)s" 25logging.basicConfig(level=logging.DEBUG, format=fmt, datefmt=datefmt) 26logger = logging.getLogger("ktest") 27 28 29NETLINK_FAMILY = "ktest" 30 31 32class KtestItem(pytest.Item): 33 def __init__(self, *, descr, kcls, **kwargs): 34 super().__init__(**kwargs) 35 self.descr = descr 36 self._kcls = kcls 37 38 def runtest(self): 39 self._kcls().runtest() 40 41 42class KtestCollector(pytest.Class): 43 def collect(self): 44 obj = self.obj 45 exclude_names = set([n for n in dir(obj) if not n.startswith("_")]) 46 47 autoload = obj.KTEST_MODULE_AUTOLOAD 48 module_name = obj.KTEST_MODULE_NAME 49 loader = KtestLoader(module_name, autoload) 50 ktests = loader.load_ktests() 51 if not ktests: 52 return 53 54 orig = pytest.Class.from_parent(self.parent, name=self.name, obj=obj) 55 for py_test in orig.collect(): 56 yield py_test 57 58 for ktest in ktests: 59 name = ktest["name"] 60 descr = ktest["desc"] 61 if name in exclude_names: 62 continue 63 yield KtestItem.from_parent(self, name=name, descr=descr, kcls=obj) 64 65 66class KtestLoader(object): 67 def __init__(self, module_name: str, autoload: bool): 68 self.module_name = module_name 69 self.autoload = autoload 70 self.helper = NlHelper() 71 self.nlsock = Nlsock(NlConst.NETLINK_GENERIC, self.helper) 72 self.family_id = self._get_family_id() 73 74 def _get_family_id(self): 75 try: 76 family_id = self.nlsock.get_genl_family_id(NETLINK_FAMILY) 77 except ValueError: 78 if self.autoload: 79 libc.kldload(self.module_name) 80 family_id = self.nlsock.get_genl_family_id(NETLINK_FAMILY) 81 else: 82 raise 83 return family_id 84 85 def _load_ktests(self): 86 msg = KtestInfoMessage(self.helper, self.family_id, KtestMsgType.KTEST_CMD_LIST) 87 msg.set_request() 88 msg.add_nla(NlAttrStr(KtestAttrType.KTEST_ATTR_MOD_NAME, self.module_name)) 89 self.nlsock.write_message(msg, verbose=False) 90 nlmsg_seq = msg.nl_hdr.nlmsg_seq 91 92 ret = [] 93 for rx_msg in NetlinkMultipartIterator(self.nlsock, nlmsg_seq, self.family_id): 94 # rx_msg.print_message() 95 tst = { 96 "mod_name": rx_msg.get_nla(KtestAttrType.KTEST_ATTR_MOD_NAME).text, 97 "name": rx_msg.get_nla(KtestAttrType.KTEST_ATTR_TEST_NAME).text, 98 "desc": rx_msg.get_nla(KtestAttrType.KTEST_ATTR_TEST_DESCR).text, 99 } 100 ret.append(tst) 101 return ret 102 103 def load_ktests(self): 104 ret = self._load_ktests() 105 if not ret and self.autoload: 106 libc.kldload(self.module_name) 107 ret = self._load_ktests() 108 return ret 109 110 111def generate_ktests(collector, name, obj): 112 if getattr(obj, "KTEST_MODULE_NAME", None) is not None: 113 return KtestCollector.from_parent(collector, name=name, obj=obj) 114 return None 115 116 117class BaseKernelTest(BaseTest): 118 KTEST_MODULE_AUTOLOAD = True 119 KTEST_MODULE_NAME = None 120 121 def _get_record_time(self, msg) -> float: 122 timespec = msg.get_nla(KtestMsgAttrType.KTEST_MSG_ATTR_TS).ts 123 epoch_ktime = timespec.tv_sec * 1.0 + timespec.tv_nsec * 1.0 / 1000000000 124 if not hasattr(self, "_start_epoch"): 125 self._start_ktime = epoch_ktime 126 self._start_time = time.time() 127 epoch_time = self._start_time 128 else: 129 epoch_time = time.time() - self._start_time + epoch_ktime 130 return epoch_time 131 132 def _log_message(self, msg): 133 # Convert syslog-type l 134 syslog_level = msg.get_nla(KtestMsgAttrType.KTEST_MSG_ATTR_LEVEL).u8 135 if syslog_level <= 6: 136 loglevel = logging.INFO 137 else: 138 loglevel = logging.DEBUG 139 rec = logging.LogRecord( 140 self.KTEST_MODULE_NAME, 141 loglevel, 142 msg.get_nla(KtestMsgAttrType.KTEST_MSG_ATTR_FILE).text, 143 msg.get_nla(KtestMsgAttrType.KTEST_MSG_ATTR_LINE).u32, 144 "%s", 145 (msg.get_nla(KtestMsgAttrType.KTEST_MSG_ATTR_TEXT).text), 146 None, 147 msg.get_nla(KtestMsgAttrType.KTEST_MSG_ATTR_FUNC).text, 148 None, 149 ) 150 rec.created = self._get_record_time(msg) 151 logger.handle(rec) 152 153 def _runtest_name(self, test_name: str, test_data): 154 module_name = self.KTEST_MODULE_NAME 155 # print("Running kernel test {} for module {}".format(test_name, module_name)) 156 helper = NlHelper() 157 nlsock = Nlsock(NlConst.NETLINK_GENERIC, helper) 158 family_id = nlsock.get_genl_family_id(NETLINK_FAMILY) 159 msg = KtestInfoMessage(helper, family_id, KtestMsgType.KTEST_CMD_RUN) 160 msg.set_request() 161 msg.add_nla(NlAttrStr(KtestAttrType.KTEST_ATTR_MOD_NAME, module_name)) 162 msg.add_nla(NlAttrStr(KtestAttrType.KTEST_ATTR_TEST_NAME, test_name)) 163 if test_data is not None: 164 msg.add_nla(NlAttrNested(KtestAttrType.KTEST_ATTR_TEST_META, test_data)) 165 nlsock.write_message(msg, verbose=False) 166 167 for log_msg in NetlinkMultipartIterator( 168 nlsock, msg.nl_hdr.nlmsg_seq, family_id 169 ): 170 self._log_message(log_msg) 171 172 def runtest(self, test_data=None): 173 self._runtest_name(nodeid_to_method_name(self.test_id), test_data) 174