xref: /freebsd/tests/atf_python/sys/netpfil/ipfw/ioctl.py (revision 9247238cc4b8835892a47701136b0fd073f8d67c)
1#!/usr/bin/env python3
2import os
3import socket
4import struct
5import subprocess
6import sys
7from ctypes import c_byte
8from ctypes import c_char
9from ctypes import c_int
10from ctypes import c_long
11from ctypes import c_uint32
12from ctypes import c_uint8
13from ctypes import c_ulong
14from ctypes import c_ushort
15from ctypes import sizeof
16from ctypes import Structure
17from enum import Enum
18from typing import Any
19from typing import Dict
20from typing import List
21from typing import NamedTuple
22from typing import Optional
23from typing import Union
24
25import pytest
26from atf_python.sys.netpfil.ipfw.insns import BaseInsn
27from atf_python.sys.netpfil.ipfw.insns import insn_attrs
28from atf_python.sys.netpfil.ipfw.ioctl_headers import IpFwTableLookupType
29from atf_python.sys.netpfil.ipfw.ioctl_headers import IpFwTlvType
30from atf_python.sys.netpfil.ipfw.ioctl_headers import Op3CmdType
31from atf_python.sys.netpfil.ipfw.utils import align8
32from atf_python.sys.netpfil.ipfw.utils import AttrDescr
33from atf_python.sys.netpfil.ipfw.utils import enum_from_int
34from atf_python.sys.netpfil.ipfw.utils import prepare_attrs_map
35
36
37class IpFw3OpHeader(Structure):
38    _fields_ = [
39        ("opcode", c_ushort),
40        ("version", c_ushort),
41        ("reserved1", c_ushort),
42        ("reserved2", c_ushort),
43    ]
44
45
46class IpFwObjTlv(Structure):
47    _fields_ = [
48        ("n_type", c_ushort),
49        ("flags", c_ushort),
50        ("length", c_uint32),
51    ]
52
53
54class BaseTlv(object):
55    obj_enum_class = IpFwTlvType
56
57    def __init__(self, obj_type):
58        if isinstance(obj_type, Enum):
59            self.obj_type = obj_type.value
60            self._enum = obj_type
61        else:
62            self.obj_type = obj_type
63            self._enum = enum_from_int(self.obj_enum_class, obj_type)
64        self.obj_list = []
65
66    def add_obj(self, obj):
67        self.obj_list.append(obj)
68
69    @property
70    def len(self):
71        return len(bytes(self))
72
73    @property
74    def obj_name(self):
75        if self._enum is not None:
76            return self._enum.name
77        else:
78            return "tlv#{}".format(self.obj_type)
79
80    def print_hdr(self, prepend=""):
81        print(
82            "{}len={} type={}({}){}".format(
83                prepend, self.len, self.obj_name, self.obj_type, self._print_obj_value()
84            )
85        )
86
87    def print_obj(self, prepend=""):
88        self.print_hdr(prepend)
89        prepend = "  " + prepend
90        for obj in self.obj_list:
91            obj.print_obj(prepend)
92
93    @classmethod
94    def _validate(cls, data):
95        if len(data) < sizeof(IpFwObjTlv):
96            raise ValueError("TLV too short")
97        hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)])
98        if len(data) != hdr.length:
99            raise ValueError("wrong TLV size")
100
101    @classmethod
102    def _parse(cls, data, attr_map):
103        hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)])
104        return cls(hdr.n_type)
105
106    @classmethod
107    def from_bytes(cls, data, attr_map=None):
108        cls._validate(data)
109        obj = cls._parse(data, attr_map)
110        return obj
111
112    def __bytes__(self):
113        raise NotImplementedError()
114
115    def _print_obj_value(self):
116        return " " + " ".join(
117            ["x{:02X}".format(b) for b in self._data[sizeof(IpFwObjTlv) :]]
118        )
119
120    def as_hexdump(self):
121        return " ".join(["x{:02X}".format(b) for b in bytes(self)])
122
123
124class UnknownTlv(BaseTlv):
125    def __init__(self, obj_type, data):
126        super().__init__(obj_type)
127        self._data = data
128
129    @classmethod
130    def _validate(cls, data):
131        if len(data) < sizeof(IpFwObjNTlv):
132            raise ValueError("TLV size is too short")
133        hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)])
134        if len(data) != hdr.length:
135            raise ValueError("wrong TLV size")
136
137    @classmethod
138    def _parse(cls, data, attr_map):
139        hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)])
140        self = cls(hdr.n_type, data)
141        return self
142
143    def __bytes__(self):
144        return self._data
145
146
147class Tlv(BaseTlv):
148    @staticmethod
149    def parse_tlvs(data, attr_map):
150        # print("PARSING " + " ".join(["x{:02X}".format(b) for b in data]))
151        off = 0
152        ret = []
153        while off + sizeof(IpFwObjTlv) <= len(data):
154            hdr = IpFwObjTlv.from_buffer_copy(data[off : off + sizeof(IpFwObjTlv)])
155            if off + hdr.length > len(data):
156                raise ValueError("TLV size do not match")
157            obj_data = data[off : off + hdr.length]
158            obj_descr = attr_map.get(hdr.n_type, None)
159            if obj_descr is None:
160                # raise ValueError("unknown child TLV {}".format(hdr.n_type))
161                cls = UnknownTlv
162                child_map = {}
163            else:
164                cls = obj_descr["ad"].cls
165                child_map = obj_descr.get("child", {})
166            # print("FOUND OBJECT type {}".format(cls))
167            # print()
168            obj = cls.from_bytes(obj_data, child_map)
169            ret.append(obj)
170            off += hdr.length
171        return ret
172
173
174class IpFwObjNTlv(Structure):
175    _fields_ = [
176        ("head", IpFwObjTlv),
177        ("idx", c_ushort),
178        ("n_set", c_uint8),
179        ("n_type", c_uint8),
180        ("spare", c_uint32),
181        ("name", c_char * 64),
182    ]
183
184
185class NTlv(Tlv):
186    def __init__(self, obj_type, idx=0, n_set=0, n_type=0, name=None):
187        super().__init__(obj_type)
188        self.n_idx = idx
189        self.n_set = n_set
190        self.n_type = n_type
191        self.n_name = name
192
193    @classmethod
194    def _validate(cls, data):
195        if len(data) != sizeof(IpFwObjNTlv):
196            raise ValueError("TLV size is not correct")
197        hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)])
198        if len(data) != hdr.length:
199            raise ValueError("wrong TLV size")
200
201    @classmethod
202    def _parse(cls, data, attr_map):
203        hdr = IpFwObjNTlv.from_buffer_copy(data[: sizeof(IpFwObjNTlv)])
204        name = hdr.name.decode("utf-8")
205        self = cls(hdr.head.n_type, hdr.idx, hdr.n_set, hdr.n_type, name)
206        return self
207
208    def __bytes__(self):
209        name_bytes = self.n_name.encode("utf-8")
210        if len(name_bytes) < 64:
211            name_bytes += b"\0" * (64 - len(name_bytes))
212        hdr = IpFwObjNTlv(
213            head=IpFwObjTlv(n_type=self.obj_type, length=sizeof(IpFwObjNTlv)),
214            idx=self.n_idx,
215            n_set=self.n_set,
216            n_type=self.n_type,
217            name=name_bytes[:64],
218        )
219        return bytes(hdr)
220
221    def _print_obj_value(self):
222        return " " + "type={} set={} idx={} name={}".format(
223            self.n_type, self.n_set, self.n_idx, self.n_name
224        )
225
226
227class IpFwObjCTlv(Structure):
228    _fields_ = [
229        ("head", IpFwObjTlv),
230        ("count", c_uint32),
231        ("objsize", c_ushort),
232        ("version", c_uint8),
233        ("flags", c_uint8),
234    ]
235
236
237class CTlv(Tlv):
238    def __init__(self, obj_type, obj_list=[]):
239        super().__init__(obj_type)
240        if obj_list:
241            self.obj_list.extend(obj_list)
242
243    @classmethod
244    def _validate(cls, data):
245        if len(data) < sizeof(IpFwObjCTlv):
246            raise ValueError("TLV too short")
247        hdr = IpFwObjCTlv.from_buffer_copy(data[: sizeof(IpFwObjCTlv)])
248        if len(data) != hdr.head.length:
249            raise ValueError("wrong TLV size")
250
251    @classmethod
252    def _parse(cls, data, attr_map):
253        hdr = IpFwObjCTlv.from_buffer_copy(data[: sizeof(IpFwObjCTlv)])
254        tlv_list = cls.parse_tlvs(data[sizeof(IpFwObjCTlv) :], attr_map)
255        if len(tlv_list) != hdr.count:
256            raise ValueError("wrong number of objects")
257        self = cls(hdr.head.n_type, obj_list=tlv_list)
258        return self
259
260    def __bytes__(self):
261        ret = b""
262        for obj in self.obj_list:
263            ret += bytes(obj)
264        length = len(ret) + sizeof(IpFwObjCTlv)
265        if self.obj_list:
266            objsize = len(bytes(self.obj_list[0]))
267        else:
268            objsize = 0
269        hdr = IpFwObjCTlv(
270            head=IpFwObjTlv(n_type=self.obj_type, length=sizeof(IpFwObjNTlv)),
271            count=len(self.obj_list),
272            objsize=objsize,
273        )
274        return bytes(hdr) + ret
275
276    def _print_obj_value(self):
277        return ""
278
279
280class IpFwRule(Structure):
281    _fields_ = [
282        ("act_ofs", c_ushort),
283        ("cmd_len", c_ushort),
284        ("spare", c_ushort),
285        ("n_set", c_uint8),
286        ("flags", c_uint8),
287        ("rulenum", c_uint32),
288        ("n_id", c_uint32),
289    ]
290
291
292class RawRule(Tlv):
293    def __init__(self, obj_type=0, n_set=0, rulenum=0, obj_list=[]):
294        super().__init__(obj_type)
295        self.n_set = n_set
296        self.rulenum = rulenum
297        if obj_list:
298            self.obj_list.extend(obj_list)
299
300    @classmethod
301    def _validate(cls, data):
302        min_size = sizeof(IpFwRule)
303        if len(data) < min_size:
304            raise ValueError("rule TLV too short")
305        rule = IpFwRule.from_buffer_copy(data[:min_size])
306        if len(data) != min_size + rule.cmd_len * 4:
307            raise ValueError("rule TLV cmd_len incorrect")
308
309    @classmethod
310    def _parse(cls, data, attr_map):
311        hdr = IpFwRule.from_buffer_copy(data[: sizeof(IpFwRule)])
312        self = cls(
313            n_set=hdr.n_set,
314            rulenum=hdr.rulenum,
315            obj_list=BaseInsn.parse_insns(data[sizeof(IpFwRule) :], insn_attrs),
316        )
317        return self
318
319    def __bytes__(self):
320        act_ofs = 0
321        cmd_len = 0
322        ret = b""
323        for obj in self.obj_list:
324            if obj.is_action and act_ofs == 0:
325                act_ofs = cmd_len
326            obj_bytes = bytes(obj)
327            cmd_len += len(obj_bytes) // 4
328            ret += obj_bytes
329
330        hdr = IpFwRule(
331            act_ofs=act_ofs,
332            cmd_len=cmd_len,
333            n_set=self.n_set,
334            rulenum=self.rulenum,
335        )
336        return bytes(hdr) + ret
337
338    @property
339    def obj_name(self):
340        return "rule#{}".format(self.rulenum)
341
342    def _print_obj_value(self):
343        cmd_len = sum([len(bytes(obj)) for obj in self.obj_list]) // 4
344        return " set={} cmd_len={}".format(self.n_set, cmd_len)
345
346
347class CTlvRule(CTlv):
348    def __init__(self, obj_type=IpFwTlvType.IPFW_TLV_RULE_LIST, obj_list=[]):
349        super().__init__(obj_type, obj_list)
350
351    @classmethod
352    def _parse(cls, data, attr_map):
353        chdr = IpFwObjCTlv.from_buffer_copy(data[: sizeof(IpFwObjCTlv)])
354        rule_list = []
355        off = sizeof(IpFwObjCTlv)
356        while off + sizeof(IpFwRule) <= len(data):
357            hdr = IpFwRule.from_buffer_copy(data[off : off + sizeof(IpFwRule)])
358            rule_len = sizeof(IpFwRule) + hdr.cmd_len * 4
359            # print("FOUND RULE len={} cmd_len={}".format(rule_len, hdr.cmd_len))
360            if off + rule_len > len(data):
361                raise ValueError("wrong rule size")
362            rule = RawRule.from_bytes(data[off : off + rule_len])
363            rule_list.append(rule)
364            off += align8(rule_len)
365        if off != len(data):
366            raise ValueError("rule bytes left: off={} len={}".format(off, len(data)))
367        return cls(chdr.head.n_type, obj_list=rule_list)
368
369    # XXX: _validate
370
371    def __bytes__(self):
372        ret = b""
373        for rule in self.obj_list:
374            rule_bytes = bytes(rule)
375            remainder = len(rule_bytes) % 8
376            if remainder > 0:
377                rule_bytes += b"\0" * (8 - remainder)
378            ret += rule_bytes
379        hdr = IpFwObjCTlv(
380            head=IpFwObjTlv(
381                n_type=self.obj_type, length=len(ret) + sizeof(IpFwObjCTlv)
382            ),
383            count=len(self.obj_list),
384        )
385        return bytes(hdr) + ret
386
387
388class BaseIpFwMessage(object):
389    messages = []
390
391    def __init__(self, msg_type, obj_list=[]):
392        if isinstance(msg_type, Enum):
393            self.obj_type = msg_type.value
394            self._enum = msg_type
395        else:
396            self.obj_type = msg_type
397            self._enum = enum_from_int(self.messages, self.obj_type)
398        self.obj_list = []
399        if obj_list:
400            self.obj_list.extend(obj_list)
401
402    def add_obj(self, obj):
403        self.obj_list.append(obj)
404
405    def get_obj(self, obj_type):
406        obj_type_raw = enum_or_int(obj_type)
407        for obj in self.obj_list:
408            if obj.obj_type == obj_type_raw:
409                return obj
410        return None
411
412    @staticmethod
413    def parse_header(data: bytes):
414        if len(data) < sizeof(IpFw3OpHeader):
415            raise ValueError("length less than op3 message header")
416        return IpFw3OpHeader.from_buffer_copy(data), sizeof(IpFw3OpHeader)
417
418    def parse_obj_list(self, data: bytes):
419        off = 0
420        while off < len(data):
421            # print("PARSE off={} rem={}".format(off, len(data) - off))
422            hdr = IpFwObjTlv.from_buffer_copy(data[off : off + sizeof(IpFwObjTlv)])
423            # print(" tlv len {}".format(hdr.length))
424            if hdr.length + off > len(data):
425                raise ValueError("TLV too big")
426            tlv = Tlv(hdr.n_type, data[off : off + hdr.length])
427            self.add_obj(tlv)
428            off += hdr.length
429
430    def is_type(self, msg_type):
431        return enum_or_int(msg_type) == self.msg_type
432
433    @property
434    def obj_name(self):
435        if self._enum is not None:
436            return self._enum.name
437        else:
438            return "msg#{}".format(self.obj_type)
439
440    def print_hdr(self, prepend=""):
441        print("{}len={}, type={}".format(prepend, len(bytes(self)), self.obj_name))
442
443    @classmethod
444    def from_bytes(cls, data):
445        try:
446            hdr, hdrlen = cls.parse_header(data)
447            self = cls(hdr.opcode)
448            self._orig_data = data
449        except ValueError as e:
450            print("Failed to parse op3 header: {}".format(e))
451            cls.print_as_bytes(data)
452            raise
453        tlv_list = Tlv.parse_tlvs(data[hdrlen:], self.attr_map)
454        self.obj_list.extend(tlv_list)
455        return self
456
457    def __bytes__(self):
458        ret = bytes(IpFw3OpHeader(opcode=self.obj_type))
459        for obj in self.obj_list:
460            ret += bytes(obj)
461        return ret
462
463    def print_obj(self):
464        self.print_hdr()
465        for obj in self.obj_list:
466            obj.print_obj("  ")
467
468    @staticmethod
469    def print_as_bytes(data: bytes, descr: str):
470        print("===vv {} (len:{:3d}) vv===".format(descr, len(data)))
471        off = 0
472        step = 16
473        while off < len(data):
474            for i in range(step):
475                if off + i < len(data):
476                    print(" {:02X}".format(data[off + i]), end="")
477            print("")
478            off += step
479        print("--------------------")
480
481
482rule_attrs = prepare_attrs_map(
483    [
484        AttrDescr(
485            IpFwTlvType.IPFW_TLV_TBLNAME_LIST,
486            CTlv,
487            [
488                AttrDescr(IpFwTlvType.IPFW_TLV_TBL_NAME, NTlv),
489                AttrDescr(IpFwTlvType.IPFW_TLV_STATE_NAME, NTlv),
490            ],
491            True,
492        ),
493        AttrDescr(IpFwTlvType.IPFW_TLV_RULE_LIST, CTlvRule),
494    ]
495)
496
497
498class IpFwXRule(BaseIpFwMessage):
499    messages = [Op3CmdType.IP_FW_XADD]
500    attr_map = rule_attrs
501
502
503legacy_classes = []
504set3_classes = []
505get3_classes = [IpFwXRule]
506