1#!/usr/local/bin/python3 2import struct 3from ctypes import sizeof 4from typing import List 5 6from atf_python.sys.netlink.attrs import NlAttr 7from atf_python.sys.netlink.attrs import NlAttrNested 8from atf_python.sys.netlink.base_headers import NlmBaseFlags 9from atf_python.sys.netlink.base_headers import Nlmsghdr 10from atf_python.sys.netlink.base_headers import NlMsgType 11from atf_python.sys.netlink.utils import align4 12from atf_python.sys.netlink.utils import enum_or_int 13 14 15class BaseNetlinkMessage(object): 16 def __init__(self, helper, nlmsg_type): 17 self.nlmsg_type = enum_or_int(nlmsg_type) 18 self.nla_list = [] 19 self._orig_data = None 20 self.helper = helper 21 self.nl_hdr = Nlmsghdr( 22 nlmsg_type=self.nlmsg_type, nlmsg_seq=helper.get_seq(), nlmsg_pid=helper.pid 23 ) 24 self.base_hdr = None 25 26 def set_request(self, need_ack=True): 27 self.add_nlflags([NlmBaseFlags.NLM_F_REQUEST]) 28 if need_ack: 29 self.add_nlflags([NlmBaseFlags.NLM_F_ACK]) 30 31 def add_nlflags(self, flags: List): 32 int_flags = 0 33 for flag in flags: 34 int_flags |= enum_or_int(flag) 35 self.nl_hdr.nlmsg_flags |= int_flags 36 37 def add_nla(self, nla): 38 self.nla_list.append(nla) 39 40 def _get_nla(self, nla_list, nla_type): 41 nla_type_raw = enum_or_int(nla_type) 42 for nla in nla_list: 43 if nla.nla_type == nla_type_raw: 44 return nla 45 return None 46 47 def get_nla(self, nla_type): 48 return self._get_nla(self.nla_list, nla_type) 49 50 @staticmethod 51 def parse_nl_header(data: bytes): 52 if len(data) < sizeof(Nlmsghdr): 53 raise ValueError("length less than netlink message header") 54 return Nlmsghdr.from_buffer_copy(data), sizeof(Nlmsghdr) 55 56 def is_type(self, nlmsg_type): 57 nlmsg_type_raw = enum_or_int(nlmsg_type) 58 return nlmsg_type_raw == self.nl_hdr.nlmsg_type 59 60 def is_reply(self, hdr): 61 return hdr.nlmsg_type == NlMsgType.NLMSG_ERROR.value 62 63 def print_nl_header(self, hdr, prepend=""): 64 # len=44, type=RTM_DELROUTE, flags=NLM_F_REQUEST|NLM_F_ACK, seq=1641163704, pid=0 # noqa: E501 65 is_reply = self.is_reply(hdr) 66 msg_name = self.helper.get_nlmsg_name(hdr.nlmsg_type) 67 print( 68 "{}len={}, type={}, flags={}(0x{:X}), seq={}, pid={}".format( 69 prepend, 70 hdr.nlmsg_len, 71 msg_name, 72 self.helper.get_nlm_flags_str( 73 msg_name, is_reply, hdr.nlmsg_flags 74 ), # noqa: E501 75 hdr.nlmsg_flags, 76 hdr.nlmsg_seq, 77 hdr.nlmsg_pid, 78 ) 79 ) 80 81 @classmethod 82 def from_bytes(cls, helper, data): 83 try: 84 hdr, hdrlen = BaseNetlinkMessage.parse_nl_header(data) 85 self = cls(helper, hdr.nlmsg_type) 86 self._orig_data = data 87 self.nl_hdr = hdr 88 except ValueError as e: 89 print("Failed to parse nl header: {}".format(e)) 90 cls.print_as_bytes(data) 91 raise 92 return self 93 94 def print_message(self): 95 self.print_nl_header(self.nl_hdr) 96 97 @staticmethod 98 def print_as_bytes(data: bytes, descr: str): 99 print("===vv {} (len:{:3d}) vv===".format(descr, len(data))) 100 off = 0 101 step = 16 102 while off < len(data): 103 for i in range(step): 104 if off + i < len(data): 105 print(" {:02X}".format(data[off + i]), end="") 106 print("") 107 off += step 108 print("--------------------") 109 110 111class StdNetlinkMessage(BaseNetlinkMessage): 112 nl_attrs_map = {} 113 114 @classmethod 115 def from_bytes(cls, helper, data): 116 try: 117 hdr, hdrlen = BaseNetlinkMessage.parse_nl_header(data) 118 self = cls(helper, hdr.nlmsg_type) 119 self._orig_data = data 120 self.nl_hdr = hdr 121 except ValueError as e: 122 print("Failed to parse nl header: {}".format(e)) 123 cls.print_as_bytes(data) 124 raise 125 126 offset = align4(hdrlen) 127 try: 128 base_hdr, hdrlen = self.parse_base_header(data[offset:]) 129 self.base_hdr = base_hdr 130 offset += align4(hdrlen) 131 # XXX: CAP_ACK 132 except ValueError as e: 133 print("Failed to parse nl rt header: {}".format(e)) 134 cls.print_as_bytes(data) 135 raise 136 137 orig_offset = offset 138 try: 139 nla_list, nla_len = self.parse_nla_list(data[offset:]) 140 offset += nla_len 141 if offset != len(data): 142 raise ValueError( 143 "{} bytes left at the end of the packet".format(len(data) - offset) 144 ) # noqa: E501 145 self.nla_list = nla_list 146 except ValueError as e: 147 print( 148 "Failed to parse nla attributes at offset {}: {}".format(orig_offset, e) 149 ) # noqa: E501 150 cls.print_as_bytes(data, "msg dump") 151 cls.print_as_bytes(data[orig_offset:], "failed block") 152 raise 153 return self 154 155 def parse_attrs(self, data: bytes, attr_map): 156 ret = [] 157 off = 0 158 while len(data) - off >= 4: 159 nla_len, raw_nla_type = struct.unpack("@HH", data[off:off + 4]) 160 if nla_len + off > len(data): 161 raise ValueError( 162 "attr length {} > than the remaining length {}".format( 163 nla_len, len(data) - off 164 ) 165 ) 166 nla_type = raw_nla_type & 0x3F 167 if nla_type in attr_map: 168 v = attr_map[nla_type] 169 val = v["ad"].cls.from_bytes(data[off:off + nla_len], v["ad"].val) 170 if "child" in v: 171 # nested 172 attrs, _ = self.parse_attrs( 173 data[off + 4:off + nla_len], v["child"] 174 ) 175 val = NlAttrNested(v["ad"].val, attrs) 176 else: 177 # unknown attribute 178 val = NlAttr(raw_nla_type, data[off + 4:off + nla_len]) 179 ret.append(val) 180 off += align4(nla_len) 181 return ret, off 182 183 def parse_nla_list(self, data: bytes) -> List[NlAttr]: 184 return self.parse_attrs(data, self.nl_attrs_map) 185 186 def __bytes__(self): 187 ret = bytes() 188 for nla in self.nla_list: 189 ret += bytes(nla) 190 ret = bytes(self.base_hdr) + ret 191 self.nl_hdr.nlmsg_len = len(ret) + sizeof(Nlmsghdr) 192 return bytes(self.nl_hdr) + ret 193 194 def print_base_header(self, hdr, prepend=""): 195 pass 196 197 def print_message(self): 198 self.print_nl_header(self.nl_hdr) 199 self.print_base_header(self.base_hdr, " ") 200 for nla in self.nla_list: 201 nla.print_attr(" ") 202