xref: /freebsd/tests/atf_python/sys/netlink/attrs.py (revision 7ff314380919d5b7c4b45e68fd093a51a998f845)
1import socket
2import struct
3from enum import Enum
4
5from atf_python.sys.netlink.utils import align4
6from atf_python.sys.netlink.utils import enum_or_int
7
8
9class NlAttr(object):
10    HDR_LEN = 4  # sizeof(struct nlattr)
11
12    def __init__(self, nla_type, data):
13        if isinstance(nla_type, Enum):
14            self._nla_type = nla_type.value
15            self._enum = nla_type
16        else:
17            self._nla_type = nla_type
18            self._enum = None
19        self.nla_list = []
20        self._data = data
21
22    @property
23    def nla_type(self):
24        return self._nla_type & 0x3FFF
25
26    @property
27    def nla_len(self):
28        return len(self._data) + 4
29
30    def add_nla(self, nla):
31        self.nla_list.append(nla)
32
33    def print_attr(self, prepend=""):
34        if self._enum is not None:
35            type_str = self._enum.name
36        else:
37            type_str = "nla#{}".format(self.nla_type)
38        print(
39            "{}len={} type={}({}){}".format(
40                prepend, self.nla_len, type_str, self.nla_type, self._print_attr_value()
41            )
42        )
43
44    @staticmethod
45    def _validate(data):
46        if len(data) < 4:
47            raise ValueError("attribute too short")
48        nla_len, nla_type = struct.unpack("@HH", data[:4])
49        if nla_len > len(data):
50            raise ValueError("attribute length too big")
51        if nla_len < 4:
52            raise ValueError("attribute length too short")
53
54    @classmethod
55    def _parse(cls, data):
56        nla_len, nla_type = struct.unpack("@HH", data[:4])
57        return cls(nla_type, data[4:])
58
59    @classmethod
60    def from_bytes(cls, data, attr_type_enum=None):
61        cls._validate(data)
62        attr = cls._parse(data)
63        attr._enum = attr_type_enum
64        return attr
65
66    def _to_bytes(self, data: bytes):
67        ret = data
68        if align4(len(ret)) != len(ret):
69            ret = data + bytes(align4(len(ret)) - len(ret))
70        return struct.pack("@HH", len(data) + 4, self._nla_type) + ret
71
72    def __bytes__(self):
73        return self._to_bytes(self._data)
74
75    def _print_attr_value(self):
76        return " " + " ".join(["x{:02X}".format(b) for b in self._data])
77
78
79class NlAttrNested(NlAttr):
80    def __init__(self, nla_type, val):
81        super().__init__(nla_type, b"")
82        self.nla_list = val
83
84    @property
85    def nla_len(self):
86        return align4(len(b"".join([bytes(nla) for nla in self.nla_list]))) + 4
87
88    def print_attr(self, prepend=""):
89        if self._enum is not None:
90            type_str = self._enum.name
91        else:
92            type_str = "nla#{}".format(self.nla_type)
93        print(
94            "{}len={} type={}({}) {{".format(
95                prepend, self.nla_len, type_str, self.nla_type
96            )
97        )
98        for nla in self.nla_list:
99            nla.print_attr(prepend + "  ")
100        print("{}}}".format(prepend))
101
102    def __bytes__(self):
103        return self._to_bytes(b"".join([bytes(nla) for nla in self.nla_list]))
104
105
106class NlAttrU32(NlAttr):
107    def __init__(self, nla_type, val):
108        self.u32 = enum_or_int(val)
109        super().__init__(nla_type, b"")
110
111    @property
112    def nla_len(self):
113        return 8
114
115    def _print_attr_value(self):
116        return " val={}".format(self.u32)
117
118    @staticmethod
119    def _validate(data):
120        assert len(data) == 8
121        nla_len, nla_type = struct.unpack("@HH", data[:4])
122        assert nla_len == 8
123
124    @classmethod
125    def _parse(cls, data):
126        nla_len, nla_type, val = struct.unpack("@HHI", data)
127        return cls(nla_type, val)
128
129    def __bytes__(self):
130        return self._to_bytes(struct.pack("@I", self.u32))
131
132
133class NlAttrU16(NlAttr):
134    def __init__(self, nla_type, val):
135        self.u16 = enum_or_int(val)
136        super().__init__(nla_type, b"")
137
138    @property
139    def nla_len(self):
140        return 6
141
142    def _print_attr_value(self):
143        return " val={}".format(self.u16)
144
145    @staticmethod
146    def _validate(data):
147        assert len(data) == 6
148        nla_len, nla_type = struct.unpack("@HH", data[:4])
149        assert nla_len == 6
150
151    @classmethod
152    def _parse(cls, data):
153        nla_len, nla_type, val = struct.unpack("@HHH", data)
154        return cls(nla_type, val)
155
156    def __bytes__(self):
157        return self._to_bytes(struct.pack("@H", self.u16))
158
159
160class NlAttrU8(NlAttr):
161    def __init__(self, nla_type, val):
162        self.u8 = enum_or_int(val)
163        super().__init__(nla_type, b"")
164
165    @property
166    def nla_len(self):
167        return 5
168
169    def _print_attr_value(self):
170        return " val={}".format(self.u8)
171
172    @staticmethod
173    def _validate(data):
174        assert len(data) == 5
175        nla_len, nla_type = struct.unpack("@HH", data[:4])
176        assert nla_len == 5
177
178    @classmethod
179    def _parse(cls, data):
180        nla_len, nla_type, val = struct.unpack("@HHB", data)
181        return cls(nla_type, val)
182
183    def __bytes__(self):
184        return self._to_bytes(struct.pack("@B", self.u8))
185
186
187class NlAttrIp(NlAttr):
188    def __init__(self, nla_type, addr: str):
189        super().__init__(nla_type, b"")
190        self.addr = addr
191        if ":" in self.addr:
192            self.family = socket.AF_INET6
193        else:
194            self.family = socket.AF_INET
195
196    @staticmethod
197    def _validate(data):
198        nla_len, nla_type = struct.unpack("@HH", data[:4])
199        data_len = nla_len - 4
200        if data_len != 4 and data_len != 16:
201            raise ValueError(
202                "Error validating attr {}: nla_len is not valid".format(  # noqa: E501
203                    nla_type
204                )
205            )
206
207    @property
208    def nla_len(self):
209        if self.family == socket.AF_INET6:
210            return 20
211        else:
212            return 8
213        return align4(len(self._data)) + 4
214
215    @classmethod
216    def _parse(cls, data):
217        nla_len, nla_type = struct.unpack("@HH", data[:4])
218        data_len = len(data) - 4
219        if data_len == 4:
220            addr = socket.inet_ntop(socket.AF_INET, data[4:8])
221        else:
222            addr = socket.inet_ntop(socket.AF_INET6, data[4:20])
223        return cls(nla_type, addr)
224
225    def __bytes__(self):
226        return self._to_bytes(socket.inet_pton(self.family, self.addr))
227
228    def _print_attr_value(self):
229        return " addr={}".format(self.addr)
230
231
232class NlAttrStr(NlAttr):
233    def __init__(self, nla_type, text):
234        super().__init__(nla_type, b"")
235        self.text = text
236
237    @staticmethod
238    def _validate(data):
239        NlAttr._validate(data)
240        try:
241            data[4:].decode("utf-8")
242        except Exception as e:
243            raise ValueError("wrong utf-8 string: {}".format(e))
244
245    @property
246    def nla_len(self):
247        return len(self.text) + 5
248
249    @classmethod
250    def _parse(cls, data):
251        text = data[4:-1].decode("utf-8")
252        nla_len, nla_type = struct.unpack("@HH", data[:4])
253        return cls(nla_type, text)
254
255    def __bytes__(self):
256        return self._to_bytes(bytes(self.text, encoding="utf-8") + bytes(1))
257
258    def _print_attr_value(self):
259        return ' val="{}"'.format(self.text)
260
261
262class NlAttrStrn(NlAttr):
263    def __init__(self, nla_type, text):
264        super().__init__(nla_type, b"")
265        self.text = text
266
267    @staticmethod
268    def _validate(data):
269        NlAttr._validate(data)
270        try:
271            data[4:].decode("utf-8")
272        except Exception as e:
273            raise ValueError("wrong utf-8 string: {}".format(e))
274
275    @property
276    def nla_len(self):
277        return len(self.text) + 4
278
279    @classmethod
280    def _parse(cls, data):
281        text = data[4:].decode("utf-8")
282        nla_len, nla_type = struct.unpack("@HH", data[:4])
283        return cls(nla_type, text)
284
285    def __bytes__(self):
286        return self._to_bytes(bytes(self.text, encoding="utf-8"))
287
288    def _print_attr_value(self):
289        return ' val="{}"'.format(self.text)
290