1#!/usr/local/bin/python3 2import struct 3from ctypes import sizeof 4from enum import Enum 5from typing import List 6from typing import NamedTuple 7 8from atf_python.sys.netlink.attrs import NlAttr 9from atf_python.sys.netlink.attrs import NlAttrNested 10from atf_python.sys.netlink.base_headers import NlmAckFlags 11from atf_python.sys.netlink.base_headers import NlmNewFlags 12from atf_python.sys.netlink.base_headers import NlmGetFlags 13from atf_python.sys.netlink.base_headers import NlmDeleteFlags 14from atf_python.sys.netlink.base_headers import NlmBaseFlags 15from atf_python.sys.netlink.base_headers import Nlmsghdr 16from atf_python.sys.netlink.base_headers import NlMsgType 17from atf_python.sys.netlink.utils import align4 18from atf_python.sys.netlink.utils import enum_or_int 19from atf_python.sys.netlink.utils import get_bitmask_str 20 21 22class NlMsgCategory(Enum): 23 UNKNOWN = 0 24 GET = 1 25 NEW = 2 26 DELETE = 3 27 ACK = 4 28 29 30class NlMsgProps(NamedTuple): 31 msg: Enum 32 category: NlMsgCategory 33 34 35class BaseNetlinkMessage(object): 36 def __init__(self, helper, nlmsg_type): 37 self.nlmsg_type = enum_or_int(nlmsg_type) 38 self.nla_list = [] 39 self._orig_data = None 40 self.helper = helper 41 self.nl_hdr = Nlmsghdr( 42 nlmsg_type=self.nlmsg_type, nlmsg_seq=helper.get_seq(), nlmsg_pid=helper.pid 43 ) 44 self.base_hdr = None 45 46 def set_request(self, need_ack=True): 47 self.add_nlflags([NlmBaseFlags.NLM_F_REQUEST]) 48 if need_ack: 49 self.add_nlflags([NlmBaseFlags.NLM_F_ACK]) 50 51 def add_nlflags(self, flags: List): 52 int_flags = 0 53 for flag in flags: 54 int_flags |= enum_or_int(flag) 55 self.nl_hdr.nlmsg_flags |= int_flags 56 57 def add_nla(self, nla): 58 self.nla_list.append(nla) 59 60 def _get_nla(self, nla_list, nla_type): 61 nla_type_raw = enum_or_int(nla_type) 62 for nla in nla_list: 63 if nla.nla_type == nla_type_raw: 64 return nla 65 return None 66 67 def get_nla(self, nla_type): 68 return self._get_nla(self.nla_list, nla_type) 69 70 @staticmethod 71 def parse_nl_header(data: bytes): 72 if len(data) < sizeof(Nlmsghdr): 73 raise ValueError("length less than netlink message header") 74 return Nlmsghdr.from_buffer_copy(data), sizeof(Nlmsghdr) 75 76 def is_type(self, nlmsg_type): 77 nlmsg_type_raw = enum_or_int(nlmsg_type) 78 return nlmsg_type_raw == self.nl_hdr.nlmsg_type 79 80 def is_reply(self, hdr): 81 return hdr.nlmsg_type == NlMsgType.NLMSG_ERROR.value 82 83 @property 84 def msg_name(self): 85 return "msg#{}".format(self._get_msg_type()) 86 87 def _get_nl_category(self): 88 if self.is_reply(self.nl_hdr): 89 return NlMsgCategory.ACK 90 return NlMsgCategory.UNKNOWN 91 92 def get_nlm_flags_str(self): 93 category = self._get_nl_category() 94 flags = self.nl_hdr.nlmsg_flags 95 96 if category == NlMsgCategory.UNKNOWN: 97 return self.helper.get_bitmask_str(NlmBaseFlags, flags) 98 elif category == NlMsgCategory.GET: 99 flags_enum = NlmGetFlags 100 elif category == NlMsgCategory.NEW: 101 flags_enum = NlmNewFlags 102 elif category == NlMsgCategory.DELETE: 103 flags_enum = NlmDeleteFlags 104 elif category == NlMsgCategory.ACK: 105 flags_enum = NlmAckFlags 106 return get_bitmask_str([NlmBaseFlags, flags_enum], flags) 107 108 def print_nl_header(self, prepend=""): 109 # len=44, type=RTM_DELROUTE, flags=NLM_F_REQUEST|NLM_F_ACK, seq=1641163704, pid=0 # noqa: E501 110 hdr = self.nl_hdr 111 print( 112 "{}len={}, type={}, flags={}(0x{:X}), seq={}, pid={}".format( 113 prepend, 114 hdr.nlmsg_len, 115 self.msg_name, 116 self.get_nlm_flags_str(), 117 hdr.nlmsg_flags, 118 hdr.nlmsg_seq, 119 hdr.nlmsg_pid, 120 ) 121 ) 122 123 @classmethod 124 def from_bytes(cls, helper, data): 125 try: 126 hdr, hdrlen = BaseNetlinkMessage.parse_nl_header(data) 127 self = cls(helper, hdr.nlmsg_type) 128 self._orig_data = data 129 self.nl_hdr = hdr 130 except ValueError as e: 131 print("Failed to parse nl header: {}".format(e)) 132 cls.print_as_bytes(data) 133 raise 134 return self 135 136 def print_message(self): 137 self.print_nl_header() 138 139 @staticmethod 140 def print_as_bytes(data: bytes, descr: str): 141 print("===vv {} (len:{:3d}) vv===".format(descr, len(data))) 142 off = 0 143 step = 16 144 while off < len(data): 145 for i in range(step): 146 if off + i < len(data): 147 print(" {:02X}".format(data[off + i]), end="") 148 print("") 149 off += step 150 print("--------------------") 151 152 153class StdNetlinkMessage(BaseNetlinkMessage): 154 nl_attrs_map = {} 155 156 @classmethod 157 def from_bytes(cls, helper, data): 158 try: 159 hdr, hdrlen = BaseNetlinkMessage.parse_nl_header(data) 160 self = cls(helper, hdr.nlmsg_type) 161 self._orig_data = data 162 self.nl_hdr = hdr 163 except ValueError as e: 164 print("Failed to parse nl header: {}".format(e)) 165 cls.print_as_bytes(data) 166 raise 167 168 offset = align4(hdrlen) 169 try: 170 base_hdr, hdrlen = self.parse_base_header(data[offset:]) 171 self.base_hdr = base_hdr 172 offset += align4(hdrlen) 173 # XXX: CAP_ACK 174 except ValueError as e: 175 print("Failed to parse nl rt header: {}".format(e)) 176 cls.print_as_bytes(data) 177 raise 178 179 orig_offset = offset 180 try: 181 nla_list, nla_len = self.parse_nla_list(data[offset:]) 182 offset += nla_len 183 if offset != len(data): 184 raise ValueError( 185 "{} bytes left at the end of the packet".format(len(data) - offset) 186 ) # noqa: E501 187 self.nla_list = nla_list 188 except ValueError as e: 189 print( 190 "Failed to parse nla attributes at offset {}: {}".format(orig_offset, e) 191 ) # noqa: E501 192 cls.print_as_bytes(data, "msg dump") 193 cls.print_as_bytes(data[orig_offset:], "failed block") 194 raise 195 return self 196 197 def parse_child(self, data: bytes, attr_key, attr_map): 198 attrs, _ = self.parse_attrs(data, attr_map) 199 return NlAttrNested(attr_key, attrs) 200 201 def parse_child_array(self, data: bytes, attr_key, attr_map): 202 ret = [] 203 off = 0 204 while len(data) - off >= 4: 205 nla_len, raw_nla_type = struct.unpack("@HH", data[off : off + 4]) 206 if nla_len + off > len(data): 207 raise ValueError( 208 "attr length {} > than the remaining length {}".format( 209 nla_len, len(data) - off 210 ) 211 ) 212 nla_type = raw_nla_type & 0x3FFF 213 val = self.parse_child(data[off + 4 : off + nla_len], nla_type, attr_map) 214 ret.append(val) 215 off += align4(nla_len) 216 return NlAttrNested(attr_key, ret) 217 218 def parse_attrs(self, data: bytes, attr_map): 219 ret = [] 220 off = 0 221 while len(data) - off >= 4: 222 nla_len, raw_nla_type = struct.unpack("@HH", data[off : off + 4]) 223 if nla_len + off > len(data): 224 raise ValueError( 225 "attr length {} > than the remaining length {}".format( 226 nla_len, len(data) - off 227 ) 228 ) 229 nla_type = raw_nla_type & 0x3FFF 230 if nla_type in attr_map: 231 v = attr_map[nla_type] 232 val = v["ad"].cls.from_bytes(data[off : off + nla_len], v["ad"].val) 233 if "child" in v: 234 # nested 235 child_data = data[off + 4 : off + nla_len] 236 if v.get("is_array", False): 237 # Array of nested attributes 238 val = self.parse_child_array( 239 child_data, v["ad"].val, v["child"] 240 ) 241 else: 242 val = self.parse_child(child_data, v["ad"].val, v["child"]) 243 else: 244 # unknown attribute 245 val = NlAttr(raw_nla_type, data[off + 4 : off + nla_len]) 246 ret.append(val) 247 off += align4(nla_len) 248 return ret, off 249 250 def parse_nla_list(self, data: bytes) -> List[NlAttr]: 251 return self.parse_attrs(data, self.nl_attrs_map) 252 253 def __bytes__(self): 254 ret = bytes() 255 for nla in self.nla_list: 256 ret += bytes(nla) 257 ret = bytes(self.base_hdr) + ret 258 self.nl_hdr.nlmsg_len = len(ret) + sizeof(Nlmsghdr) 259 return bytes(self.nl_hdr) + ret 260 261 def _get_msg_type(self): 262 return self.nl_hdr.nlmsg_type 263 264 @property 265 def msg_props(self): 266 msg_type = self._get_msg_type() 267 for msg_props in self.messages: 268 if msg_props.msg.value == msg_type: 269 return msg_props 270 return None 271 272 @property 273 def msg_name(self): 274 msg_props = self.msg_props 275 if msg_props is not None: 276 return msg_props.msg.name 277 return super().msg_name 278 279 def print_base_header(self, hdr, prepend=""): 280 pass 281 282 def print_message(self): 283 self.print_nl_header() 284 self.print_base_header(self.base_hdr, " ") 285 for nla in self.nla_list: 286 nla.print_attr(" ") 287