1import logging 2import time 3from typing import NamedTuple 4 5import pytest 6from atf_python.sys.netlink.attrs import NlAttrNested 7from atf_python.sys.netlink.attrs import NlAttrStr 8from atf_python.sys.netlink.netlink import NetlinkMultipartIterator 9from atf_python.sys.netlink.netlink import NlHelper 10from atf_python.sys.netlink.netlink import Nlsock 11from atf_python.sys.netlink.netlink_generic import KtestAttrType 12from atf_python.sys.netlink.netlink_generic import KtestInfoMessage 13from atf_python.sys.netlink.netlink_generic import KtestLogMsgType 14from atf_python.sys.netlink.netlink_generic import KtestMsgAttrType 15from atf_python.sys.netlink.netlink_generic import KtestMsgType 16from atf_python.sys.netlink.netlink_generic import timespec 17from atf_python.sys.netlink.utils import NlConst 18from atf_python.utils import BaseTest 19from atf_python.utils import libc 20from atf_python.utils import nodeid_to_method_name 21 22 23datefmt = "%H:%M:%S" 24fmt = "%(asctime)s.%(msecs)03d %(filename)s:%(funcName)s:%(lineno)d %(message)s" 25logging.basicConfig(level=logging.DEBUG, format=fmt, datefmt=datefmt) 26logger = logging.getLogger("ktest") 27 28 29NETLINK_FAMILY = "ktest" 30 31 32class KtestItem(pytest.Item): 33 def __init__(self, *, descr, kcls, **kwargs): 34 super().__init__(**kwargs) 35 self.descr = descr 36 self._kcls = kcls 37 38 def runtest(self): 39 self._kcls().runtest() 40 41 42class KtestCollector(pytest.Class): 43 def collect(self): 44 obj = self.obj 45 exclude_names = set([n for n in dir(obj) if not n.startswith("_")]) 46 47 autoload = obj.KTEST_MODULE_AUTOLOAD 48 module_name = obj.KTEST_MODULE_NAME 49 loader = KtestLoader(module_name, autoload) 50 ktests = loader.load_ktests() 51 if not ktests: 52 return 53 54 orig = pytest.Class.from_parent(self.parent, name=self.name, obj=obj) 55 for py_test in orig.collect(): 56 yield py_test 57 58 for ktest in ktests: 59 name = ktest["name"] 60 descr = ktest["desc"] 61 if name in exclude_names: 62 continue 63 yield KtestItem.from_parent(self, name=name, descr=descr, kcls=obj) 64 65 66class KtestLoader(object): 67 def __init__(self, module_name: str, autoload: bool): 68 self.module_name = module_name 69 self.autoload = autoload 70 # Ensure the base ktest.ko module is loaded 71 result = libc.kldload("ktest") 72 if result != 0 and result != 17: # 17 is EEXIST (already loaded) 73 logger.debug(f"Failed to load base ktest module (error {result})") 74 self.helper = NlHelper() 75 self.nlsock = Nlsock(NlConst.NETLINK_GENERIC, self.helper) 76 self.family_id = self._get_family_id() 77 78 def _get_family_id(self): 79 try: 80 family_id = self.nlsock.get_genl_family_id(NETLINK_FAMILY) 81 except ValueError: 82 if self.autoload: 83 result = libc.kldload(self.module_name) 84 if result != 0 and result != 17: # 17 is EEXIST (already loaded) 85 raise RuntimeError(f"Failed to load kernel module '{self.module_name}' (error {result})") 86 family_id = self.nlsock.get_genl_family_id(NETLINK_FAMILY) 87 else: 88 raise 89 return family_id 90 91 def _load_ktests(self): 92 msg = KtestInfoMessage(self.helper, self.family_id, KtestMsgType.KTEST_CMD_LIST) 93 msg.set_request() 94 msg.add_nla(NlAttrStr(KtestAttrType.KTEST_ATTR_MOD_NAME, self.module_name)) 95 self.nlsock.write_message(msg, verbose=False) 96 nlmsg_seq = msg.nl_hdr.nlmsg_seq 97 98 ret = [] 99 for rx_msg in NetlinkMultipartIterator(self.nlsock, nlmsg_seq, self.family_id): 100 # rx_msg.print_message() 101 tst = { 102 "mod_name": rx_msg.get_nla(KtestAttrType.KTEST_ATTR_MOD_NAME).text, 103 "name": rx_msg.get_nla(KtestAttrType.KTEST_ATTR_TEST_NAME).text, 104 "desc": rx_msg.get_nla(KtestAttrType.KTEST_ATTR_TEST_DESCR).text, 105 } 106 ret.append(tst) 107 return ret 108 109 def load_ktests(self): 110 ret = self._load_ktests() 111 if not ret and self.autoload: 112 result = libc.kldload(self.module_name) 113 if result != 0 and result != 17: # 17 is EEXIST (already loaded) 114 raise RuntimeError(f"Failed to load kernel module '{self.module_name}' (error {result})") 115 ret = self._load_ktests() 116 return ret 117 118 119def generate_ktests(collector, name, obj): 120 if getattr(obj, "KTEST_MODULE_NAME", None) is not None: 121 return KtestCollector.from_parent(collector, name=name, obj=obj) 122 return None 123 124 125class BaseKernelTest(BaseTest): 126 KTEST_MODULE_AUTOLOAD = True 127 KTEST_MODULE_NAME = None 128 129 def _get_record_time(self, msg) -> float: 130 timespec = msg.get_nla(KtestMsgAttrType.KTEST_MSG_ATTR_TS).ts 131 epoch_ktime = timespec.tv_sec * 1.0 + timespec.tv_nsec * 1.0 / 1000000000 132 if not hasattr(self, "_start_epoch"): 133 self._start_ktime = epoch_ktime 134 self._start_time = time.time() 135 epoch_time = self._start_time 136 else: 137 epoch_time = time.time() - self._start_time + epoch_ktime 138 return epoch_time 139 140 def _log_message(self, msg): 141 # Convert syslog-type l 142 syslog_level = msg.get_nla(KtestMsgAttrType.KTEST_MSG_ATTR_LEVEL).u8 143 if syslog_level <= 6: 144 loglevel = logging.INFO 145 else: 146 loglevel = logging.DEBUG 147 rec = logging.LogRecord( 148 self.KTEST_MODULE_NAME, 149 loglevel, 150 msg.get_nla(KtestMsgAttrType.KTEST_MSG_ATTR_FILE).text, 151 msg.get_nla(KtestMsgAttrType.KTEST_MSG_ATTR_LINE).u32, 152 "%s", 153 (msg.get_nla(KtestMsgAttrType.KTEST_MSG_ATTR_TEXT).text), 154 None, 155 msg.get_nla(KtestMsgAttrType.KTEST_MSG_ATTR_FUNC).text, 156 None, 157 ) 158 rec.created = self._get_record_time(msg) 159 logger.handle(rec) 160 161 def _runtest_name(self, test_name: str, test_data): 162 module_name = self.KTEST_MODULE_NAME 163 # print("Running kernel test {} for module {}".format(test_name, module_name)) 164 helper = NlHelper() 165 nlsock = Nlsock(NlConst.NETLINK_GENERIC, helper) 166 family_id = nlsock.get_genl_family_id(NETLINK_FAMILY) 167 msg = KtestInfoMessage(helper, family_id, KtestMsgType.KTEST_CMD_RUN) 168 msg.set_request() 169 msg.add_nla(NlAttrStr(KtestAttrType.KTEST_ATTR_MOD_NAME, module_name)) 170 msg.add_nla(NlAttrStr(KtestAttrType.KTEST_ATTR_TEST_NAME, test_name)) 171 if test_data is not None: 172 msg.add_nla(NlAttrNested(KtestAttrType.KTEST_ATTR_TEST_META, test_data)) 173 nlsock.write_message(msg, verbose=False) 174 175 for log_msg in NetlinkMultipartIterator( 176 nlsock, msg.nl_hdr.nlmsg_seq, family_id 177 ): 178 self._log_message(log_msg) 179 180 def runtest(self, test_data=None): 181 self._runtest_name(nodeid_to_method_name(self.test_id), test_data) 182