xref: /freebsd/tests/atf_python/sys/netlink/message.py (revision 04a036601e10237ae00655e515aeb78762eb5d1a)
1#!/usr/local/bin/python3
2import struct
3from ctypes import sizeof
4from enum import Enum
5from typing import List
6from typing import NamedTuple
7
8from atf_python.sys.netlink.attrs import NlAttr
9from atf_python.sys.netlink.attrs import NlAttrNested
10from atf_python.sys.netlink.base_headers import NlmAckFlags
11from atf_python.sys.netlink.base_headers import NlmNewFlags
12from atf_python.sys.netlink.base_headers import NlmGetFlags
13from atf_python.sys.netlink.base_headers import NlmDeleteFlags
14from atf_python.sys.netlink.base_headers import NlmBaseFlags
15from atf_python.sys.netlink.base_headers import Nlmsghdr
16from atf_python.sys.netlink.base_headers import NlMsgType
17from atf_python.sys.netlink.utils import align4
18from atf_python.sys.netlink.utils import enum_or_int
19from atf_python.sys.netlink.utils import get_bitmask_str
20
21
22class NlMsgCategory(Enum):
23    UNKNOWN = 0
24    GET = 1
25    NEW = 2
26    DELETE = 3
27    ACK = 4
28
29
30class NlMsgProps(NamedTuple):
31    msg: Enum
32    category: NlMsgCategory
33
34
35class BaseNetlinkMessage(object):
36    def __init__(self, helper, nlmsg_type):
37        self.nlmsg_type = enum_or_int(nlmsg_type)
38        self.nla_list = []
39        self._orig_data = None
40        self.helper = helper
41        self.nl_hdr = Nlmsghdr(
42            nlmsg_type=self.nlmsg_type, nlmsg_seq=helper.get_seq(), nlmsg_pid=helper.pid
43        )
44        self.base_hdr = None
45
46    def set_request(self, need_ack=True):
47        self.add_nlflags([NlmBaseFlags.NLM_F_REQUEST])
48        if need_ack:
49            self.add_nlflags([NlmBaseFlags.NLM_F_ACK])
50
51    def add_nlflags(self, flags: List):
52        int_flags = 0
53        for flag in flags:
54            int_flags |= enum_or_int(flag)
55        self.nl_hdr.nlmsg_flags |= int_flags
56
57    def add_nla(self, nla):
58        self.nla_list.append(nla)
59
60    def _get_nla(self, nla_list, nla_type):
61        nla_type_raw = enum_or_int(nla_type)
62        for nla in nla_list:
63            if nla.nla_type == nla_type_raw:
64                return nla
65        return None
66
67    def get_nla(self, nla_type):
68        return self._get_nla(self.nla_list, nla_type)
69
70    @staticmethod
71    def parse_nl_header(data: bytes):
72        if len(data) < sizeof(Nlmsghdr):
73            raise ValueError("length less than netlink message header")
74        return Nlmsghdr.from_buffer_copy(data), sizeof(Nlmsghdr)
75
76    def is_type(self, nlmsg_type):
77        nlmsg_type_raw = enum_or_int(nlmsg_type)
78        return nlmsg_type_raw == self.nl_hdr.nlmsg_type
79
80    def is_reply(self, hdr):
81        return hdr.nlmsg_type == NlMsgType.NLMSG_ERROR.value
82
83    @property
84    def msg_name(self):
85        return "msg#{}".format(self._get_msg_type())
86
87    def _get_nl_category(self):
88        if self.is_reply(self.nl_hdr):
89            return NlMsgCategory.ACK
90        return NlMsgCategory.UNKNOWN
91
92    def get_nlm_flags_str(self):
93        category = self._get_nl_category()
94        flags = self.nl_hdr.nlmsg_flags
95
96        if category == NlMsgCategory.UNKNOWN:
97            return self.helper.get_bitmask_str(NlmBaseFlags, flags)
98        elif category == NlMsgCategory.GET:
99            flags_enum = NlmGetFlags
100        elif category == NlMsgCategory.NEW:
101            flags_enum = NlmNewFlags
102        elif category == NlMsgCategory.DELETE:
103            flags_enum = NlmDeleteFlags
104        elif category == NlMsgCategory.ACK:
105            flags_enum = NlmAckFlags
106        return get_bitmask_str([NlmBaseFlags, flags_enum], flags)
107
108    def print_nl_header(self, prepend=""):
109        # len=44, type=RTM_DELROUTE, flags=NLM_F_REQUEST|NLM_F_ACK, seq=1641163704, pid=0  # noqa: E501
110        hdr = self.nl_hdr
111        print(
112            "{}len={}, type={}, flags={}(0x{:X}), seq={}, pid={}".format(
113                prepend,
114                hdr.nlmsg_len,
115                self.msg_name,
116                self.get_nlm_flags_str(),
117                hdr.nlmsg_flags,
118                hdr.nlmsg_seq,
119                hdr.nlmsg_pid,
120            )
121        )
122
123    @classmethod
124    def from_bytes(cls, helper, data):
125        try:
126            hdr, hdrlen = BaseNetlinkMessage.parse_nl_header(data)
127            self = cls(helper, hdr.nlmsg_type)
128            self._orig_data = data
129            self.nl_hdr = hdr
130        except ValueError as e:
131            print("Failed to parse nl header: {}".format(e))
132            cls.print_as_bytes(data)
133            raise
134        return self
135
136    def print_message(self):
137        self.print_nl_header()
138
139    @staticmethod
140    def print_as_bytes(data: bytes, descr: str):
141        print("===vv {} (len:{:3d}) vv===".format(descr, len(data)))
142        off = 0
143        step = 16
144        while off < len(data):
145            for i in range(step):
146                if off + i < len(data):
147                    print(" {:02X}".format(data[off + i]), end="")
148            print("")
149            off += step
150        print("--------------------")
151
152
153class StdNetlinkMessage(BaseNetlinkMessage):
154    nl_attrs_map = {}
155
156    @classmethod
157    def from_bytes(cls, helper, data):
158        try:
159            hdr, hdrlen = BaseNetlinkMessage.parse_nl_header(data)
160            self = cls(helper, hdr.nlmsg_type)
161            self._orig_data = data
162            self.nl_hdr = hdr
163        except ValueError as e:
164            print("Failed to parse nl header: {}".format(e))
165            cls.print_as_bytes(data)
166            raise
167
168        offset = align4(hdrlen)
169        try:
170            base_hdr, hdrlen = self.parse_base_header(data[offset:])
171            self.base_hdr = base_hdr
172            offset += align4(hdrlen)
173            # XXX: CAP_ACK
174        except ValueError as e:
175            print("Failed to parse nl rt header: {}".format(e))
176            cls.print_as_bytes(data)
177            raise
178
179        orig_offset = offset
180        try:
181            nla_list, nla_len = self.parse_nla_list(data[offset:])
182            offset += nla_len
183            if offset != len(data):
184                raise ValueError(
185                    "{} bytes left at the end of the packet".format(len(data) - offset)
186                )  # noqa: E501
187            self.nla_list = nla_list
188        except ValueError as e:
189            print(
190                "Failed to parse nla attributes at offset {}: {}".format(orig_offset, e)
191            )  # noqa: E501
192            cls.print_as_bytes(data, "msg dump")
193            cls.print_as_bytes(data[orig_offset:], "failed block")
194            raise
195        return self
196
197    def parse_attrs(self, data: bytes, attr_map):
198        ret = []
199        off = 0
200        while len(data) - off >= 4:
201            nla_len, raw_nla_type = struct.unpack("@HH", data[off:off + 4])
202            if nla_len + off > len(data):
203                raise ValueError(
204                    "attr length {} > than the remaining length {}".format(
205                        nla_len, len(data) - off
206                    )
207                )
208            nla_type = raw_nla_type & 0x3FFF
209            if nla_type in attr_map:
210                v = attr_map[nla_type]
211                val = v["ad"].cls.from_bytes(data[off:off + nla_len], v["ad"].val)
212                if "child" in v:
213                    # nested
214                    attrs, _ = self.parse_attrs(
215                        data[off + 4:off + nla_len], v["child"]
216                    )
217                    val = NlAttrNested(v["ad"].val, attrs)
218            else:
219                # unknown attribute
220                val = NlAttr(raw_nla_type, data[off + 4:off + nla_len])
221            ret.append(val)
222            off += align4(nla_len)
223        return ret, off
224
225    def parse_nla_list(self, data: bytes) -> List[NlAttr]:
226        return self.parse_attrs(data, self.nl_attrs_map)
227
228    def __bytes__(self):
229        ret = bytes()
230        for nla in self.nla_list:
231            ret += bytes(nla)
232        ret = bytes(self.base_hdr) + ret
233        self.nl_hdr.nlmsg_len = len(ret) + sizeof(Nlmsghdr)
234        return bytes(self.nl_hdr) + ret
235
236    def _get_msg_type(self):
237        return self.nl_hdr.nlmsg_type
238
239    @property
240    def msg_props(self):
241        msg_type = self._get_msg_type()
242        for msg_props in self.messages:
243            if msg_props.msg.value == msg_type:
244                return msg_props
245        return None
246
247    @property
248    def msg_name(self):
249        msg_props = self.msg_props
250        if msg_props is not None:
251            return msg_props.msg.name
252        return super().msg_name
253
254    def print_base_header(self, hdr, prepend=""):
255        pass
256
257    def print_message(self):
258        self.print_nl_header()
259        self.print_base_header(self.base_hdr, " ")
260        for nla in self.nla_list:
261            nla.print_attr("  ")
262