xref: /freebsd/tests/atf_python/sys/netlink/attrs.py (revision a03411e84728e9b267056fd31c7d1d9d1dc1b01e)
1import socket
2import struct
3from enum import Enum
4
5from atf_python.sys.netlink.utils import align4
6from atf_python.sys.netlink.utils import enum_or_int
7
8
9class NlAttr(object):
10    HDR_LEN = 4  # sizeof(struct nlattr)
11
12    def __init__(self, nla_type, data):
13        if isinstance(nla_type, Enum):
14            self._nla_type = nla_type.value
15            self._enum = nla_type
16        else:
17            self._nla_type = nla_type
18            self._enum = None
19        self.nla_list = []
20        self._data = data
21
22    @property
23    def nla_type(self):
24        return self._nla_type & 0x3FFF
25
26    @property
27    def nla_len(self):
28        return len(self._data) + 4
29
30    def add_nla(self, nla):
31        self.nla_list.append(nla)
32
33    def print_attr(self, prepend=""):
34        if self._enum is not None:
35            type_str = self._enum.name
36        else:
37            type_str = "nla#{}".format(self.nla_type)
38        print(
39            "{}len={} type={}({}){}".format(
40                prepend, self.nla_len, type_str, self.nla_type, self._print_attr_value()
41            )
42        )
43
44    @staticmethod
45    def _validate(data):
46        if len(data) < 4:
47            raise ValueError("attribute too short")
48        nla_len, nla_type = struct.unpack("@HH", data[:4])
49        if nla_len > len(data):
50            raise ValueError("attribute length too big")
51        if nla_len < 4:
52            raise ValueError("attribute length too short")
53
54    @classmethod
55    def _parse(cls, data):
56        nla_len, nla_type = struct.unpack("@HH", data[:4])
57        return cls(nla_type, data[4:])
58
59    @classmethod
60    def from_bytes(cls, data, attr_type_enum=None):
61        cls._validate(data)
62        attr = cls._parse(data)
63        attr._enum = attr_type_enum
64        return attr
65
66    def _to_bytes(self, data: bytes):
67        ret = data
68        if align4(len(ret)) != len(ret):
69            ret = data + bytes(align4(len(ret)) - len(ret))
70        return struct.pack("@HH", len(data) + 4, self._nla_type) + ret
71
72    def __bytes__(self):
73        return self._to_bytes(self._data)
74
75    def _print_attr_value(self):
76        return " " + " ".join(["x{:02X}".format(b) for b in self._data])
77
78
79class NlAttrNested(NlAttr):
80    def __init__(self, nla_type, val):
81        super().__init__(nla_type, b"")
82        self.nla_list = val
83
84    def get_nla(self, nla_type):
85        nla_type_raw = enum_or_int(nla_type)
86        for nla in self.nla_list:
87            if nla.nla_type == nla_type_raw:
88                return nla
89        return None
90
91    @property
92    def nla_len(self):
93        return align4(len(b"".join([bytes(nla) for nla in self.nla_list]))) + 4
94
95    def print_attr(self, prepend=""):
96        if self._enum is not None:
97            type_str = self._enum.name
98        else:
99            type_str = "nla#{}".format(self.nla_type)
100        print(
101            "{}len={} type={}({}) {{".format(
102                prepend, self.nla_len, type_str, self.nla_type
103            )
104        )
105        for nla in self.nla_list:
106            nla.print_attr(prepend + "  ")
107        print("{}}}".format(prepend))
108
109    def __bytes__(self):
110        return self._to_bytes(b"".join([bytes(nla) for nla in self.nla_list]))
111
112
113class NlAttrU32(NlAttr):
114    def __init__(self, nla_type, val):
115        self.u32 = enum_or_int(val)
116        super().__init__(nla_type, b"")
117
118    @property
119    def nla_len(self):
120        return 8
121
122    def _print_attr_value(self):
123        return " val={}".format(self.u32)
124
125    @staticmethod
126    def _validate(data):
127        assert len(data) == 8
128        nla_len, nla_type = struct.unpack("@HH", data[:4])
129        assert nla_len == 8
130
131    @classmethod
132    def _parse(cls, data):
133        nla_len, nla_type, val = struct.unpack("@HHI", data)
134        return cls(nla_type, val)
135
136    def __bytes__(self):
137        return self._to_bytes(struct.pack("@I", self.u32))
138
139
140class NlAttrS32(NlAttr):
141    def __init__(self, nla_type, val):
142        self.s32 = enum_or_int(val)
143        super().__init__(nla_type, b"")
144
145    @property
146    def nla_len(self):
147        return 8
148
149    def _print_attr_value(self):
150        return " val={}".format(self.s32)
151
152    @staticmethod
153    def _validate(data):
154        assert len(data) == 8
155        nla_len, nla_type = struct.unpack("@HH", data[:4])
156        assert nla_len == 8
157
158    @classmethod
159    def _parse(cls, data):
160        nla_len, nla_type, val = struct.unpack("@HHi", data)
161        return cls(nla_type, val)
162
163    def __bytes__(self):
164        return self._to_bytes(struct.pack("@i", self.s32))
165
166
167class NlAttrU16(NlAttr):
168    def __init__(self, nla_type, val):
169        self.u16 = enum_or_int(val)
170        super().__init__(nla_type, b"")
171
172    @property
173    def nla_len(self):
174        return 6
175
176    def _print_attr_value(self):
177        return " val={}".format(self.u16)
178
179    @staticmethod
180    def _validate(data):
181        assert len(data) == 6
182        nla_len, nla_type = struct.unpack("@HH", data[:4])
183        assert nla_len == 6
184
185    @classmethod
186    def _parse(cls, data):
187        nla_len, nla_type, val = struct.unpack("@HHH", data)
188        return cls(nla_type, val)
189
190    def __bytes__(self):
191        return self._to_bytes(struct.pack("@H", self.u16))
192
193
194class NlAttrU8(NlAttr):
195    def __init__(self, nla_type, val):
196        self.u8 = enum_or_int(val)
197        super().__init__(nla_type, b"")
198
199    @property
200    def nla_len(self):
201        return 5
202
203    def _print_attr_value(self):
204        return " val={}".format(self.u8)
205
206    @staticmethod
207    def _validate(data):
208        assert len(data) == 5
209        nla_len, nla_type = struct.unpack("@HH", data[:4])
210        assert nla_len == 5
211
212    @classmethod
213    def _parse(cls, data):
214        nla_len, nla_type, val = struct.unpack("@HHB", data)
215        return cls(nla_type, val)
216
217    def __bytes__(self):
218        return self._to_bytes(struct.pack("@B", self.u8))
219
220
221class NlAttrIp(NlAttr):
222    def __init__(self, nla_type, addr: str):
223        super().__init__(nla_type, b"")
224        self.addr = addr
225        if ":" in self.addr:
226            self.family = socket.AF_INET6
227        else:
228            self.family = socket.AF_INET
229
230    @staticmethod
231    def _validate(data):
232        nla_len, nla_type = struct.unpack("@HH", data[:4])
233        data_len = nla_len - 4
234        if data_len != 4 and data_len != 16:
235            raise ValueError(
236                "Error validating attr {}: nla_len is not valid".format(  # noqa: E501
237                    nla_type
238                )
239            )
240
241    @property
242    def nla_len(self):
243        if self.family == socket.AF_INET6:
244            return 20
245        else:
246            return 8
247        return align4(len(self._data)) + 4
248
249    @classmethod
250    def _parse(cls, data):
251        nla_len, nla_type = struct.unpack("@HH", data[:4])
252        data_len = len(data) - 4
253        if data_len == 4:
254            addr = socket.inet_ntop(socket.AF_INET, data[4:8])
255        else:
256            addr = socket.inet_ntop(socket.AF_INET6, data[4:20])
257        return cls(nla_type, addr)
258
259    def __bytes__(self):
260        return self._to_bytes(socket.inet_pton(self.family, self.addr))
261
262    def _print_attr_value(self):
263        return " addr={}".format(self.addr)
264
265
266class NlAttrIp4(NlAttrIp):
267    def __init__(self, nla_type, addr: str):
268        super().__init__(nla_type, addr)
269        assert self.family == socket.AF_INET
270
271
272class NlAttrIp6(NlAttrIp):
273    def __init__(self, nla_type, addr: str):
274        super().__init__(nla_type, addr)
275        assert self.family == socket.AF_INET6
276
277
278class NlAttrStr(NlAttr):
279    def __init__(self, nla_type, text):
280        super().__init__(nla_type, b"")
281        self.text = text
282
283    @staticmethod
284    def _validate(data):
285        NlAttr._validate(data)
286        try:
287            data[4:].decode("utf-8")
288        except Exception as e:
289            raise ValueError("wrong utf-8 string: {}".format(e))
290
291    @property
292    def nla_len(self):
293        return len(self.text) + 5
294
295    @classmethod
296    def _parse(cls, data):
297        text = data[4:-1].decode("utf-8")
298        nla_len, nla_type = struct.unpack("@HH", data[:4])
299        return cls(nla_type, text)
300
301    def __bytes__(self):
302        return self._to_bytes(bytes(self.text, encoding="utf-8") + bytes(1))
303
304    def _print_attr_value(self):
305        return ' val="{}"'.format(self.text)
306
307
308class NlAttrStrn(NlAttr):
309    def __init__(self, nla_type, text):
310        super().__init__(nla_type, b"")
311        self.text = text
312
313    @staticmethod
314    def _validate(data):
315        NlAttr._validate(data)
316        try:
317            data[4:].decode("utf-8")
318        except Exception as e:
319            raise ValueError("wrong utf-8 string: {}".format(e))
320
321    @property
322    def nla_len(self):
323        return len(self.text) + 4
324
325    @classmethod
326    def _parse(cls, data):
327        text = data[4:].decode("utf-8")
328        nla_len, nla_type = struct.unpack("@HH", data[:4])
329        return cls(nla_type, text)
330
331    def __bytes__(self):
332        return self._to_bytes(bytes(self.text, encoding="utf-8"))
333
334    def _print_attr_value(self):
335        return ' val="{}"'.format(self.text)
336