xref: /freebsd/tests/atf_python/sys/netlink/attrs.py (revision 179bffddf530d20b13cbc0056c908458ddfb2ae7)
1import socket
2import struct
3from enum import Enum
4
5from atf_python.sys.netlink.utils import align4
6from atf_python.sys.netlink.utils import enum_or_int
7
8
9class NlAttr(object):
10    def __init__(self, nla_type, data):
11        if isinstance(nla_type, Enum):
12            self._nla_type = nla_type.value
13            self._enum = nla_type
14        else:
15            self._nla_type = nla_type
16            self._enum = None
17        self.nla_list = []
18        self._data = data
19
20    @property
21    def nla_type(self):
22        return self._nla_type & 0x3F
23
24    @property
25    def nla_len(self):
26        return len(self._data) + 4
27
28    def add_nla(self, nla):
29        self.nla_list.append(nla)
30
31    def print_attr(self, prepend=""):
32        if self._enum is not None:
33            type_str = self._enum.name
34        else:
35            type_str = "nla#{}".format(self.nla_type)
36        print(
37            "{}len={} type={}({}){}".format(
38                prepend, self.nla_len, type_str, self.nla_type, self._print_attr_value()
39            )
40        )
41
42    @staticmethod
43    def _validate(data):
44        if len(data) < 4:
45            raise ValueError("attribute too short")
46        nla_len, nla_type = struct.unpack("@HH", data[:4])
47        if nla_len > len(data):
48            raise ValueError("attribute length too big")
49        if nla_len < 4:
50            raise ValueError("attribute length too short")
51
52    @classmethod
53    def _parse(cls, data):
54        nla_len, nla_type = struct.unpack("@HH", data[:4])
55        return cls(nla_type, data[4:])
56
57    @classmethod
58    def from_bytes(cls, data, attr_type_enum=None):
59        cls._validate(data)
60        attr = cls._parse(data)
61        attr._enum = attr_type_enum
62        return attr
63
64    def _to_bytes(self, data: bytes):
65        ret = data
66        if align4(len(ret)) != len(ret):
67            ret = data + bytes(align4(len(ret)) - len(ret))
68        return struct.pack("@HH", len(data) + 4, self._nla_type) + ret
69
70    def __bytes__(self):
71        return self._to_bytes(self._data)
72
73    def _print_attr_value(self):
74        return " " + " ".join(["x{:02X}".format(b) for b in self._data])
75
76
77class NlAttrNested(NlAttr):
78    def __init__(self, nla_type, val):
79        super().__init__(nla_type, b"")
80        self.nla_list = val
81
82    @property
83    def nla_len(self):
84        return align4(len(b"".join([bytes(nla) for nla in self.nla_list]))) + 4
85
86    def print_attr(self, prepend=""):
87        if self._enum is not None:
88            type_str = self._enum.name
89        else:
90            type_str = "nla#{}".format(self.nla_type)
91        print(
92            "{}len={} type={}({}) {{".format(
93                prepend, self.nla_len, type_str, self.nla_type
94            )
95        )
96        for nla in self.nla_list:
97            nla.print_attr(prepend + "  ")
98        print("{}}}".format(prepend))
99
100    def __bytes__(self):
101        return self._to_bytes(b"".join([bytes(nla) for nla in self.nla_list]))
102
103
104class NlAttrU32(NlAttr):
105    def __init__(self, nla_type, val):
106        self.u32 = enum_or_int(val)
107        super().__init__(nla_type, b"")
108
109    @property
110    def nla_len(self):
111        return 8
112
113    def _print_attr_value(self):
114        return " val={}".format(self.u32)
115
116    @staticmethod
117    def _validate(data):
118        assert len(data) == 8
119        nla_len, nla_type = struct.unpack("@HH", data[:4])
120        assert nla_len == 8
121
122    @classmethod
123    def _parse(cls, data):
124        nla_len, nla_type, val = struct.unpack("@HHI", data)
125        return cls(nla_type, val)
126
127    def __bytes__(self):
128        return self._to_bytes(struct.pack("@I", self.u32))
129
130
131class NlAttrU16(NlAttr):
132    def __init__(self, nla_type, val):
133        self.u16 = enum_or_int(val)
134        super().__init__(nla_type, b"")
135
136    @property
137    def nla_len(self):
138        return 6
139
140    def _print_attr_value(self):
141        return " val={}".format(self.u16)
142
143    @staticmethod
144    def _validate(data):
145        assert len(data) == 6
146        nla_len, nla_type = struct.unpack("@HH", data[:4])
147        assert nla_len == 6
148
149    @classmethod
150    def _parse(cls, data):
151        nla_len, nla_type, val = struct.unpack("@HHH", data)
152        return cls(nla_type, val)
153
154    def __bytes__(self):
155        return self._to_bytes(struct.pack("@H", self.u16))
156
157
158class NlAttrU8(NlAttr):
159    def __init__(self, nla_type, val):
160        self.u8 = enum_or_int(val)
161        super().__init__(nla_type, b"")
162
163    @property
164    def nla_len(self):
165        return 5
166
167    def _print_attr_value(self):
168        return " val={}".format(self.u8)
169
170    @staticmethod
171    def _validate(data):
172        assert len(data) == 5
173        nla_len, nla_type = struct.unpack("@HH", data[:4])
174        assert nla_len == 5
175
176    @classmethod
177    def _parse(cls, data):
178        nla_len, nla_type, val = struct.unpack("@HHB", data)
179        return cls(nla_type, val)
180
181    def __bytes__(self):
182        return self._to_bytes(struct.pack("@B", self.u8))
183
184
185class NlAttrIp(NlAttr):
186    def __init__(self, nla_type, addr: str):
187        super().__init__(nla_type, b"")
188        self.addr = addr
189        if ":" in self.addr:
190            self.family = socket.AF_INET6
191        else:
192            self.family = socket.AF_INET
193
194    @staticmethod
195    def _validate(data):
196        nla_len, nla_type = struct.unpack("@HH", data[:4])
197        data_len = nla_len - 4
198        if data_len != 4 and data_len != 16:
199            raise ValueError(
200                "Error validating attr {}: nla_len is not valid".format(  # noqa: E501
201                    nla_type
202                )
203            )
204
205    @property
206    def nla_len(self):
207        if self.family == socket.AF_INET6:
208            return 20
209        else:
210            return 8
211        return align4(len(self._data)) + 4
212
213    @classmethod
214    def _parse(cls, data):
215        nla_len, nla_type = struct.unpack("@HH", data[:4])
216        data_len = len(data) - 4
217        if data_len == 4:
218            addr = socket.inet_ntop(socket.AF_INET, data[4:8])
219        else:
220            addr = socket.inet_ntop(socket.AF_INET6, data[4:20])
221        return cls(nla_type, addr)
222
223    def __bytes__(self):
224        return self._to_bytes(socket.inet_pton(self.family, self.addr))
225
226    def _print_attr_value(self):
227        return " addr={}".format(self.addr)
228
229
230class NlAttrStr(NlAttr):
231    def __init__(self, nla_type, text):
232        super().__init__(nla_type, b"")
233        self.text = text
234
235    @staticmethod
236    def _validate(data):
237        NlAttr._validate(data)
238        try:
239            data[4:].decode("utf-8")
240        except Exception as e:
241            raise ValueError("wrong utf-8 string: {}".format(e))
242
243    @property
244    def nla_len(self):
245        return len(self.text) + 5
246
247    @classmethod
248    def _parse(cls, data):
249        text = data[4:-1].decode("utf-8")
250        nla_len, nla_type = struct.unpack("@HH", data[:4])
251        return cls(nla_type, text)
252
253    def __bytes__(self):
254        return self._to_bytes(bytes(self.text, encoding="utf-8") + bytes(1))
255
256    def _print_attr_value(self):
257        return ' val="{}"'.format(self.text)
258
259
260class NlAttrStrn(NlAttr):
261    def __init__(self, nla_type, text):
262        super().__init__(nla_type, b"")
263        self.text = text
264
265    @staticmethod
266    def _validate(data):
267        NlAttr._validate(data)
268        try:
269            data[4:].decode("utf-8")
270        except Exception as e:
271            raise ValueError("wrong utf-8 string: {}".format(e))
272
273    @property
274    def nla_len(self):
275        return len(self.text) + 4
276
277    @classmethod
278    def _parse(cls, data):
279        text = data[4:].decode("utf-8")
280        nla_len, nla_type = struct.unpack("@HH", data[:4])
281        return cls(nla_type, text)
282
283    def __bytes__(self):
284        return self._to_bytes(bytes(self.text, encoding="utf-8"))
285
286    def _print_attr_value(self):
287        return ' val="{}"'.format(self.text)
288