xref: /freebsd/tests/atf_python/sys/netpfil/ipfw/ioctl.py (revision 32cd3ee5901ea33d41ff550e5f40ce743c8d4165)
1#!/usr/bin/env python3
2import os
3import socket
4import struct
5import subprocess
6import sys
7from ctypes import c_byte
8from ctypes import c_char
9from ctypes import c_int
10from ctypes import c_long
11from ctypes import c_uint32
12from ctypes import c_uint8
13from ctypes import c_ulong
14from ctypes import c_ushort
15from ctypes import sizeof
16from ctypes import Structure
17from enum import Enum
18from typing import Any
19from typing import Dict
20from typing import List
21from typing import NamedTuple
22from typing import Optional
23from typing import Union
24
25import pytest
26from atf_python.sys.netpfil.ipfw.insns import BaseInsn
27from atf_python.sys.netpfil.ipfw.insns import insn_attrs
28from atf_python.sys.netpfil.ipfw.ioctl_headers import IpFwTableLookupType
29from atf_python.sys.netpfil.ipfw.ioctl_headers import IpFwTableValueType
30from atf_python.sys.netpfil.ipfw.ioctl_headers import IpFwTlvType
31from atf_python.sys.netpfil.ipfw.ioctl_headers import Op3CmdType
32from atf_python.sys.netpfil.ipfw.utils import align8
33from atf_python.sys.netpfil.ipfw.utils import AttrDescr
34from atf_python.sys.netpfil.ipfw.utils import enum_from_int
35from atf_python.sys.netpfil.ipfw.utils import prepare_attrs_map
36
37
38class IpFw3OpHeader(Structure):
39    _fields_ = [
40        ("opcode", c_ushort),
41        ("version", c_ushort),
42        ("reserved1", c_ushort),
43        ("reserved2", c_ushort),
44    ]
45
46
47class IpFwObjTlv(Structure):
48    _fields_ = [
49        ("n_type", c_ushort),
50        ("flags", c_ushort),
51        ("length", c_uint32),
52    ]
53
54
55class BaseTlv(object):
56    obj_enum_class = IpFwTlvType
57
58    def __init__(self, obj_type):
59        if isinstance(obj_type, Enum):
60            self.obj_type = obj_type.value
61            self._enum = obj_type
62        else:
63            self.obj_type = obj_type
64            self._enum = enum_from_int(self.obj_enum_class, obj_type)
65        self.obj_list = []
66
67    def add_obj(self, obj):
68        self.obj_list.append(obj)
69
70    @property
71    def len(self):
72        return len(bytes(self))
73
74    @property
75    def obj_name(self):
76        if self._enum is not None:
77            return self._enum.name
78        else:
79            return "tlv#{}".format(self.obj_type)
80
81    def print_hdr(self, prepend=""):
82        print(
83            "{}len={} type={}({}){}".format(
84                prepend, self.len, self.obj_name, self.obj_type, self._print_obj_value()
85            )
86        )
87
88    def print_obj(self, prepend=""):
89        self.print_hdr(prepend)
90        prepend = "  " + prepend
91        for obj in self.obj_list:
92            obj.print_obj(prepend)
93
94    def print_obj_hex(self, prepend=""):
95        print(prepend)
96        print()
97        print(" ".join(["x{:02X}".format(b) for b in bytes(self)]))
98
99    @classmethod
100    def _validate(cls, data):
101        if len(data) < sizeof(IpFwObjTlv):
102            raise ValueError("TLV too short")
103        hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)])
104        if len(data) != hdr.length:
105            raise ValueError("wrong TLV size")
106
107    @classmethod
108    def _parse(cls, data, attr_map):
109        hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)])
110        return cls(hdr.n_type)
111
112    @classmethod
113    def from_bytes(cls, data, attr_map=None):
114        cls._validate(data)
115        obj = cls._parse(data, attr_map)
116        return obj
117
118    def __bytes__(self):
119        raise NotImplementedError()
120
121    def _print_obj_value(self):
122        return " " + " ".join(
123            ["x{:02X}".format(b) for b in self._data[sizeof(IpFwObjTlv) :]]
124        )
125
126    def as_hexdump(self):
127        return " ".join(["x{:02X}".format(b) for b in bytes(self)])
128
129
130class UnknownTlv(BaseTlv):
131    def __init__(self, obj_type, data):
132        super().__init__(obj_type)
133        self._data = data
134
135    @classmethod
136    def _validate(cls, data):
137        if len(data) < sizeof(IpFwObjNTlv):
138            raise ValueError("TLV size is too short")
139        hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)])
140        if len(data) != hdr.length:
141            raise ValueError("wrong TLV size")
142
143    @classmethod
144    def _parse(cls, data, attr_map):
145        hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)])
146        self = cls(hdr.n_type, data)
147        return self
148
149    def __bytes__(self):
150        return self._data
151
152
153class Tlv(BaseTlv):
154    @staticmethod
155    def parse_tlvs(data, attr_map):
156        # print("PARSING " + " ".join(["x{:02X}".format(b) for b in data]))
157        off = 0
158        ret = []
159        while off + sizeof(IpFwObjTlv) <= len(data):
160            hdr = IpFwObjTlv.from_buffer_copy(data[off : off + sizeof(IpFwObjTlv)])
161            if off + hdr.length > len(data):
162                raise ValueError("TLV size do not match")
163            obj_data = data[off : off + hdr.length]
164            obj_descr = attr_map.get(hdr.n_type, None)
165            if obj_descr is None:
166                # raise ValueError("unknown child TLV {}".format(hdr.n_type))
167                cls = UnknownTlv
168                child_map = {}
169            else:
170                cls = obj_descr["ad"].cls
171                child_map = obj_descr.get("child", {})
172            # print("FOUND OBJECT type {}".format(cls))
173            # print()
174            obj = cls.from_bytes(obj_data, child_map)
175            ret.append(obj)
176            off += hdr.length
177        return ret
178
179
180class IpFwObjNTlv(Structure):
181    _fields_ = [
182        ("head", IpFwObjTlv),
183        ("idx", c_ushort),
184        ("n_set", c_uint8),
185        ("n_type", c_uint8),
186        ("spare", c_uint32),
187        ("name", c_char * 64),
188    ]
189
190
191class NTlv(Tlv):
192    def __init__(self, obj_type, idx=0, n_set=0, n_type=0, name=None):
193        super().__init__(obj_type)
194        self.n_idx = idx
195        self.n_set = n_set
196        self.n_type = n_type
197        self.n_name = name
198
199    @classmethod
200    def _validate(cls, data):
201        if len(data) != sizeof(IpFwObjNTlv):
202            raise ValueError("TLV size is not correct")
203        hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)])
204        if len(data) != hdr.length:
205            raise ValueError("wrong TLV size")
206
207    @classmethod
208    def _parse(cls, data, attr_map):
209        hdr = IpFwObjNTlv.from_buffer_copy(data[: sizeof(IpFwObjNTlv)])
210        name = hdr.name.decode("utf-8")
211        self = cls(hdr.head.n_type, hdr.idx, hdr.n_set, hdr.n_type, name)
212        return self
213
214    def __bytes__(self):
215        name_bytes = self.n_name.encode("utf-8")
216        if len(name_bytes) < 64:
217            name_bytes += b"\0" * (64 - len(name_bytes))
218        hdr = IpFwObjNTlv(
219            head=IpFwObjTlv(n_type=self.obj_type, length=sizeof(IpFwObjNTlv)),
220            idx=self.n_idx,
221            n_set=self.n_set,
222            n_type=self.n_type,
223            name=name_bytes[:64],
224        )
225        return bytes(hdr)
226
227    def _print_obj_value(self):
228        return " " + "type={} set={} idx={} name={}".format(
229            self.n_type, self.n_set, self.n_idx, self.n_name
230        )
231
232
233class IpFwObjCTlv(Structure):
234    _fields_ = [
235        ("head", IpFwObjTlv),
236        ("count", c_uint32),
237        ("objsize", c_ushort),
238        ("version", c_uint8),
239        ("flags", c_uint8),
240    ]
241
242
243class CTlv(Tlv):
244    def __init__(self, obj_type, obj_list=[]):
245        super().__init__(obj_type)
246        if obj_list:
247            self.obj_list.extend(obj_list)
248
249    @classmethod
250    def _validate(cls, data):
251        if len(data) < sizeof(IpFwObjCTlv):
252            raise ValueError("TLV too short")
253        hdr = IpFwObjCTlv.from_buffer_copy(data[: sizeof(IpFwObjCTlv)])
254        if len(data) != hdr.head.length:
255            raise ValueError("wrong TLV size")
256
257    @classmethod
258    def _parse(cls, data, attr_map):
259        hdr = IpFwObjCTlv.from_buffer_copy(data[: sizeof(IpFwObjCTlv)])
260        tlv_list = cls.parse_tlvs(data[sizeof(IpFwObjCTlv) :], attr_map)
261        if len(tlv_list) != hdr.count:
262            raise ValueError("wrong number of objects")
263        self = cls(hdr.head.n_type, obj_list=tlv_list)
264        return self
265
266    def __bytes__(self):
267        ret = b""
268        for obj in self.obj_list:
269            ret += bytes(obj)
270        length = len(ret) + sizeof(IpFwObjCTlv)
271        if self.obj_list:
272            objsize = len(bytes(self.obj_list[0]))
273        else:
274            objsize = 0
275        hdr = IpFwObjCTlv(
276            head=IpFwObjTlv(n_type=self.obj_type, length=sizeof(IpFwObjNTlv)),
277            count=len(self.obj_list),
278            objsize=objsize,
279        )
280        return bytes(hdr) + ret
281
282    def _print_obj_value(self):
283        return ""
284
285
286class IpFwRule(Structure):
287    _fields_ = [
288        ("act_ofs", c_ushort),
289        ("cmd_len", c_ushort),
290        ("spare", c_ushort),
291        ("n_set", c_uint8),
292        ("flags", c_uint8),
293        ("rulenum", c_uint32),
294        ("n_id", c_uint32),
295    ]
296
297
298class RawRule(Tlv):
299    def __init__(self, obj_type=0, n_set=0, rulenum=0, obj_list=[]):
300        super().__init__(obj_type)
301        self.n_set = n_set
302        self.rulenum = rulenum
303        if obj_list:
304            self.obj_list.extend(obj_list)
305
306    @classmethod
307    def _validate(cls, data):
308        min_size = sizeof(IpFwRule)
309        if len(data) < min_size:
310            raise ValueError("rule TLV too short")
311        rule = IpFwRule.from_buffer_copy(data[:min_size])
312        if len(data) != min_size + rule.cmd_len * 4:
313            raise ValueError("rule TLV cmd_len incorrect")
314
315    @classmethod
316    def _parse(cls, data, attr_map):
317        hdr = IpFwRule.from_buffer_copy(data[: sizeof(IpFwRule)])
318        self = cls(
319            n_set=hdr.n_set,
320            rulenum=hdr.rulenum,
321            obj_list=BaseInsn.parse_insns(data[sizeof(IpFwRule) :], insn_attrs),
322        )
323        return self
324
325    def __bytes__(self):
326        act_ofs = 0
327        cmd_len = 0
328        ret = b""
329        for obj in self.obj_list:
330            if obj.is_action and act_ofs == 0:
331                act_ofs = cmd_len
332            obj_bytes = bytes(obj)
333            cmd_len += len(obj_bytes) // 4
334            ret += obj_bytes
335
336        hdr = IpFwRule(
337            act_ofs=act_ofs,
338            cmd_len=cmd_len,
339            n_set=self.n_set,
340            rulenum=self.rulenum,
341        )
342        return bytes(hdr) + ret
343
344    @property
345    def obj_name(self):
346        return "rule#{}".format(self.rulenum)
347
348    def _print_obj_value(self):
349        cmd_len = sum([len(bytes(obj)) for obj in self.obj_list]) // 4
350        return " set={} cmd_len={}".format(self.n_set, cmd_len)
351
352
353class CTlvRule(CTlv):
354    def __init__(self, obj_type=IpFwTlvType.IPFW_TLV_RULE_LIST, obj_list=[]):
355        super().__init__(obj_type, obj_list)
356
357    @classmethod
358    def _parse(cls, data, attr_map):
359        chdr = IpFwObjCTlv.from_buffer_copy(data[: sizeof(IpFwObjCTlv)])
360        rule_list = []
361        off = sizeof(IpFwObjCTlv)
362        while off + sizeof(IpFwRule) <= len(data):
363            hdr = IpFwRule.from_buffer_copy(data[off : off + sizeof(IpFwRule)])
364            rule_len = sizeof(IpFwRule) + hdr.cmd_len * 4
365            # print("FOUND RULE len={} cmd_len={}".format(rule_len, hdr.cmd_len))
366            if off + rule_len > len(data):
367                raise ValueError("wrong rule size")
368            rule = RawRule.from_bytes(data[off : off + rule_len])
369            rule_list.append(rule)
370            off += align8(rule_len)
371        if off != len(data):
372            raise ValueError("rule bytes left: off={} len={}".format(off, len(data)))
373        return cls(chdr.head.n_type, obj_list=rule_list)
374
375    # XXX: _validate
376
377    def __bytes__(self):
378        ret = b""
379        for rule in self.obj_list:
380            rule_bytes = bytes(rule)
381            remainder = len(rule_bytes) % 8
382            if remainder > 0:
383                rule_bytes += b"\0" * (8 - remainder)
384            ret += rule_bytes
385        hdr = IpFwObjCTlv(
386            head=IpFwObjTlv(
387                n_type=self.obj_type, length=len(ret) + sizeof(IpFwObjCTlv)
388            ),
389            count=len(self.obj_list),
390        )
391        return bytes(hdr) + ret
392
393
394class BaseIpFwMessage(object):
395    messages = []
396
397    def __init__(self, msg_type, obj_list=[]):
398        if isinstance(msg_type, Enum):
399            self.obj_type = msg_type.value
400            self._enum = msg_type
401        else:
402            self.obj_type = msg_type
403            self._enum = enum_from_int(self.messages, self.obj_type)
404        self.obj_list = []
405        if obj_list:
406            self.obj_list.extend(obj_list)
407
408    def add_obj(self, obj):
409        self.obj_list.append(obj)
410
411    def get_obj(self, obj_type):
412        obj_type_raw = enum_or_int(obj_type)
413        for obj in self.obj_list:
414            if obj.obj_type == obj_type_raw:
415                return obj
416        return None
417
418    @staticmethod
419    def parse_header(data: bytes):
420        if len(data) < sizeof(IpFw3OpHeader):
421            raise ValueError("length less than op3 message header")
422        return IpFw3OpHeader.from_buffer_copy(data), sizeof(IpFw3OpHeader)
423
424    def parse_obj_list(self, data: bytes):
425        off = 0
426        while off < len(data):
427            # print("PARSE off={} rem={}".format(off, len(data) - off))
428            hdr = IpFwObjTlv.from_buffer_copy(data[off : off + sizeof(IpFwObjTlv)])
429            # print(" tlv len {}".format(hdr.length))
430            if hdr.length + off > len(data):
431                raise ValueError("TLV too big")
432            tlv = Tlv(hdr.n_type, data[off : off + hdr.length])
433            self.add_obj(tlv)
434            off += hdr.length
435
436    def is_type(self, msg_type):
437        return enum_or_int(msg_type) == self.msg_type
438
439    @property
440    def obj_name(self):
441        if self._enum is not None:
442            return self._enum.name
443        else:
444            return "msg#{}".format(self.obj_type)
445
446    def print_hdr(self, prepend=""):
447        print("{}len={}, type={}".format(prepend, len(bytes(self)), self.obj_name))
448
449    @classmethod
450    def from_bytes(cls, data):
451        try:
452            hdr, hdrlen = cls.parse_header(data)
453            self = cls(hdr.opcode)
454            self._orig_data = data
455        except ValueError as e:
456            print("Failed to parse op3 header: {}".format(e))
457            cls.print_as_bytes(data)
458            raise
459        tlv_list = Tlv.parse_tlvs(data[hdrlen:], self.attr_map)
460        self.obj_list.extend(tlv_list)
461        return self
462
463    def __bytes__(self):
464        ret = bytes(IpFw3OpHeader(opcode=self.obj_type))
465        for obj in self.obj_list:
466            ret += bytes(obj)
467        return ret
468
469    def print_obj(self):
470        self.print_hdr()
471        for obj in self.obj_list:
472            obj.print_obj("  ")
473
474    @staticmethod
475    def print_as_bytes(data: bytes, descr: str):
476        print("===vv {} (len:{:3d}) vv===".format(descr, len(data)))
477        off = 0
478        step = 16
479        while off < len(data):
480            for i in range(step):
481                if off + i < len(data):
482                    print(" {:02X}".format(data[off + i]), end="")
483            print("")
484            off += step
485        print("--------------------")
486
487
488rule_attrs = prepare_attrs_map(
489    [
490        AttrDescr(
491            IpFwTlvType.IPFW_TLV_TBLNAME_LIST,
492            CTlv,
493            [
494                AttrDescr(IpFwTlvType.IPFW_TLV_TBL_NAME, NTlv),
495                AttrDescr(IpFwTlvType.IPFW_TLV_STATE_NAME, NTlv),
496                AttrDescr(IpFwTlvType.IPFW_TLV_EACTION, NTlv),
497            ],
498            True,
499        ),
500        AttrDescr(IpFwTlvType.IPFW_TLV_RULE_LIST, CTlvRule),
501    ]
502)
503
504
505class IpFwXRule(BaseIpFwMessage):
506    messages = [Op3CmdType.IP_FW_XADD]
507    attr_map = rule_attrs
508
509
510legacy_classes = []
511set3_classes = []
512get3_classes = [IpFwXRule]
513