1#!/usr/bin/env python3 2import os 3import socket 4import struct 5import subprocess 6import sys 7from ctypes import c_byte 8from ctypes import c_char 9from ctypes import c_int 10from ctypes import c_long 11from ctypes import c_uint32 12from ctypes import c_uint8 13from ctypes import c_ulong 14from ctypes import c_ushort 15from ctypes import sizeof 16from ctypes import Structure 17from enum import Enum 18from typing import Any 19from typing import Dict 20from typing import List 21from typing import NamedTuple 22from typing import Optional 23from typing import Union 24 25import pytest 26from atf_python.sys.netpfil.ipfw.insns import BaseInsn 27from atf_python.sys.netpfil.ipfw.insns import insn_attrs 28from atf_python.sys.netpfil.ipfw.ioctl_headers import IpFwTableLookupType 29from atf_python.sys.netpfil.ipfw.ioctl_headers import IpFwTableValueType 30from atf_python.sys.netpfil.ipfw.ioctl_headers import IpFwTlvType 31from atf_python.sys.netpfil.ipfw.ioctl_headers import Op3CmdType 32from atf_python.sys.netpfil.ipfw.utils import align8 33from atf_python.sys.netpfil.ipfw.utils import AttrDescr 34from atf_python.sys.netpfil.ipfw.utils import enum_from_int 35from atf_python.sys.netpfil.ipfw.utils import prepare_attrs_map 36 37 38class IpFw3OpHeader(Structure): 39 _fields_ = [ 40 ("opcode", c_ushort), 41 ("version", c_ushort), 42 ("reserved1", c_ushort), 43 ("reserved2", c_ushort), 44 ] 45 46 47class IpFwObjTlv(Structure): 48 _fields_ = [ 49 ("n_type", c_ushort), 50 ("flags", c_ushort), 51 ("length", c_uint32), 52 ] 53 54 55class BaseTlv(object): 56 obj_enum_class = IpFwTlvType 57 58 def __init__(self, obj_type): 59 if isinstance(obj_type, Enum): 60 self.obj_type = obj_type.value 61 self._enum = obj_type 62 else: 63 self.obj_type = obj_type 64 self._enum = enum_from_int(self.obj_enum_class, obj_type) 65 self.obj_list = [] 66 67 def add_obj(self, obj): 68 self.obj_list.append(obj) 69 70 @property 71 def len(self): 72 return len(bytes(self)) 73 74 @property 75 def obj_name(self): 76 if self._enum is not None: 77 return self._enum.name 78 else: 79 return "tlv#{}".format(self.obj_type) 80 81 def print_hdr(self, prepend=""): 82 print( 83 "{}len={} type={}({}){}".format( 84 prepend, self.len, self.obj_name, self.obj_type, self._print_obj_value() 85 ) 86 ) 87 88 def print_obj(self, prepend=""): 89 self.print_hdr(prepend) 90 prepend = " " + prepend 91 for obj in self.obj_list: 92 obj.print_obj(prepend) 93 94 def print_obj_hex(self, prepend=""): 95 print(prepend) 96 print() 97 print(" ".join(["x{:02X}".format(b) for b in bytes(self)])) 98 99 @classmethod 100 def _validate(cls, data): 101 if len(data) < sizeof(IpFwObjTlv): 102 raise ValueError("TLV too short") 103 hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)]) 104 if len(data) != hdr.length: 105 raise ValueError("wrong TLV size") 106 107 @classmethod 108 def _parse(cls, data, attr_map): 109 hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)]) 110 return cls(hdr.n_type) 111 112 @classmethod 113 def from_bytes(cls, data, attr_map=None): 114 cls._validate(data) 115 obj = cls._parse(data, attr_map) 116 return obj 117 118 def __bytes__(self): 119 raise NotImplementedError() 120 121 def _print_obj_value(self): 122 return " " + " ".join( 123 ["x{:02X}".format(b) for b in self._data[sizeof(IpFwObjTlv) :]] 124 ) 125 126 def as_hexdump(self): 127 return " ".join(["x{:02X}".format(b) for b in bytes(self)]) 128 129 130class UnknownTlv(BaseTlv): 131 def __init__(self, obj_type, data): 132 super().__init__(obj_type) 133 self._data = data 134 135 @classmethod 136 def _validate(cls, data): 137 if len(data) < sizeof(IpFwObjNTlv): 138 raise ValueError("TLV size is too short") 139 hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)]) 140 if len(data) != hdr.length: 141 raise ValueError("wrong TLV size") 142 143 @classmethod 144 def _parse(cls, data, attr_map): 145 hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)]) 146 self = cls(hdr.n_type, data) 147 return self 148 149 def __bytes__(self): 150 return self._data 151 152 153class Tlv(BaseTlv): 154 @staticmethod 155 def parse_tlvs(data, attr_map): 156 # print("PARSING " + " ".join(["x{:02X}".format(b) for b in data])) 157 off = 0 158 ret = [] 159 while off + sizeof(IpFwObjTlv) <= len(data): 160 hdr = IpFwObjTlv.from_buffer_copy(data[off : off + sizeof(IpFwObjTlv)]) 161 if off + hdr.length > len(data): 162 raise ValueError("TLV size do not match") 163 obj_data = data[off : off + hdr.length] 164 obj_descr = attr_map.get(hdr.n_type, None) 165 if obj_descr is None: 166 # raise ValueError("unknown child TLV {}".format(hdr.n_type)) 167 cls = UnknownTlv 168 child_map = {} 169 else: 170 cls = obj_descr["ad"].cls 171 child_map = obj_descr.get("child", {}) 172 # print("FOUND OBJECT type {}".format(cls)) 173 # print() 174 obj = cls.from_bytes(obj_data, child_map) 175 ret.append(obj) 176 off += hdr.length 177 return ret 178 179 180class IpFwObjNTlv(Structure): 181 _fields_ = [ 182 ("head", IpFwObjTlv), 183 ("idx", c_ushort), 184 ("n_set", c_uint8), 185 ("n_type", c_uint8), 186 ("spare", c_uint32), 187 ("name", c_char * 64), 188 ] 189 190 191class NTlv(Tlv): 192 def __init__(self, obj_type, idx=0, n_set=0, n_type=0, name=None): 193 super().__init__(obj_type) 194 self.n_idx = idx 195 self.n_set = n_set 196 self.n_type = n_type 197 self.n_name = name 198 199 @classmethod 200 def _validate(cls, data): 201 if len(data) != sizeof(IpFwObjNTlv): 202 raise ValueError("TLV size is not correct") 203 hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)]) 204 if len(data) != hdr.length: 205 raise ValueError("wrong TLV size") 206 207 @classmethod 208 def _parse(cls, data, attr_map): 209 hdr = IpFwObjNTlv.from_buffer_copy(data[: sizeof(IpFwObjNTlv)]) 210 name = hdr.name.decode("utf-8") 211 self = cls(hdr.head.n_type, hdr.idx, hdr.n_set, hdr.n_type, name) 212 return self 213 214 def __bytes__(self): 215 name_bytes = self.n_name.encode("utf-8") 216 if len(name_bytes) < 64: 217 name_bytes += b"\0" * (64 - len(name_bytes)) 218 hdr = IpFwObjNTlv( 219 head=IpFwObjTlv(n_type=self.obj_type, length=sizeof(IpFwObjNTlv)), 220 idx=self.n_idx, 221 n_set=self.n_set, 222 n_type=self.n_type, 223 name=name_bytes[:64], 224 ) 225 return bytes(hdr) 226 227 def _print_obj_value(self): 228 return " " + "type={} set={} idx={} name={}".format( 229 self.n_type, self.n_set, self.n_idx, self.n_name 230 ) 231 232 233class IpFwObjCTlv(Structure): 234 _fields_ = [ 235 ("head", IpFwObjTlv), 236 ("count", c_uint32), 237 ("objsize", c_ushort), 238 ("version", c_uint8), 239 ("flags", c_uint8), 240 ] 241 242 243class CTlv(Tlv): 244 def __init__(self, obj_type, obj_list=[]): 245 super().__init__(obj_type) 246 if obj_list: 247 self.obj_list.extend(obj_list) 248 249 @classmethod 250 def _validate(cls, data): 251 if len(data) < sizeof(IpFwObjCTlv): 252 raise ValueError("TLV too short") 253 hdr = IpFwObjCTlv.from_buffer_copy(data[: sizeof(IpFwObjCTlv)]) 254 if len(data) != hdr.head.length: 255 raise ValueError("wrong TLV size") 256 257 @classmethod 258 def _parse(cls, data, attr_map): 259 hdr = IpFwObjCTlv.from_buffer_copy(data[: sizeof(IpFwObjCTlv)]) 260 tlv_list = cls.parse_tlvs(data[sizeof(IpFwObjCTlv) :], attr_map) 261 if len(tlv_list) != hdr.count: 262 raise ValueError("wrong number of objects") 263 self = cls(hdr.head.n_type, obj_list=tlv_list) 264 return self 265 266 def __bytes__(self): 267 ret = b"" 268 for obj in self.obj_list: 269 ret += bytes(obj) 270 length = len(ret) + sizeof(IpFwObjCTlv) 271 if self.obj_list: 272 objsize = len(bytes(self.obj_list[0])) 273 else: 274 objsize = 0 275 hdr = IpFwObjCTlv( 276 head=IpFwObjTlv(n_type=self.obj_type, length=sizeof(IpFwObjNTlv)), 277 count=len(self.obj_list), 278 objsize=objsize, 279 ) 280 return bytes(hdr) + ret 281 282 def _print_obj_value(self): 283 return "" 284 285 286class IpFwRule(Structure): 287 _fields_ = [ 288 ("act_ofs", c_ushort), 289 ("cmd_len", c_ushort), 290 ("spare", c_ushort), 291 ("n_set", c_uint8), 292 ("flags", c_uint8), 293 ("rulenum", c_uint32), 294 ("n_id", c_uint32), 295 ] 296 297 298class RawRule(Tlv): 299 def __init__(self, obj_type=0, n_set=0, rulenum=0, obj_list=[]): 300 super().__init__(obj_type) 301 self.n_set = n_set 302 self.rulenum = rulenum 303 if obj_list: 304 self.obj_list.extend(obj_list) 305 306 @classmethod 307 def _validate(cls, data): 308 min_size = sizeof(IpFwRule) 309 if len(data) < min_size: 310 raise ValueError("rule TLV too short") 311 rule = IpFwRule.from_buffer_copy(data[:min_size]) 312 if len(data) != min_size + rule.cmd_len * 4: 313 raise ValueError("rule TLV cmd_len incorrect") 314 315 @classmethod 316 def _parse(cls, data, attr_map): 317 hdr = IpFwRule.from_buffer_copy(data[: sizeof(IpFwRule)]) 318 self = cls( 319 n_set=hdr.n_set, 320 rulenum=hdr.rulenum, 321 obj_list=BaseInsn.parse_insns(data[sizeof(IpFwRule) :], insn_attrs), 322 ) 323 return self 324 325 def __bytes__(self): 326 act_ofs = 0 327 cmd_len = 0 328 ret = b"" 329 for obj in self.obj_list: 330 if obj.is_action and act_ofs == 0: 331 act_ofs = cmd_len 332 obj_bytes = bytes(obj) 333 cmd_len += len(obj_bytes) // 4 334 ret += obj_bytes 335 336 hdr = IpFwRule( 337 act_ofs=act_ofs, 338 cmd_len=cmd_len, 339 n_set=self.n_set, 340 rulenum=self.rulenum, 341 ) 342 return bytes(hdr) + ret 343 344 @property 345 def obj_name(self): 346 return "rule#{}".format(self.rulenum) 347 348 def _print_obj_value(self): 349 cmd_len = sum([len(bytes(obj)) for obj in self.obj_list]) // 4 350 return " set={} cmd_len={}".format(self.n_set, cmd_len) 351 352 353class CTlvRule(CTlv): 354 def __init__(self, obj_type=IpFwTlvType.IPFW_TLV_RULE_LIST, obj_list=[]): 355 super().__init__(obj_type, obj_list) 356 357 @classmethod 358 def _parse(cls, data, attr_map): 359 chdr = IpFwObjCTlv.from_buffer_copy(data[: sizeof(IpFwObjCTlv)]) 360 rule_list = [] 361 off = sizeof(IpFwObjCTlv) 362 while off + sizeof(IpFwRule) <= len(data): 363 hdr = IpFwRule.from_buffer_copy(data[off : off + sizeof(IpFwRule)]) 364 rule_len = sizeof(IpFwRule) + hdr.cmd_len * 4 365 # print("FOUND RULE len={} cmd_len={}".format(rule_len, hdr.cmd_len)) 366 if off + rule_len > len(data): 367 raise ValueError("wrong rule size") 368 rule = RawRule.from_bytes(data[off : off + rule_len]) 369 rule_list.append(rule) 370 off += align8(rule_len) 371 if off != len(data): 372 raise ValueError("rule bytes left: off={} len={}".format(off, len(data))) 373 return cls(chdr.head.n_type, obj_list=rule_list) 374 375 # XXX: _validate 376 377 def __bytes__(self): 378 ret = b"" 379 for rule in self.obj_list: 380 rule_bytes = bytes(rule) 381 remainder = len(rule_bytes) % 8 382 if remainder > 0: 383 rule_bytes += b"\0" * (8 - remainder) 384 ret += rule_bytes 385 hdr = IpFwObjCTlv( 386 head=IpFwObjTlv( 387 n_type=self.obj_type, length=len(ret) + sizeof(IpFwObjCTlv) 388 ), 389 count=len(self.obj_list), 390 ) 391 return bytes(hdr) + ret 392 393 394class BaseIpFwMessage(object): 395 messages = [] 396 397 def __init__(self, msg_type, obj_list=[]): 398 if isinstance(msg_type, Enum): 399 self.obj_type = msg_type.value 400 self._enum = msg_type 401 else: 402 self.obj_type = msg_type 403 self._enum = enum_from_int(self.messages, self.obj_type) 404 self.obj_list = [] 405 if obj_list: 406 self.obj_list.extend(obj_list) 407 408 def add_obj(self, obj): 409 self.obj_list.append(obj) 410 411 def get_obj(self, obj_type): 412 obj_type_raw = enum_or_int(obj_type) 413 for obj in self.obj_list: 414 if obj.obj_type == obj_type_raw: 415 return obj 416 return None 417 418 @staticmethod 419 def parse_header(data: bytes): 420 if len(data) < sizeof(IpFw3OpHeader): 421 raise ValueError("length less than op3 message header") 422 return IpFw3OpHeader.from_buffer_copy(data), sizeof(IpFw3OpHeader) 423 424 def parse_obj_list(self, data: bytes): 425 off = 0 426 while off < len(data): 427 # print("PARSE off={} rem={}".format(off, len(data) - off)) 428 hdr = IpFwObjTlv.from_buffer_copy(data[off : off + sizeof(IpFwObjTlv)]) 429 # print(" tlv len {}".format(hdr.length)) 430 if hdr.length + off > len(data): 431 raise ValueError("TLV too big") 432 tlv = Tlv(hdr.n_type, data[off : off + hdr.length]) 433 self.add_obj(tlv) 434 off += hdr.length 435 436 def is_type(self, msg_type): 437 return enum_or_int(msg_type) == self.msg_type 438 439 @property 440 def obj_name(self): 441 if self._enum is not None: 442 return self._enum.name 443 else: 444 return "msg#{}".format(self.obj_type) 445 446 def print_hdr(self, prepend=""): 447 print("{}len={}, type={}".format(prepend, len(bytes(self)), self.obj_name)) 448 449 @classmethod 450 def from_bytes(cls, data): 451 try: 452 hdr, hdrlen = cls.parse_header(data) 453 self = cls(hdr.opcode) 454 self._orig_data = data 455 except ValueError as e: 456 print("Failed to parse op3 header: {}".format(e)) 457 cls.print_as_bytes(data) 458 raise 459 tlv_list = Tlv.parse_tlvs(data[hdrlen:], self.attr_map) 460 self.obj_list.extend(tlv_list) 461 return self 462 463 def __bytes__(self): 464 ret = bytes(IpFw3OpHeader(opcode=self.obj_type)) 465 for obj in self.obj_list: 466 ret += bytes(obj) 467 return ret 468 469 def print_obj(self): 470 self.print_hdr() 471 for obj in self.obj_list: 472 obj.print_obj(" ") 473 474 @staticmethod 475 def print_as_bytes(data: bytes, descr: str): 476 print("===vv {} (len:{:3d}) vv===".format(descr, len(data))) 477 off = 0 478 step = 16 479 while off < len(data): 480 for i in range(step): 481 if off + i < len(data): 482 print(" {:02X}".format(data[off + i]), end="") 483 print("") 484 off += step 485 print("--------------------") 486 487 488rule_attrs = prepare_attrs_map( 489 [ 490 AttrDescr( 491 IpFwTlvType.IPFW_TLV_TBLNAME_LIST, 492 CTlv, 493 [ 494 AttrDescr(IpFwTlvType.IPFW_TLV_TBL_NAME, NTlv), 495 AttrDescr(IpFwTlvType.IPFW_TLV_STATE_NAME, NTlv), 496 AttrDescr(IpFwTlvType.IPFW_TLV_EACTION, NTlv), 497 ], 498 True, 499 ), 500 AttrDescr(IpFwTlvType.IPFW_TLV_RULE_LIST, CTlvRule), 501 ] 502) 503 504 505class IpFwXRule(BaseIpFwMessage): 506 messages = [Op3CmdType.IP_FW_XADD] 507 attr_map = rule_attrs 508 509 510legacy_classes = [] 511set3_classes = [] 512get3_classes = [IpFwXRule] 513