1#!/usr/bin/env python3 2import os 3import socket 4import struct 5import subprocess 6import sys 7from ctypes import c_byte 8from ctypes import c_char 9from ctypes import c_int 10from ctypes import c_long 11from ctypes import c_uint32 12from ctypes import c_uint8 13from ctypes import c_ulong 14from ctypes import c_ushort 15from ctypes import sizeof 16from ctypes import Structure 17from enum import Enum 18from typing import Any 19from typing import Dict 20from typing import List 21from typing import NamedTuple 22from typing import Optional 23from typing import Union 24 25import pytest 26from atf_python.sys.netpfil.ipfw.insns import BaseInsn 27from atf_python.sys.netpfil.ipfw.insns import insn_attrs 28from atf_python.sys.netpfil.ipfw.ioctl_headers import IpFwTableLookupType 29from atf_python.sys.netpfil.ipfw.ioctl_headers import IpFwTlvType 30from atf_python.sys.netpfil.ipfw.ioctl_headers import Op3CmdType 31from atf_python.sys.netpfil.ipfw.utils import align8 32from atf_python.sys.netpfil.ipfw.utils import AttrDescr 33from atf_python.sys.netpfil.ipfw.utils import enum_from_int 34from atf_python.sys.netpfil.ipfw.utils import prepare_attrs_map 35 36 37class IpFw3OpHeader(Structure): 38 _fields_ = [ 39 ("opcode", c_ushort), 40 ("version", c_ushort), 41 ("reserved1", c_ushort), 42 ("reserved2", c_ushort), 43 ] 44 45 46class IpFwObjTlv(Structure): 47 _fields_ = [ 48 ("n_type", c_ushort), 49 ("flags", c_ushort), 50 ("length", c_uint32), 51 ] 52 53 54class BaseTlv(object): 55 obj_enum_class = IpFwTlvType 56 57 def __init__(self, obj_type): 58 if isinstance(obj_type, Enum): 59 self.obj_type = obj_type.value 60 self._enum = obj_type 61 else: 62 self.obj_type = obj_type 63 self._enum = enum_from_int(self.obj_enum_class, obj_type) 64 self.obj_list = [] 65 66 def add_obj(self, obj): 67 self.obj_list.append(obj) 68 69 @property 70 def len(self): 71 return len(bytes(self)) 72 73 @property 74 def obj_name(self): 75 if self._enum is not None: 76 return self._enum.name 77 else: 78 return "tlv#{}".format(self.obj_type) 79 80 def print_hdr(self, prepend=""): 81 print( 82 "{}len={} type={}({}){}".format( 83 prepend, self.len, self.obj_name, self.obj_type, self._print_obj_value() 84 ) 85 ) 86 87 def print_obj(self, prepend=""): 88 self.print_hdr(prepend) 89 prepend = " " + prepend 90 for obj in self.obj_list: 91 obj.print_obj(prepend) 92 93 def print_obj_hex(self, prepend=""): 94 print(prepend) 95 print() 96 print(" ".join(["x{:02X}".format(b) for b in bytes(self)])) 97 98 @classmethod 99 def _validate(cls, data): 100 if len(data) < sizeof(IpFwObjTlv): 101 raise ValueError("TLV too short") 102 hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)]) 103 if len(data) != hdr.length: 104 raise ValueError("wrong TLV size") 105 106 @classmethod 107 def _parse(cls, data, attr_map): 108 hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)]) 109 return cls(hdr.n_type) 110 111 @classmethod 112 def from_bytes(cls, data, attr_map=None): 113 cls._validate(data) 114 obj = cls._parse(data, attr_map) 115 return obj 116 117 def __bytes__(self): 118 raise NotImplementedError() 119 120 def _print_obj_value(self): 121 return " " + " ".join( 122 ["x{:02X}".format(b) for b in self._data[sizeof(IpFwObjTlv) :]] 123 ) 124 125 def as_hexdump(self): 126 return " ".join(["x{:02X}".format(b) for b in bytes(self)]) 127 128 129class UnknownTlv(BaseTlv): 130 def __init__(self, obj_type, data): 131 super().__init__(obj_type) 132 self._data = data 133 134 @classmethod 135 def _validate(cls, data): 136 if len(data) < sizeof(IpFwObjNTlv): 137 raise ValueError("TLV size is too short") 138 hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)]) 139 if len(data) != hdr.length: 140 raise ValueError("wrong TLV size") 141 142 @classmethod 143 def _parse(cls, data, attr_map): 144 hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)]) 145 self = cls(hdr.n_type, data) 146 return self 147 148 def __bytes__(self): 149 return self._data 150 151 152class Tlv(BaseTlv): 153 @staticmethod 154 def parse_tlvs(data, attr_map): 155 # print("PARSING " + " ".join(["x{:02X}".format(b) for b in data])) 156 off = 0 157 ret = [] 158 while off + sizeof(IpFwObjTlv) <= len(data): 159 hdr = IpFwObjTlv.from_buffer_copy(data[off : off + sizeof(IpFwObjTlv)]) 160 if off + hdr.length > len(data): 161 raise ValueError("TLV size do not match") 162 obj_data = data[off : off + hdr.length] 163 obj_descr = attr_map.get(hdr.n_type, None) 164 if obj_descr is None: 165 # raise ValueError("unknown child TLV {}".format(hdr.n_type)) 166 cls = UnknownTlv 167 child_map = {} 168 else: 169 cls = obj_descr["ad"].cls 170 child_map = obj_descr.get("child", {}) 171 # print("FOUND OBJECT type {}".format(cls)) 172 # print() 173 obj = cls.from_bytes(obj_data, child_map) 174 ret.append(obj) 175 off += hdr.length 176 return ret 177 178 179class IpFwObjNTlv(Structure): 180 _fields_ = [ 181 ("head", IpFwObjTlv), 182 ("idx", c_ushort), 183 ("n_set", c_uint8), 184 ("n_type", c_uint8), 185 ("spare", c_uint32), 186 ("name", c_char * 64), 187 ] 188 189 190class NTlv(Tlv): 191 def __init__(self, obj_type, idx=0, n_set=0, n_type=0, name=None): 192 super().__init__(obj_type) 193 self.n_idx = idx 194 self.n_set = n_set 195 self.n_type = n_type 196 self.n_name = name 197 198 @classmethod 199 def _validate(cls, data): 200 if len(data) != sizeof(IpFwObjNTlv): 201 raise ValueError("TLV size is not correct") 202 hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)]) 203 if len(data) != hdr.length: 204 raise ValueError("wrong TLV size") 205 206 @classmethod 207 def _parse(cls, data, attr_map): 208 hdr = IpFwObjNTlv.from_buffer_copy(data[: sizeof(IpFwObjNTlv)]) 209 name = hdr.name.decode("utf-8") 210 self = cls(hdr.head.n_type, hdr.idx, hdr.n_set, hdr.n_type, name) 211 return self 212 213 def __bytes__(self): 214 name_bytes = self.n_name.encode("utf-8") 215 if len(name_bytes) < 64: 216 name_bytes += b"\0" * (64 - len(name_bytes)) 217 hdr = IpFwObjNTlv( 218 head=IpFwObjTlv(n_type=self.obj_type, length=sizeof(IpFwObjNTlv)), 219 idx=self.n_idx, 220 n_set=self.n_set, 221 n_type=self.n_type, 222 name=name_bytes[:64], 223 ) 224 return bytes(hdr) 225 226 def _print_obj_value(self): 227 return " " + "type={} set={} idx={} name={}".format( 228 self.n_type, self.n_set, self.n_idx, self.n_name 229 ) 230 231 232class IpFwObjCTlv(Structure): 233 _fields_ = [ 234 ("head", IpFwObjTlv), 235 ("count", c_uint32), 236 ("objsize", c_ushort), 237 ("version", c_uint8), 238 ("flags", c_uint8), 239 ] 240 241 242class CTlv(Tlv): 243 def __init__(self, obj_type, obj_list=[]): 244 super().__init__(obj_type) 245 if obj_list: 246 self.obj_list.extend(obj_list) 247 248 @classmethod 249 def _validate(cls, data): 250 if len(data) < sizeof(IpFwObjCTlv): 251 raise ValueError("TLV too short") 252 hdr = IpFwObjCTlv.from_buffer_copy(data[: sizeof(IpFwObjCTlv)]) 253 if len(data) != hdr.head.length: 254 raise ValueError("wrong TLV size") 255 256 @classmethod 257 def _parse(cls, data, attr_map): 258 hdr = IpFwObjCTlv.from_buffer_copy(data[: sizeof(IpFwObjCTlv)]) 259 tlv_list = cls.parse_tlvs(data[sizeof(IpFwObjCTlv) :], attr_map) 260 if len(tlv_list) != hdr.count: 261 raise ValueError("wrong number of objects") 262 self = cls(hdr.head.n_type, obj_list=tlv_list) 263 return self 264 265 def __bytes__(self): 266 ret = b"" 267 for obj in self.obj_list: 268 ret += bytes(obj) 269 length = len(ret) + sizeof(IpFwObjCTlv) 270 if self.obj_list: 271 objsize = len(bytes(self.obj_list[0])) 272 else: 273 objsize = 0 274 hdr = IpFwObjCTlv( 275 head=IpFwObjTlv(n_type=self.obj_type, length=sizeof(IpFwObjNTlv)), 276 count=len(self.obj_list), 277 objsize=objsize, 278 ) 279 return bytes(hdr) + ret 280 281 def _print_obj_value(self): 282 return "" 283 284 285class IpFwRule(Structure): 286 _fields_ = [ 287 ("act_ofs", c_ushort), 288 ("cmd_len", c_ushort), 289 ("spare", c_ushort), 290 ("n_set", c_uint8), 291 ("flags", c_uint8), 292 ("rulenum", c_uint32), 293 ("n_id", c_uint32), 294 ] 295 296 297class RawRule(Tlv): 298 def __init__(self, obj_type=0, n_set=0, rulenum=0, obj_list=[]): 299 super().__init__(obj_type) 300 self.n_set = n_set 301 self.rulenum = rulenum 302 if obj_list: 303 self.obj_list.extend(obj_list) 304 305 @classmethod 306 def _validate(cls, data): 307 min_size = sizeof(IpFwRule) 308 if len(data) < min_size: 309 raise ValueError("rule TLV too short") 310 rule = IpFwRule.from_buffer_copy(data[:min_size]) 311 if len(data) != min_size + rule.cmd_len * 4: 312 raise ValueError("rule TLV cmd_len incorrect") 313 314 @classmethod 315 def _parse(cls, data, attr_map): 316 hdr = IpFwRule.from_buffer_copy(data[: sizeof(IpFwRule)]) 317 self = cls( 318 n_set=hdr.n_set, 319 rulenum=hdr.rulenum, 320 obj_list=BaseInsn.parse_insns(data[sizeof(IpFwRule) :], insn_attrs), 321 ) 322 return self 323 324 def __bytes__(self): 325 act_ofs = 0 326 cmd_len = 0 327 ret = b"" 328 for obj in self.obj_list: 329 if obj.is_action and act_ofs == 0: 330 act_ofs = cmd_len 331 obj_bytes = bytes(obj) 332 cmd_len += len(obj_bytes) // 4 333 ret += obj_bytes 334 335 hdr = IpFwRule( 336 act_ofs=act_ofs, 337 cmd_len=cmd_len, 338 n_set=self.n_set, 339 rulenum=self.rulenum, 340 ) 341 return bytes(hdr) + ret 342 343 @property 344 def obj_name(self): 345 return "rule#{}".format(self.rulenum) 346 347 def _print_obj_value(self): 348 cmd_len = sum([len(bytes(obj)) for obj in self.obj_list]) // 4 349 return " set={} cmd_len={}".format(self.n_set, cmd_len) 350 351 352class CTlvRule(CTlv): 353 def __init__(self, obj_type=IpFwTlvType.IPFW_TLV_RULE_LIST, obj_list=[]): 354 super().__init__(obj_type, obj_list) 355 356 @classmethod 357 def _parse(cls, data, attr_map): 358 chdr = IpFwObjCTlv.from_buffer_copy(data[: sizeof(IpFwObjCTlv)]) 359 rule_list = [] 360 off = sizeof(IpFwObjCTlv) 361 while off + sizeof(IpFwRule) <= len(data): 362 hdr = IpFwRule.from_buffer_copy(data[off : off + sizeof(IpFwRule)]) 363 rule_len = sizeof(IpFwRule) + hdr.cmd_len * 4 364 # print("FOUND RULE len={} cmd_len={}".format(rule_len, hdr.cmd_len)) 365 if off + rule_len > len(data): 366 raise ValueError("wrong rule size") 367 rule = RawRule.from_bytes(data[off : off + rule_len]) 368 rule_list.append(rule) 369 off += align8(rule_len) 370 if off != len(data): 371 raise ValueError("rule bytes left: off={} len={}".format(off, len(data))) 372 return cls(chdr.head.n_type, obj_list=rule_list) 373 374 # XXX: _validate 375 376 def __bytes__(self): 377 ret = b"" 378 for rule in self.obj_list: 379 rule_bytes = bytes(rule) 380 remainder = len(rule_bytes) % 8 381 if remainder > 0: 382 rule_bytes += b"\0" * (8 - remainder) 383 ret += rule_bytes 384 hdr = IpFwObjCTlv( 385 head=IpFwObjTlv( 386 n_type=self.obj_type, length=len(ret) + sizeof(IpFwObjCTlv) 387 ), 388 count=len(self.obj_list), 389 ) 390 return bytes(hdr) + ret 391 392 393class BaseIpFwMessage(object): 394 messages = [] 395 396 def __init__(self, msg_type, obj_list=[]): 397 if isinstance(msg_type, Enum): 398 self.obj_type = msg_type.value 399 self._enum = msg_type 400 else: 401 self.obj_type = msg_type 402 self._enum = enum_from_int(self.messages, self.obj_type) 403 self.obj_list = [] 404 if obj_list: 405 self.obj_list.extend(obj_list) 406 407 def add_obj(self, obj): 408 self.obj_list.append(obj) 409 410 def get_obj(self, obj_type): 411 obj_type_raw = enum_or_int(obj_type) 412 for obj in self.obj_list: 413 if obj.obj_type == obj_type_raw: 414 return obj 415 return None 416 417 @staticmethod 418 def parse_header(data: bytes): 419 if len(data) < sizeof(IpFw3OpHeader): 420 raise ValueError("length less than op3 message header") 421 return IpFw3OpHeader.from_buffer_copy(data), sizeof(IpFw3OpHeader) 422 423 def parse_obj_list(self, data: bytes): 424 off = 0 425 while off < len(data): 426 # print("PARSE off={} rem={}".format(off, len(data) - off)) 427 hdr = IpFwObjTlv.from_buffer_copy(data[off : off + sizeof(IpFwObjTlv)]) 428 # print(" tlv len {}".format(hdr.length)) 429 if hdr.length + off > len(data): 430 raise ValueError("TLV too big") 431 tlv = Tlv(hdr.n_type, data[off : off + hdr.length]) 432 self.add_obj(tlv) 433 off += hdr.length 434 435 def is_type(self, msg_type): 436 return enum_or_int(msg_type) == self.msg_type 437 438 @property 439 def obj_name(self): 440 if self._enum is not None: 441 return self._enum.name 442 else: 443 return "msg#{}".format(self.obj_type) 444 445 def print_hdr(self, prepend=""): 446 print("{}len={}, type={}".format(prepend, len(bytes(self)), self.obj_name)) 447 448 @classmethod 449 def from_bytes(cls, data): 450 try: 451 hdr, hdrlen = cls.parse_header(data) 452 self = cls(hdr.opcode) 453 self._orig_data = data 454 except ValueError as e: 455 print("Failed to parse op3 header: {}".format(e)) 456 cls.print_as_bytes(data) 457 raise 458 tlv_list = Tlv.parse_tlvs(data[hdrlen:], self.attr_map) 459 self.obj_list.extend(tlv_list) 460 return self 461 462 def __bytes__(self): 463 ret = bytes(IpFw3OpHeader(opcode=self.obj_type)) 464 for obj in self.obj_list: 465 ret += bytes(obj) 466 return ret 467 468 def print_obj(self): 469 self.print_hdr() 470 for obj in self.obj_list: 471 obj.print_obj(" ") 472 473 @staticmethod 474 def print_as_bytes(data: bytes, descr: str): 475 print("===vv {} (len:{:3d}) vv===".format(descr, len(data))) 476 off = 0 477 step = 16 478 while off < len(data): 479 for i in range(step): 480 if off + i < len(data): 481 print(" {:02X}".format(data[off + i]), end="") 482 print("") 483 off += step 484 print("--------------------") 485 486 487rule_attrs = prepare_attrs_map( 488 [ 489 AttrDescr( 490 IpFwTlvType.IPFW_TLV_TBLNAME_LIST, 491 CTlv, 492 [ 493 AttrDescr(IpFwTlvType.IPFW_TLV_TBL_NAME, NTlv), 494 AttrDescr(IpFwTlvType.IPFW_TLV_STATE_NAME, NTlv), 495 AttrDescr(IpFwTlvType.IPFW_TLV_EACTION, NTlv), 496 ], 497 True, 498 ), 499 AttrDescr(IpFwTlvType.IPFW_TLV_RULE_LIST, CTlvRule), 500 ] 501) 502 503 504class IpFwXRule(BaseIpFwMessage): 505 messages = [Op3CmdType.IP_FW_XADD] 506 attr_map = rule_attrs 507 508 509legacy_classes = [] 510set3_classes = [] 511get3_classes = [IpFwXRule] 512