1#!/usr/local/bin/python3 2import os 3import socket 4import sys 5from ctypes import c_int 6from ctypes import c_ubyte 7from ctypes import c_uint 8from ctypes import c_ushort 9from ctypes import sizeof 10from ctypes import Structure 11from enum import auto 12from enum import Enum 13 14from atf_python.sys.netlink.attrs import NlAttr 15from atf_python.sys.netlink.attrs import NlAttrStr 16from atf_python.sys.netlink.attrs import NlAttrU32 17from atf_python.sys.netlink.base_headers import GenlMsgHdr 18from atf_python.sys.netlink.base_headers import NlmBaseFlags 19from atf_python.sys.netlink.base_headers import Nlmsghdr 20from atf_python.sys.netlink.base_headers import NlMsgType 21from atf_python.sys.netlink.message import BaseNetlinkMessage 22from atf_python.sys.netlink.message import NlMsgCategory 23from atf_python.sys.netlink.message import NlMsgProps 24from atf_python.sys.netlink.message import StdNetlinkMessage 25from atf_python.sys.netlink.netlink_generic import GenlCtrlAttrType 26from atf_python.sys.netlink.netlink_generic import GenlCtrlMsgType 27from atf_python.sys.netlink.netlink_generic import handler_classes as genl_classes 28from atf_python.sys.netlink.netlink_route import handler_classes as rt_classes 29from atf_python.sys.netlink.utils import align4 30from atf_python.sys.netlink.utils import AttrDescr 31from atf_python.sys.netlink.utils import build_propmap 32from atf_python.sys.netlink.utils import enum_or_int 33from atf_python.sys.netlink.utils import get_bitmask_map 34from atf_python.sys.netlink.utils import NlConst 35from atf_python.sys.netlink.utils import prepare_attrs_map 36 37 38class SockaddrNl(Structure): 39 _fields_ = [ 40 ("nl_len", c_ubyte), 41 ("nl_family", c_ubyte), 42 ("nl_pad", c_ushort), 43 ("nl_pid", c_uint), 44 ("nl_groups", c_uint), 45 ] 46 47 48class Nlmsgdone(Structure): 49 _fields_ = [ 50 ("error", c_int), 51 ] 52 53 54class Nlmsgerr(Structure): 55 _fields_ = [ 56 ("error", c_int), 57 ("msg", Nlmsghdr), 58 ] 59 60 61class NlErrattrType(Enum): 62 NLMSGERR_ATTR_UNUSED = 0 63 NLMSGERR_ATTR_MSG = auto() 64 NLMSGERR_ATTR_OFFS = auto() 65 NLMSGERR_ATTR_COOKIE = auto() 66 NLMSGERR_ATTR_POLICY = auto() 67 68 69class AddressFamilyLinux(Enum): 70 AF_INET = socket.AF_INET 71 AF_INET6 = socket.AF_INET6 72 AF_NETLINK = 16 73 74 75class AddressFamilyBsd(Enum): 76 AF_INET = socket.AF_INET 77 AF_INET6 = socket.AF_INET6 78 AF_NETLINK = 38 79 80 81class NlHelper: 82 def __init__(self): 83 self._pmap = {} 84 self._af_cls = self.get_af_cls() 85 self._seq_counter = 1 86 self.pid = os.getpid() 87 88 def get_seq(self): 89 ret = self._seq_counter 90 self._seq_counter += 1 91 return ret 92 93 def get_af_cls(self): 94 if sys.platform.startswith("freebsd"): 95 cls = AddressFamilyBsd 96 else: 97 cls = AddressFamilyLinux 98 return cls 99 100 def get_propmap(self, cls): 101 if cls not in self._pmap: 102 self._pmap[cls] = build_propmap(cls) 103 return self._pmap[cls] 104 105 def get_name_propmap(self, cls): 106 ret = {} 107 for prop in dir(cls): 108 if not prop.startswith("_"): 109 ret[prop] = getattr(cls, prop).value 110 return ret 111 112 def get_attr_byval(self, cls, attr_val): 113 propmap = self.get_propmap(cls) 114 return propmap.get(attr_val) 115 116 def get_af_name(self, family): 117 v = self.get_attr_byval(self._af_cls, family) 118 if v is not None: 119 return v 120 return "af#{}".format(family) 121 122 def get_af_value(self, family_str: str) -> int: 123 propmap = self.get_name_propmap(self._af_cls) 124 return propmap.get(family_str) 125 126 def get_bitmask_str(self, cls, val): 127 bmap = get_bitmask_map(self.get_propmap(cls), val) 128 return ",".join([v for k, v in bmap.items()]) 129 130 @staticmethod 131 def get_bitmask_str_uncached(cls, val): 132 pmap = NlHelper.build_propmap(cls) 133 bmap = NlHelper.get_bitmask_map(pmap, val) 134 return ",".join([v for k, v in bmap.items()]) 135 136 137nldone_attrs = prepare_attrs_map([]) 138 139nlerr_attrs = prepare_attrs_map( 140 [ 141 AttrDescr(NlErrattrType.NLMSGERR_ATTR_MSG, NlAttrStr), 142 AttrDescr(NlErrattrType.NLMSGERR_ATTR_OFFS, NlAttrU32), 143 AttrDescr(NlErrattrType.NLMSGERR_ATTR_COOKIE, NlAttr), 144 ] 145) 146 147 148class NetlinkDoneMessage(StdNetlinkMessage): 149 messages = [NlMsgProps(NlMsgType.NLMSG_DONE, NlMsgCategory.ACK)] 150 nl_attrs_map = nldone_attrs 151 152 @property 153 def error_code(self): 154 return self.base_hdr.error 155 156 def parse_base_header(self, data): 157 if len(data) < sizeof(Nlmsgdone): 158 raise ValueError("length less than nlmsgdone header") 159 done_hdr = Nlmsgdone.from_buffer_copy(data) 160 sz = sizeof(Nlmsgdone) 161 return (done_hdr, sz) 162 163 def print_base_header(self, hdr, prepend=""): 164 print("{}error={}".format(prepend, hdr.error)) 165 166 167class NetlinkErrorMessage(StdNetlinkMessage): 168 messages = [NlMsgProps(NlMsgType.NLMSG_ERROR, NlMsgCategory.ACK)] 169 nl_attrs_map = nlerr_attrs 170 171 @property 172 def error_code(self): 173 return self.base_hdr.error 174 175 @property 176 def error_str(self): 177 nla = self.get_nla(NlErrattrType.NLMSGERR_ATTR_MSG) 178 if nla: 179 return nla.text 180 return None 181 182 @property 183 def error_offset(self): 184 nla = self.get_nla(NlErrattrType.NLMSGERR_ATTR_OFFS) 185 if nla: 186 return nla.u32 187 return None 188 189 @property 190 def cookie(self): 191 return self.get_nla(NlErrattrType.NLMSGERR_ATTR_COOKIE) 192 193 def parse_base_header(self, data): 194 if len(data) < sizeof(Nlmsgerr): 195 raise ValueError("length less than nlmsgerr header") 196 err_hdr = Nlmsgerr.from_buffer_copy(data) 197 sz = sizeof(Nlmsgerr) 198 if (self.nl_hdr.nlmsg_flags & 0x100) == 0: 199 sz += align4(err_hdr.msg.nlmsg_len - sizeof(Nlmsghdr)) 200 return (err_hdr, sz) 201 202 def print_base_header(self, errhdr, prepend=""): 203 print("{}error={}, ".format(prepend, errhdr.error), end="") 204 hdr = errhdr.msg 205 print( 206 "{}len={}, type={}, flags={}(0x{:X}), seq={}, pid={}".format( 207 prepend, 208 hdr.nlmsg_len, 209 "msg#{}".format(hdr.nlmsg_type), 210 self.helper.get_bitmask_str(NlmBaseFlags, hdr.nlmsg_flags), 211 hdr.nlmsg_flags, 212 hdr.nlmsg_seq, 213 hdr.nlmsg_pid, 214 ) 215 ) 216 217 218core_classes = { 219 "netlink_core": [ 220 NetlinkDoneMessage, 221 NetlinkErrorMessage, 222 ], 223} 224 225 226class Nlsock: 227 HANDLER_CLASSES = [core_classes, rt_classes, genl_classes] 228 229 def __init__(self, family, helper): 230 self.helper = helper 231 self.sock_fd = self._setup_netlink(family) 232 self._sock_family = family 233 self._data = bytes() 234 self.msgmap = self.build_msgmap() 235 self._family_map = { 236 NlConst.GENL_ID_CTRL: "nlctrl", 237 } 238 239 def build_msgmap(self): 240 handler_classes = {} 241 for d in self.HANDLER_CLASSES: 242 handler_classes.update(d) 243 xmap = {} 244 # 'family_name': [class.messages[MsgProps.msg], ] 245 for family_id, family_classes in handler_classes.items(): 246 xmap[family_id] = {} 247 for cls in family_classes: 248 for msg_props in cls.messages: 249 xmap[family_id][enum_or_int(msg_props.msg)] = cls 250 return xmap 251 252 def _setup_netlink(self, netlink_family) -> int: 253 family = self.helper.get_af_value("AF_NETLINK") 254 s = socket.socket(family, socket.SOCK_RAW, netlink_family) 255 s.setsockopt(270, 10, 1) # NETLINK_CAP_ACK 256 s.setsockopt(270, 11, 1) # NETLINK_EXT_ACK 257 return s 258 259 def set_groups(self, mask: int): 260 self.sock_fd.setsockopt(socket.SOL_SOCKET, 1, mask) 261 # snl = SockaddrNl(nl_len = sizeof(SockaddrNl), nl_family=38, 262 # nl_pid=self.pid, nl_groups=mask) 263 # xbuffer = create_string_buffer(sizeof(SockaddrNl)) 264 # memmove(xbuffer, addressof(snl), sizeof(SockaddrNl)) 265 # k = struct.pack("@BBHII", 12, 38, 0, self.pid, mask) 266 # self.sock_fd.bind(k) 267 268 def join_group(self, group_id: int): 269 self.sock_fd.setsockopt(270, 1, group_id) 270 271 def write_message(self, msg, verbose=True): 272 if verbose: 273 print("vvvvvvvv OUT vvvvvvvv") 274 msg.print_message() 275 msg_bytes = bytes(msg) 276 try: 277 ret = os.write(self.sock_fd.fileno(), msg_bytes) 278 assert ret == len(msg_bytes) 279 except Exception as e: 280 print("write({}) -> {}".format(len(msg_bytes), e)) 281 282 def parse_message(self, data: bytes): 283 if len(data) < sizeof(Nlmsghdr): 284 raise Exception("Short read from nl: {} bytes".format(len(data))) 285 hdr = Nlmsghdr.from_buffer_copy(data) 286 if hdr.nlmsg_type < 16: 287 family_name = "netlink_core" 288 nlmsg_type = hdr.nlmsg_type 289 elif self._sock_family == NlConst.NETLINK_ROUTE: 290 family_name = "netlink_route" 291 nlmsg_type = hdr.nlmsg_type 292 else: 293 # Genetlink 294 if len(data) < sizeof(Nlmsghdr) + sizeof(GenlMsgHdr): 295 raise Exception("Short read from genl: {} bytes".format(len(data))) 296 family_name = self._family_map.get(hdr.nlmsg_type, "") 297 ghdr = GenlMsgHdr.from_buffer_copy(data[sizeof(Nlmsghdr):]) 298 nlmsg_type = ghdr.cmd 299 cls = self.msgmap.get(family_name, {}).get(nlmsg_type) 300 if not cls: 301 cls = BaseNetlinkMessage 302 return cls.from_bytes(self.helper, data) 303 304 def get_genl_family_id(self, family_name): 305 hdr = Nlmsghdr( 306 nlmsg_type=NlConst.GENL_ID_CTRL, 307 nlmsg_flags=NlmBaseFlags.NLM_F_REQUEST.value, 308 nlmsg_seq=self.helper.get_seq(), 309 ) 310 ghdr = GenlMsgHdr(cmd=GenlCtrlMsgType.CTRL_CMD_GETFAMILY.value) 311 nla = NlAttrStr(GenlCtrlAttrType.CTRL_ATTR_FAMILY_NAME, family_name) 312 hdr.nlmsg_len = sizeof(Nlmsghdr) + sizeof(GenlMsgHdr) + len(bytes(nla)) 313 314 msg_bytes = bytes(hdr) + bytes(ghdr) + bytes(nla) 315 self.write_data(msg_bytes) 316 while True: 317 rx_msg = self.read_message() 318 if hdr.nlmsg_seq == rx_msg.nl_hdr.nlmsg_seq: 319 if rx_msg.is_type(NlMsgType.NLMSG_ERROR): 320 if rx_msg.error_code != 0: 321 raise ValueError("unable to get family {}".format(family_name)) 322 else: 323 family_id = rx_msg.get_nla(GenlCtrlAttrType.CTRL_ATTR_FAMILY_ID).u16 324 self._family_map[family_id] = family_name 325 return family_id 326 raise ValueError("unable to get family {}".format(family_name)) 327 328 def write_data(self, data: bytes): 329 self.sock_fd.send(data) 330 331 def read_data(self): 332 while True: 333 data = self.sock_fd.recv(65535) 334 self._data += data 335 if len(self._data) >= sizeof(Nlmsghdr): 336 break 337 338 def read_message(self) -> bytes: 339 if len(self._data) < sizeof(Nlmsghdr): 340 self.read_data() 341 hdr = Nlmsghdr.from_buffer_copy(self._data) 342 while hdr.nlmsg_len > len(self._data): 343 self.read_data() 344 raw_msg = self._data[: hdr.nlmsg_len] 345 self._data = self._data[hdr.nlmsg_len:] 346 return self.parse_message(raw_msg) 347 348 def get_reply(self, tx_msg): 349 self.write_message(tx_msg) 350 while True: 351 rx_msg = self.read_message() 352 if tx_msg.nl_hdr.nlmsg_seq == rx_msg.nl_hdr.nlmsg_seq: 353 return rx_msg 354 355 356class NetlinkMultipartIterator(object): 357 def __init__(self, obj, seq_number: int, msg_type): 358 self._obj = obj 359 self._seq = seq_number 360 self._msg_type = msg_type 361 362 def __iter__(self): 363 return self 364 365 def __next__(self): 366 msg = self._obj.read_message() 367 if self._seq != msg.nl_hdr.nlmsg_seq: 368 raise ValueError("bad sequence number") 369 if msg.is_type(NlMsgType.NLMSG_ERROR): 370 raise ValueError( 371 "error while handling multipart msg: {}".format(msg.error_code) 372 ) 373 elif msg.is_type(NlMsgType.NLMSG_DONE): 374 if msg.error_code == 0: 375 raise StopIteration 376 raise ValueError( 377 "error listing some parts of the multipart msg: {}".format( 378 msg.error_code 379 ) 380 ) 381 elif not msg.is_type(self._msg_type): 382 raise ValueError("bad message type: {}".format(msg)) 383 return msg 384 385 386class NetlinkTestTemplate(object): 387 REQUIRED_MODULES = ["netlink"] 388 389 def setup_netlink(self, netlink_family: NlConst): 390 self.helper = NlHelper() 391 self.nlsock = Nlsock(netlink_family, self.helper) 392 393 def write_message(self, msg, silent=False): 394 if not silent: 395 print("") 396 print("============= >> TX MESSAGE =============") 397 msg.print_message() 398 msg.print_as_bytes(bytes(msg), "-- DATA --") 399 self.nlsock.write_data(bytes(msg)) 400 401 def read_message(self, silent=False): 402 msg = self.nlsock.read_message() 403 if not silent: 404 print("") 405 print("============= << RX MESSAGE =============") 406 msg.print_message() 407 return msg 408 409 def get_reply(self, tx_msg): 410 self.write_message(tx_msg) 411 while True: 412 rx_msg = self.read_message() 413 if tx_msg.nl_hdr.nlmsg_seq == rx_msg.nl_hdr.nlmsg_seq: 414 return rx_msg 415 416 def read_msg_list(self, seq, msg_type): 417 return list(NetlinkMultipartIterator(self, seq, msg_type)) 418