1#!/usr/bin/env python3 2import os 3import socket 4import struct 5import subprocess 6import sys 7from ctypes import c_byte 8from ctypes import c_char 9from ctypes import c_int 10from ctypes import c_long 11from ctypes import c_uint32 12from ctypes import c_uint8 13from ctypes import c_ulong 14from ctypes import c_ushort 15from ctypes import sizeof 16from ctypes import Structure 17from enum import Enum 18from typing import Any 19from typing import Dict 20from typing import List 21from typing import NamedTuple 22from typing import Optional 23from typing import Union 24 25import pytest 26from atf_python.sys.netpfil.ipfw.insns import BaseInsn 27from atf_python.sys.netpfil.ipfw.insns import insn_attrs 28from atf_python.sys.netpfil.ipfw.ioctl_headers import IpFwTableLookupType 29from atf_python.sys.netpfil.ipfw.ioctl_headers import IpFwTlvType 30from atf_python.sys.netpfil.ipfw.ioctl_headers import Op3CmdType 31from atf_python.sys.netpfil.ipfw.utils import align8 32from atf_python.sys.netpfil.ipfw.utils import AttrDescr 33from atf_python.sys.netpfil.ipfw.utils import enum_from_int 34from atf_python.sys.netpfil.ipfw.utils import prepare_attrs_map 35 36 37class IpFw3OpHeader(Structure): 38 _fields_ = [ 39 ("opcode", c_ushort), 40 ("version", c_ushort), 41 ("reserved1", c_ushort), 42 ("reserved2", c_ushort), 43 ] 44 45 46class IpFwObjTlv(Structure): 47 _fields_ = [ 48 ("n_type", c_ushort), 49 ("flags", c_ushort), 50 ("length", c_uint32), 51 ] 52 53 54class BaseTlv(object): 55 obj_enum_class = IpFwTlvType 56 57 def __init__(self, obj_type): 58 if isinstance(obj_type, Enum): 59 self.obj_type = obj_type.value 60 self._enum = obj_type 61 else: 62 self.obj_type = obj_type 63 self._enum = enum_from_int(self.obj_enum_class, obj_type) 64 self.obj_list = [] 65 66 def add_obj(self, obj): 67 self.obj_list.append(obj) 68 69 @property 70 def len(self): 71 return len(bytes(self)) 72 73 @property 74 def obj_name(self): 75 if self._enum is not None: 76 return self._enum.name 77 else: 78 return "tlv#{}".format(self.obj_type) 79 80 def print_hdr(self, prepend=""): 81 print( 82 "{}len={} type={}({}){}".format( 83 prepend, self.len, self.obj_name, self.obj_type, self._print_obj_value() 84 ) 85 ) 86 87 def print_obj(self, prepend=""): 88 self.print_hdr(prepend) 89 prepend = " " + prepend 90 for obj in self.obj_list: 91 obj.print_obj(prepend) 92 93 @classmethod 94 def _validate(cls, data): 95 if len(data) < sizeof(IpFwObjTlv): 96 raise ValueError("TLV too short") 97 hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)]) 98 if len(data) != hdr.length: 99 raise ValueError("wrong TLV size") 100 101 @classmethod 102 def _parse(cls, data, attr_map): 103 hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)]) 104 return cls(hdr.n_type) 105 106 @classmethod 107 def from_bytes(cls, data, attr_map=None): 108 cls._validate(data) 109 obj = cls._parse(data, attr_map) 110 return obj 111 112 def __bytes__(self): 113 raise NotImplementedError() 114 115 def _print_obj_value(self): 116 return " " + " ".join( 117 ["x{:02X}".format(b) for b in self._data[sizeof(IpFwObjTlv) :]] 118 ) 119 120 def as_hexdump(self): 121 return " ".join(["x{:02X}".format(b) for b in bytes(self)]) 122 123 124class UnknownTlv(BaseTlv): 125 def __init__(self, obj_type, data): 126 super().__init__(obj_type) 127 self._data = data 128 129 @classmethod 130 def _validate(cls, data): 131 if len(data) < sizeof(IpFwObjNTlv): 132 raise ValueError("TLV size is too short") 133 hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)]) 134 if len(data) != hdr.length: 135 raise ValueError("wrong TLV size") 136 137 @classmethod 138 def _parse(cls, data, attr_map): 139 hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)]) 140 self = cls(hdr.n_type, data) 141 return self 142 143 def __bytes__(self): 144 return self._data 145 146 147class Tlv(BaseTlv): 148 @staticmethod 149 def parse_tlvs(data, attr_map): 150 # print("PARSING " + " ".join(["x{:02X}".format(b) for b in data])) 151 off = 0 152 ret = [] 153 while off + sizeof(IpFwObjTlv) <= len(data): 154 hdr = IpFwObjTlv.from_buffer_copy(data[off : off + sizeof(IpFwObjTlv)]) 155 if off + hdr.length > len(data): 156 raise ValueError("TLV size do not match") 157 obj_data = data[off : off + hdr.length] 158 obj_descr = attr_map.get(hdr.n_type, None) 159 if obj_descr is None: 160 # raise ValueError("unknown child TLV {}".format(hdr.n_type)) 161 cls = UnknownTlv 162 child_map = {} 163 else: 164 cls = obj_descr["ad"].cls 165 child_map = obj_descr.get("child", {}) 166 # print("FOUND OBJECT type {}".format(cls)) 167 # print() 168 obj = cls.from_bytes(obj_data, child_map) 169 ret.append(obj) 170 off += hdr.length 171 return ret 172 173 174class IpFwObjNTlv(Structure): 175 _fields_ = [ 176 ("head", IpFwObjTlv), 177 ("idx", c_ushort), 178 ("n_set", c_uint8), 179 ("n_type", c_uint8), 180 ("spare", c_uint32), 181 ("name", c_char * 64), 182 ] 183 184 185class NTlv(Tlv): 186 def __init__(self, obj_type, idx=0, n_set=0, n_type=0, name=None): 187 super().__init__(obj_type) 188 self.n_idx = idx 189 self.n_set = n_set 190 self.n_type = n_type 191 self.n_name = name 192 193 @classmethod 194 def _validate(cls, data): 195 if len(data) != sizeof(IpFwObjNTlv): 196 raise ValueError("TLV size is not correct") 197 hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)]) 198 if len(data) != hdr.length: 199 raise ValueError("wrong TLV size") 200 201 @classmethod 202 def _parse(cls, data, attr_map): 203 hdr = IpFwObjNTlv.from_buffer_copy(data[: sizeof(IpFwObjNTlv)]) 204 name = hdr.name.decode("utf-8") 205 self = cls(hdr.head.n_type, hdr.idx, hdr.n_set, hdr.n_type, name) 206 return self 207 208 def __bytes__(self): 209 name_bytes = self.n_name.encode("utf-8") 210 if len(name_bytes) < 64: 211 name_bytes += b"\0" * (64 - len(name_bytes)) 212 hdr = IpFwObjNTlv( 213 head=IpFwObjTlv(n_type=self.obj_type, length=sizeof(IpFwObjNTlv)), 214 idx=self.n_idx, 215 n_set=self.n_set, 216 n_type=self.n_type, 217 name=name_bytes[:64], 218 ) 219 return bytes(hdr) 220 221 def _print_obj_value(self): 222 return " " + "type={} set={} idx={} name={}".format( 223 self.n_type, self.n_set, self.n_idx, self.n_name 224 ) 225 226 227class IpFwObjCTlv(Structure): 228 _fields_ = [ 229 ("head", IpFwObjTlv), 230 ("count", c_uint32), 231 ("objsize", c_ushort), 232 ("version", c_uint8), 233 ("flags", c_uint8), 234 ] 235 236 237class CTlv(Tlv): 238 def __init__(self, obj_type, obj_list=[]): 239 super().__init__(obj_type) 240 if obj_list: 241 self.obj_list.extend(obj_list) 242 243 @classmethod 244 def _validate(cls, data): 245 if len(data) < sizeof(IpFwObjCTlv): 246 raise ValueError("TLV too short") 247 hdr = IpFwObjCTlv.from_buffer_copy(data[: sizeof(IpFwObjCTlv)]) 248 if len(data) != hdr.head.length: 249 raise ValueError("wrong TLV size") 250 251 @classmethod 252 def _parse(cls, data, attr_map): 253 hdr = IpFwObjCTlv.from_buffer_copy(data[: sizeof(IpFwObjCTlv)]) 254 tlv_list = cls.parse_tlvs(data[sizeof(IpFwObjCTlv) :], attr_map) 255 if len(tlv_list) != hdr.count: 256 raise ValueError("wrong number of objects") 257 self = cls(hdr.head.n_type, obj_list=tlv_list) 258 return self 259 260 def __bytes__(self): 261 ret = b"" 262 for obj in self.obj_list: 263 ret += bytes(obj) 264 length = len(ret) + sizeof(IpFwObjCTlv) 265 if self.obj_list: 266 objsize = len(bytes(self.obj_list[0])) 267 else: 268 objsize = 0 269 hdr = IpFwObjCTlv( 270 head=IpFwObjTlv(n_type=self.obj_type, length=sizeof(IpFwObjNTlv)), 271 count=len(self.obj_list), 272 objsize=objsize, 273 ) 274 return bytes(hdr) + ret 275 276 def _print_obj_value(self): 277 return "" 278 279 280class IpFwRule(Structure): 281 _fields_ = [ 282 ("act_ofs", c_ushort), 283 ("cmd_len", c_ushort), 284 ("spare", c_ushort), 285 ("n_set", c_uint8), 286 ("flags", c_uint8), 287 ("rulenum", c_uint32), 288 ("n_id", c_uint32), 289 ] 290 291 292class RawRule(Tlv): 293 def __init__(self, obj_type=0, n_set=0, rulenum=0, obj_list=[]): 294 super().__init__(obj_type) 295 self.n_set = n_set 296 self.rulenum = rulenum 297 if obj_list: 298 self.obj_list.extend(obj_list) 299 300 @classmethod 301 def _validate(cls, data): 302 min_size = sizeof(IpFwRule) 303 if len(data) < min_size: 304 raise ValueError("rule TLV too short") 305 rule = IpFwRule.from_buffer_copy(data[:min_size]) 306 if len(data) != min_size + rule.cmd_len * 4: 307 raise ValueError("rule TLV cmd_len incorrect") 308 309 @classmethod 310 def _parse(cls, data, attr_map): 311 hdr = IpFwRule.from_buffer_copy(data[: sizeof(IpFwRule)]) 312 self = cls( 313 n_set=hdr.n_set, 314 rulenum=hdr.rulenum, 315 obj_list=BaseInsn.parse_insns(data[sizeof(IpFwRule) :], insn_attrs), 316 ) 317 return self 318 319 def __bytes__(self): 320 act_ofs = 0 321 cmd_len = 0 322 ret = b"" 323 for obj in self.obj_list: 324 if obj.is_action and act_ofs == 0: 325 act_ofs = cmd_len 326 obj_bytes = bytes(obj) 327 cmd_len += len(obj_bytes) // 4 328 ret += obj_bytes 329 330 hdr = IpFwRule( 331 act_ofs=act_ofs, 332 cmd_len=cmd_len, 333 n_set=self.n_set, 334 rulenum=self.rulenum, 335 ) 336 return bytes(hdr) + ret 337 338 @property 339 def obj_name(self): 340 return "rule#{}".format(self.rulenum) 341 342 def _print_obj_value(self): 343 cmd_len = sum([len(bytes(obj)) for obj in self.obj_list]) // 4 344 return " set={} cmd_len={}".format(self.n_set, cmd_len) 345 346 347class CTlvRule(CTlv): 348 def __init__(self, obj_type=IpFwTlvType.IPFW_TLV_RULE_LIST, obj_list=[]): 349 super().__init__(obj_type, obj_list) 350 351 @classmethod 352 def _parse(cls, data, attr_map): 353 chdr = IpFwObjCTlv.from_buffer_copy(data[: sizeof(IpFwObjCTlv)]) 354 rule_list = [] 355 off = sizeof(IpFwObjCTlv) 356 while off + sizeof(IpFwRule) <= len(data): 357 hdr = IpFwRule.from_buffer_copy(data[off : off + sizeof(IpFwRule)]) 358 rule_len = sizeof(IpFwRule) + hdr.cmd_len * 4 359 # print("FOUND RULE len={} cmd_len={}".format(rule_len, hdr.cmd_len)) 360 if off + rule_len > len(data): 361 raise ValueError("wrong rule size") 362 rule = RawRule.from_bytes(data[off : off + rule_len]) 363 rule_list.append(rule) 364 off += align8(rule_len) 365 if off != len(data): 366 raise ValueError("rule bytes left: off={} len={}".format(off, len(data))) 367 return cls(chdr.head.n_type, obj_list=rule_list) 368 369 # XXX: _validate 370 371 def __bytes__(self): 372 ret = b"" 373 for rule in self.obj_list: 374 rule_bytes = bytes(rule) 375 remainder = len(rule_bytes) % 8 376 if remainder > 0: 377 rule_bytes += b"\0" * (8 - remainder) 378 ret += rule_bytes 379 hdr = IpFwObjCTlv( 380 head=IpFwObjTlv( 381 n_type=self.obj_type, length=len(ret) + sizeof(IpFwObjCTlv) 382 ), 383 count=len(self.obj_list), 384 ) 385 return bytes(hdr) + ret 386 387 388class BaseIpFwMessage(object): 389 messages = [] 390 391 def __init__(self, msg_type, obj_list=[]): 392 if isinstance(msg_type, Enum): 393 self.obj_type = msg_type.value 394 self._enum = msg_type 395 else: 396 self.obj_type = msg_type 397 self._enum = enum_from_int(self.messages, self.obj_type) 398 self.obj_list = [] 399 if obj_list: 400 self.obj_list.extend(obj_list) 401 402 def add_obj(self, obj): 403 self.obj_list.append(obj) 404 405 def get_obj(self, obj_type): 406 obj_type_raw = enum_or_int(obj_type) 407 for obj in self.obj_list: 408 if obj.obj_type == obj_type_raw: 409 return obj 410 return None 411 412 @staticmethod 413 def parse_header(data: bytes): 414 if len(data) < sizeof(IpFw3OpHeader): 415 raise ValueError("length less than op3 message header") 416 return IpFw3OpHeader.from_buffer_copy(data), sizeof(IpFw3OpHeader) 417 418 def parse_obj_list(self, data: bytes): 419 off = 0 420 while off < len(data): 421 # print("PARSE off={} rem={}".format(off, len(data) - off)) 422 hdr = IpFwObjTlv.from_buffer_copy(data[off : off + sizeof(IpFwObjTlv)]) 423 # print(" tlv len {}".format(hdr.length)) 424 if hdr.length + off > len(data): 425 raise ValueError("TLV too big") 426 tlv = Tlv(hdr.n_type, data[off : off + hdr.length]) 427 self.add_obj(tlv) 428 off += hdr.length 429 430 def is_type(self, msg_type): 431 return enum_or_int(msg_type) == self.msg_type 432 433 @property 434 def obj_name(self): 435 if self._enum is not None: 436 return self._enum.name 437 else: 438 return "msg#{}".format(self.obj_type) 439 440 def print_hdr(self, prepend=""): 441 print("{}len={}, type={}".format(prepend, len(bytes(self)), self.obj_name)) 442 443 @classmethod 444 def from_bytes(cls, data): 445 try: 446 hdr, hdrlen = cls.parse_header(data) 447 self = cls(hdr.opcode) 448 self._orig_data = data 449 except ValueError as e: 450 print("Failed to parse op3 header: {}".format(e)) 451 cls.print_as_bytes(data) 452 raise 453 tlv_list = Tlv.parse_tlvs(data[hdrlen:], self.attr_map) 454 self.obj_list.extend(tlv_list) 455 return self 456 457 def __bytes__(self): 458 ret = bytes(IpFw3OpHeader(opcode=self.obj_type)) 459 for obj in self.obj_list: 460 ret += bytes(obj) 461 return ret 462 463 def print_obj(self): 464 self.print_hdr() 465 for obj in self.obj_list: 466 obj.print_obj(" ") 467 468 @staticmethod 469 def print_as_bytes(data: bytes, descr: str): 470 print("===vv {} (len:{:3d}) vv===".format(descr, len(data))) 471 off = 0 472 step = 16 473 while off < len(data): 474 for i in range(step): 475 if off + i < len(data): 476 print(" {:02X}".format(data[off + i]), end="") 477 print("") 478 off += step 479 print("--------------------") 480 481 482rule_attrs = prepare_attrs_map( 483 [ 484 AttrDescr( 485 IpFwTlvType.IPFW_TLV_TBLNAME_LIST, 486 CTlv, 487 [ 488 AttrDescr(IpFwTlvType.IPFW_TLV_TBL_NAME, NTlv), 489 AttrDescr(IpFwTlvType.IPFW_TLV_STATE_NAME, NTlv), 490 ], 491 True, 492 ), 493 AttrDescr(IpFwTlvType.IPFW_TLV_RULE_LIST, CTlvRule), 494 ] 495) 496 497 498class IpFwXRule(BaseIpFwMessage): 499 messages = [Op3CmdType.IP_FW_XADD] 500 attr_map = rule_attrs 501 502 503legacy_classes = [] 504set3_classes = [] 505get3_classes = [IpFwXRule] 506