1#!/usr/local/bin/python3 2import struct 3from ctypes import sizeof 4from enum import Enum 5from typing import List 6from typing import NamedTuple 7 8from atf_python.sys.netlink.attrs import NlAttr 9from atf_python.sys.netlink.attrs import NlAttrNested 10from atf_python.sys.netlink.base_headers import NlmAckFlags 11from atf_python.sys.netlink.base_headers import NlmNewFlags 12from atf_python.sys.netlink.base_headers import NlmGetFlags 13from atf_python.sys.netlink.base_headers import NlmDeleteFlags 14from atf_python.sys.netlink.base_headers import NlmBaseFlags 15from atf_python.sys.netlink.base_headers import Nlmsghdr 16from atf_python.sys.netlink.base_headers import NlMsgType 17from atf_python.sys.netlink.utils import align4 18from atf_python.sys.netlink.utils import enum_or_int 19from atf_python.sys.netlink.utils import get_bitmask_str 20 21 22class NlMsgCategory(Enum): 23 UNKNOWN = 0 24 GET = 1 25 NEW = 2 26 DELETE = 3 27 ACK = 4 28 29 30class NlMsgProps(NamedTuple): 31 msg: Enum 32 category: NlMsgCategory 33 34 35class BaseNetlinkMessage(object): 36 def __init__(self, helper, nlmsg_type): 37 self.nlmsg_type = enum_or_int(nlmsg_type) 38 self.nla_list = [] 39 self._orig_data = None 40 self.helper = helper 41 self.nl_hdr = Nlmsghdr( 42 nlmsg_type=self.nlmsg_type, nlmsg_seq=helper.get_seq(), nlmsg_pid=helper.pid 43 ) 44 self.base_hdr = None 45 46 def set_request(self, need_ack=True): 47 self.add_nlflags([NlmBaseFlags.NLM_F_REQUEST]) 48 if need_ack: 49 self.add_nlflags([NlmBaseFlags.NLM_F_ACK]) 50 51 def add_nlflags(self, flags: List): 52 int_flags = 0 53 for flag in flags: 54 int_flags |= enum_or_int(flag) 55 self.nl_hdr.nlmsg_flags |= int_flags 56 57 def add_nla(self, nla): 58 self.nla_list.append(nla) 59 60 def _get_nla(self, nla_list, nla_type): 61 nla_type_raw = enum_or_int(nla_type) 62 for nla in nla_list: 63 if nla.nla_type == nla_type_raw: 64 return nla 65 return None 66 67 def get_nla(self, nla_type): 68 return self._get_nla(self.nla_list, nla_type) 69 70 @staticmethod 71 def parse_nl_header(data: bytes): 72 if len(data) < sizeof(Nlmsghdr): 73 raise ValueError("length less than netlink message header") 74 return Nlmsghdr.from_buffer_copy(data), sizeof(Nlmsghdr) 75 76 def is_type(self, nlmsg_type): 77 nlmsg_type_raw = enum_or_int(nlmsg_type) 78 return nlmsg_type_raw == self.nl_hdr.nlmsg_type 79 80 def is_reply(self, hdr): 81 return hdr.nlmsg_type == NlMsgType.NLMSG_ERROR.value 82 83 @property 84 def msg_name(self): 85 return "msg#{}".format(self._get_msg_type()) 86 87 def _get_nl_category(self): 88 if self.is_reply(self.nl_hdr): 89 return NlMsgCategory.ACK 90 return NlMsgCategory.UNKNOWN 91 92 def get_nlm_flags_str(self): 93 category = self._get_nl_category() 94 flags = self.nl_hdr.nlmsg_flags 95 96 if category == NlMsgCategory.UNKNOWN: 97 return self.helper.get_bitmask_str(NlmBaseFlags, flags) 98 elif category == NlMsgCategory.GET: 99 flags_enum = NlmGetFlags 100 elif category == NlMsgCategory.NEW: 101 flags_enum = NlmNewFlags 102 elif category == NlMsgCategory.DELETE: 103 flags_enum = NlmDeleteFlags 104 elif category == NlMsgCategory.ACK: 105 flags_enum = NlmAckFlags 106 return get_bitmask_str([NlmBaseFlags, flags_enum], flags) 107 108 def print_nl_header(self, prepend=""): 109 # len=44, type=RTM_DELROUTE, flags=NLM_F_REQUEST|NLM_F_ACK, seq=1641163704, pid=0 # noqa: E501 110 hdr = self.nl_hdr 111 print( 112 "{}len={}, type={}, flags={}(0x{:X}), seq={}, pid={}".format( 113 prepend, 114 hdr.nlmsg_len, 115 self.msg_name, 116 self.get_nlm_flags_str(), 117 hdr.nlmsg_flags, 118 hdr.nlmsg_seq, 119 hdr.nlmsg_pid, 120 ) 121 ) 122 123 @classmethod 124 def from_bytes(cls, helper, data): 125 try: 126 hdr, hdrlen = BaseNetlinkMessage.parse_nl_header(data) 127 self = cls(helper, hdr.nlmsg_type) 128 self._orig_data = data 129 self.nl_hdr = hdr 130 except ValueError as e: 131 print("Failed to parse nl header: {}".format(e)) 132 cls.print_as_bytes(data) 133 raise 134 return self 135 136 def print_message(self): 137 self.print_nl_header() 138 139 @staticmethod 140 def print_as_bytes(data: bytes, descr: str): 141 print("===vv {} (len:{:3d}) vv===".format(descr, len(data))) 142 off = 0 143 step = 16 144 while off < len(data): 145 for i in range(step): 146 if off + i < len(data): 147 print(" {:02X}".format(data[off + i]), end="") 148 print("") 149 off += step 150 print("--------------------") 151 152 153class StdNetlinkMessage(BaseNetlinkMessage): 154 nl_attrs_map = {} 155 156 @classmethod 157 def from_bytes(cls, helper, data): 158 try: 159 hdr, hdrlen = BaseNetlinkMessage.parse_nl_header(data) 160 self = cls(helper, hdr.nlmsg_type) 161 self._orig_data = data 162 self.nl_hdr = hdr 163 except ValueError as e: 164 print("Failed to parse nl header: {}".format(e)) 165 cls.print_as_bytes(data) 166 raise 167 168 offset = align4(hdrlen) 169 try: 170 base_hdr, hdrlen = self.parse_base_header(data[offset:]) 171 self.base_hdr = base_hdr 172 offset += align4(hdrlen) 173 # XXX: CAP_ACK 174 except ValueError as e: 175 print("Failed to parse nl rt header: {}".format(e)) 176 cls.print_as_bytes(data) 177 raise 178 179 orig_offset = offset 180 try: 181 nla_list, nla_len = self.parse_nla_list(data[offset:]) 182 offset += nla_len 183 if offset != len(data): 184 raise ValueError( 185 "{} bytes left at the end of the packet".format(len(data) - offset) 186 ) # noqa: E501 187 self.nla_list = nla_list 188 except ValueError as e: 189 print( 190 "Failed to parse nla attributes at offset {}: {}".format(orig_offset, e) 191 ) # noqa: E501 192 cls.print_as_bytes(data, "msg dump") 193 cls.print_as_bytes(data[orig_offset:], "failed block") 194 raise 195 return self 196 197 def parse_attrs(self, data: bytes, attr_map): 198 ret = [] 199 off = 0 200 while len(data) - off >= 4: 201 nla_len, raw_nla_type = struct.unpack("@HH", data[off:off + 4]) 202 if nla_len + off > len(data): 203 raise ValueError( 204 "attr length {} > than the remaining length {}".format( 205 nla_len, len(data) - off 206 ) 207 ) 208 nla_type = raw_nla_type & 0x3FFF 209 if nla_type in attr_map: 210 v = attr_map[nla_type] 211 val = v["ad"].cls.from_bytes(data[off:off + nla_len], v["ad"].val) 212 if "child" in v: 213 # nested 214 attrs, _ = self.parse_attrs( 215 data[off + 4:off + nla_len], v["child"] 216 ) 217 val = NlAttrNested(v["ad"].val, attrs) 218 else: 219 # unknown attribute 220 val = NlAttr(raw_nla_type, data[off + 4:off + nla_len]) 221 ret.append(val) 222 off += align4(nla_len) 223 return ret, off 224 225 def parse_nla_list(self, data: bytes) -> List[NlAttr]: 226 return self.parse_attrs(data, self.nl_attrs_map) 227 228 def __bytes__(self): 229 ret = bytes() 230 for nla in self.nla_list: 231 ret += bytes(nla) 232 ret = bytes(self.base_hdr) + ret 233 self.nl_hdr.nlmsg_len = len(ret) + sizeof(Nlmsghdr) 234 return bytes(self.nl_hdr) + ret 235 236 def _get_msg_type(self): 237 return self.nl_hdr.nlmsg_type 238 239 @property 240 def msg_props(self): 241 msg_type = self._get_msg_type() 242 for msg_props in self.messages: 243 if msg_props.msg.value == msg_type: 244 return msg_props 245 return None 246 247 @property 248 def msg_name(self): 249 msg_props = self.msg_props 250 if msg_props is not None: 251 return msg_props.msg.name 252 return super().msg_name 253 254 def print_base_header(self, hdr, prepend=""): 255 pass 256 257 def print_message(self): 258 self.print_nl_header() 259 self.print_base_header(self.base_hdr, " ") 260 for nla in self.nla_list: 261 nla.print_attr(" ") 262