1*64130b0bSBill Sommerfeld#!@PYTHON@ 2*64130b0bSBill Sommerfeld 3*64130b0bSBill Sommerfeld# 4*64130b0bSBill Sommerfeld# This file and its contents are supplied under the terms of the 5*64130b0bSBill Sommerfeld# Common Development and Distribution License ("CDDL"), version 1.0. 6*64130b0bSBill Sommerfeld# You may only use this file in accordance with the terms of version 7*64130b0bSBill Sommerfeld# 1.0 of the CDDL. 8*64130b0bSBill Sommerfeld# 9*64130b0bSBill Sommerfeld# A full copy of the text of the CDDL should have accompanied this 10*64130b0bSBill Sommerfeld# source. A copy of the CDDL is also available via the Internet at 11*64130b0bSBill Sommerfeld# http://www.illumos.org/license/CDDL. 12*64130b0bSBill Sommerfeld# 13*64130b0bSBill Sommerfeld 14*64130b0bSBill Sommerfeld# 15*64130b0bSBill Sommerfeld# Copyright 2024 Bill Sommerfeld <sommerfeld@hamachi.org> 16*64130b0bSBill Sommerfeld# 17*64130b0bSBill Sommerfeld 18*64130b0bSBill Sommerfeld""" Set up multiple bound sockets and then send/connect to some of them. 19*64130b0bSBill SommerfeldHelps to test that link-local addresses are properly scoped """ 20*64130b0bSBill Sommerfeld 21*64130b0bSBill Sommerfeldimport argparse 22*64130b0bSBill Sommerfeldimport fcntl 23*64130b0bSBill Sommerfeldimport struct 24*64130b0bSBill Sommerfeldimport socket 25*64130b0bSBill Sommerfeldimport sys 26*64130b0bSBill Sommerfeldfrom threading import Thread 27*64130b0bSBill Sommerfeldfrom typing import List, Optional, Tuple, Type, Union 28*64130b0bSBill Sommerfeld 29*64130b0bSBill SommerfeldSockAddr = Union[Tuple[str, int], Tuple[str, int, int, int]] 30*64130b0bSBill Sommerfeld 31*64130b0bSBill SommerfeldIP_BOUND_IF=0x41 32*64130b0bSBill SommerfeldIPV6_BOUND_IF=0x41 33*64130b0bSBill Sommerfeld 34*64130b0bSBill SommerfeldSIOCGLIFINDEX=0xc0786985 35*64130b0bSBill SommerfeldLIFREQSIZE=376 36*64130b0bSBill SommerfeldLIFNAMSIZ=32 37*64130b0bSBill SommerfeldLIFRU_OFFSET=40 38*64130b0bSBill Sommerfeld 39*64130b0bSBill Sommerfelddef get_ifindex(arg: str) -> int: 40*64130b0bSBill Sommerfeld "Look up ifindex corresponding to a named interface" 41*64130b0bSBill Sommerfeld buf = bytearray(LIFREQSIZE) 42*64130b0bSBill Sommerfeld 43*64130b0bSBill Sommerfeld ifname = bytes(arg, encoding='ascii') 44*64130b0bSBill Sommerfeld if len(ifname) >= LIFNAMSIZ: 45*64130b0bSBill Sommerfeld raise ValueError('Interface name too long', arg) 46*64130b0bSBill Sommerfeld buf[0:len(ifname)] = ifname 47*64130b0bSBill Sommerfeld 48*64130b0bSBill Sommerfeld with socket.socket(family=socket.AF_INET6, 49*64130b0bSBill Sommerfeld type=socket.SOCK_DGRAM, 50*64130b0bSBill Sommerfeld proto=socket.IPPROTO_UDP) as s: 51*64130b0bSBill Sommerfeld fcntl.ioctl(s.fileno(), SIOCGLIFINDEX, buf) 52*64130b0bSBill Sommerfeld return struct.unpack_from('i', buffer=buf, offset=LIFRU_OFFSET)[0] 53*64130b0bSBill Sommerfeld 54*64130b0bSBill Sommerfelddef fmt_addr(addr: SockAddr) -> str: 55*64130b0bSBill Sommerfeld "Produce a printable form of a socket address" 56*64130b0bSBill Sommerfeld (addrstr, portstr) = socket.getnameinfo( 57*64130b0bSBill Sommerfeld addr, socket.NI_NUMERICHOST|socket.NI_NUMERICSERV) 58*64130b0bSBill Sommerfeld return addrstr + ' port ' + portstr 59*64130b0bSBill Sommerfeld 60*64130b0bSBill Sommerfeldclass TestProto: 61*64130b0bSBill Sommerfeld """ Abstract(-ish) base class for test protocols """ 62*64130b0bSBill Sommerfeld 63*64130b0bSBill Sommerfeld sockobj: socket.socket 64*64130b0bSBill Sommerfeld proto: int = -1 65*64130b0bSBill Sommerfeld type: int = -1 66*64130b0bSBill Sommerfeld thread: Thread 67*64130b0bSBill Sommerfeld ifindex: Optional[int] 68*64130b0bSBill Sommerfeld 69*64130b0bSBill Sommerfeld def __init__(self, name: str, family: int, addr: SockAddr) -> None: 70*64130b0bSBill Sommerfeld self.name = name 71*64130b0bSBill Sommerfeld self.family = family 72*64130b0bSBill Sommerfeld self.addr = addr 73*64130b0bSBill Sommerfeld self.ifindex = None 74*64130b0bSBill Sommerfeld 75*64130b0bSBill Sommerfeld def set_ifindex(self, ifindex: int) -> None: 76*64130b0bSBill Sommerfeld "Save an ifindex for later" 77*64130b0bSBill Sommerfeld self.ifindex = ifindex 78*64130b0bSBill Sommerfeld 79*64130b0bSBill Sommerfeld def bind_ifindex(self) -> None: 80*64130b0bSBill Sommerfeld "Apply saved ifindex (if any) to the socket" 81*64130b0bSBill Sommerfeld 82*64130b0bSBill Sommerfeld if self.ifindex is not None: 83*64130b0bSBill Sommerfeld print('bind to ifindex', self.ifindex) 84*64130b0bSBill Sommerfeld if self.family==socket.AF_INET6: 85*64130b0bSBill Sommerfeld self.sockobj.setsockopt(socket.IPPROTO_IPV6, IPV6_BOUND_IF, self.ifindex) 86*64130b0bSBill Sommerfeld else: 87*64130b0bSBill Sommerfeld self.sockobj.setsockopt(socket.IPPROTO_IP, IP_BOUND_IF, self.ifindex) 88*64130b0bSBill Sommerfeld 89*64130b0bSBill Sommerfeld def setup_listener(self) -> None: 90*64130b0bSBill Sommerfeld "Create a listening socket for the responder" 91*64130b0bSBill Sommerfeld self.sockobj = socket.socket(family=self.family, type=self.type, proto=self.proto) 92*64130b0bSBill Sommerfeld self.sockobj.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 93*64130b0bSBill Sommerfeld self.bind_ifindex() 94*64130b0bSBill Sommerfeld self.sockobj.bind(self.addr) 95*64130b0bSBill Sommerfeld 96*64130b0bSBill Sommerfeld def start_responder(self) -> None: 97*64130b0bSBill Sommerfeld "Create socket and start a server thread" 98*64130b0bSBill Sommerfeld self.setup_listener() 99*64130b0bSBill Sommerfeld self.thread = Thread(target=self.server_thread, name=self.name, daemon=True) 100*64130b0bSBill Sommerfeld self.thread.start() 101*64130b0bSBill Sommerfeld 102*64130b0bSBill Sommerfeld def server_thread(self) -> None: 103*64130b0bSBill Sommerfeld "Placeholder server thread body" 104*64130b0bSBill Sommerfeld raise ValueError 105*64130b0bSBill Sommerfeld 106*64130b0bSBill Sommerfeld def run_initiator(self) -> None: 107*64130b0bSBill Sommerfeld "Placeholder test client" 108*64130b0bSBill Sommerfeld raise ValueError 109*64130b0bSBill Sommerfeld 110*64130b0bSBill Sommerfeldclass TestProtoTcp(TestProto): 111*64130b0bSBill Sommerfeld """ Simple test for TCP sockets """ 112*64130b0bSBill Sommerfeld 113*64130b0bSBill Sommerfeld proto=socket.IPPROTO_TCP 114*64130b0bSBill Sommerfeld type=socket.SOCK_STREAM 115*64130b0bSBill Sommerfeld 116*64130b0bSBill Sommerfeld def setup_listener(self) -> None: 117*64130b0bSBill Sommerfeld super().setup_listener() 118*64130b0bSBill Sommerfeld self.sockobj.listen(5) 119*64130b0bSBill Sommerfeld 120*64130b0bSBill Sommerfeld def conn_thread(self, conn: socket.socket) -> None: 121*64130b0bSBill Sommerfeld "Secondary thread to handle an accepted connection" 122*64130b0bSBill Sommerfeld while True: 123*64130b0bSBill Sommerfeld buf = conn.recv(2048) 124*64130b0bSBill Sommerfeld if len(buf) == 0: 125*64130b0bSBill Sommerfeld conn.close() 126*64130b0bSBill Sommerfeld return 127*64130b0bSBill Sommerfeld conn.send(buf) 128*64130b0bSBill Sommerfeld 129*64130b0bSBill Sommerfeld def server_thread(self) -> None: 130*64130b0bSBill Sommerfeld while True: 131*64130b0bSBill Sommerfeld (conn, fromaddr) = self.sockobj.accept() 132*64130b0bSBill Sommerfeld print('accepted connection from', fmt_addr(fromaddr)) 133*64130b0bSBill Sommerfeld 134*64130b0bSBill Sommerfeld t = Thread(target=self.conn_thread, name='connection', daemon=True, args=[conn]) 135*64130b0bSBill Sommerfeld t.start() 136*64130b0bSBill Sommerfeld 137*64130b0bSBill Sommerfeld def run_initiator(self) -> None: 138*64130b0bSBill Sommerfeld 139*64130b0bSBill Sommerfeld self.sockobj = socket.socket(family=self.family, type=self.type, proto=self.proto) 140*64130b0bSBill Sommerfeld self.bind_ifindex() 141*64130b0bSBill Sommerfeld self.sockobj.settimeout(1.0) 142*64130b0bSBill Sommerfeld self.sockobj.connect(self.addr) 143*64130b0bSBill Sommerfeld 144*64130b0bSBill Sommerfeld 145*64130b0bSBill Sommerfeld msg=b'hello, world\n' 146*64130b0bSBill Sommerfeld self.sockobj.send(msg) 147*64130b0bSBill Sommerfeld buf = self.sockobj.recv(2048) 148*64130b0bSBill Sommerfeld if msg == buf: 149*64130b0bSBill Sommerfeld print (self.name, 'passed') 150*64130b0bSBill Sommerfeld else: 151*64130b0bSBill Sommerfeld raise ValueError('message mismatch', msg, buf) 152*64130b0bSBill Sommerfeld 153*64130b0bSBill Sommerfeldclass TestProtoUdp(TestProto): 154*64130b0bSBill Sommerfeld """ Simple test for UDP sockets """ 155*64130b0bSBill Sommerfeld 156*64130b0bSBill Sommerfeld proto=socket.IPPROTO_UDP 157*64130b0bSBill Sommerfeld type=socket.SOCK_DGRAM 158*64130b0bSBill Sommerfeld 159*64130b0bSBill Sommerfeld def server_thread(self) -> None: 160*64130b0bSBill Sommerfeld while True: 161*64130b0bSBill Sommerfeld (buf, fromaddr) = self.sockobj.recvfrom(2048) 162*64130b0bSBill Sommerfeld print('server received', len(buf), 'bytes from',fmt_addr(fromaddr)) 163*64130b0bSBill Sommerfeld self.sockobj.sendto(buf, fromaddr) 164*64130b0bSBill Sommerfeld 165*64130b0bSBill Sommerfeld def run_initiator(self) -> None: 166*64130b0bSBill Sommerfeld 167*64130b0bSBill Sommerfeld self.sockobj = socket.socket(family=self.family, type=self.type, proto=self.proto) 168*64130b0bSBill Sommerfeld self.bind_ifindex() 169*64130b0bSBill Sommerfeld self.sockobj.settimeout(0.1) 170*64130b0bSBill Sommerfeld self.sockobj.connect(self.addr) 171*64130b0bSBill Sommerfeld 172*64130b0bSBill Sommerfeld msg=b'hello, world from %s\n' % bytes(self.name, encoding='utf-8') 173*64130b0bSBill Sommerfeld self.sockobj.send(msg) 174*64130b0bSBill Sommerfeld (buf, fromaddr) = self.sockobj.recvfrom(2048) 175*64130b0bSBill Sommerfeld print('initiator received', len(buf), 'bytes from', fmt_addr(fromaddr)) 176*64130b0bSBill Sommerfeld if msg == buf: 177*64130b0bSBill Sommerfeld print (self.name, 'passed') 178*64130b0bSBill Sommerfeld else: 179*64130b0bSBill Sommerfeld raise ValueError('message mismatch', msg, buf) 180*64130b0bSBill Sommerfeld 181*64130b0bSBill Sommerfeldtest_map = { 182*64130b0bSBill Sommerfeld 'udp': TestProtoUdp, 183*64130b0bSBill Sommerfeld 'tcp': TestProtoTcp, 184*64130b0bSBill Sommerfeld} 185*64130b0bSBill Sommerfeld 186*64130b0bSBill Sommerfeldfamily_map = { 187*64130b0bSBill Sommerfeld '4': socket.AF_INET, 188*64130b0bSBill Sommerfeld '6': socket.AF_INET6, 189*64130b0bSBill Sommerfeld} 190*64130b0bSBill Sommerfeld 191*64130b0bSBill Sommerfelddef get_addr(addr: str, port: int, family: int, 192*64130b0bSBill Sommerfeld proto: Type[TestProto]) -> Tuple[SockAddr, Optional[int]]: 193*64130b0bSBill Sommerfeld """Pull sockaddr,ifindex pair out of a command line argument; 194*64130b0bSBill Sommerfeld accept either 'addr' or 'ifname,addr' syntax.""" 195*64130b0bSBill Sommerfeld ifindex = None 196*64130b0bSBill Sommerfeld 197*64130b0bSBill Sommerfeld if ',' in addr: 198*64130b0bSBill Sommerfeld (ifname, addr) = addr.split(',', maxsplit=1) 199*64130b0bSBill Sommerfeld ifindex = get_ifindex(ifname) 200*64130b0bSBill Sommerfeld 201*64130b0bSBill Sommerfeld sa = socket.getaddrinfo(addr, port, family=family, proto=proto.proto, 202*64130b0bSBill Sommerfeld flags=socket.AI_NUMERICHOST|socket.AI_NUMERICSERV) 203*64130b0bSBill Sommerfeld 204*64130b0bSBill Sommerfeld return (sa[0][4], ifindex) 205*64130b0bSBill Sommerfeld 206*64130b0bSBill Sommerfelddef main(argv: List[str]) -> int: 207*64130b0bSBill Sommerfeld "Multi-socket test. Bind several sockets; connect to several specified addresses" 208*64130b0bSBill Sommerfeld 209*64130b0bSBill Sommerfeld parser = argparse.ArgumentParser(prog='dup-bind') 210*64130b0bSBill Sommerfeld 211*64130b0bSBill Sommerfeld parser.add_argument('--proto', choices=test_map.keys(), required=True) 212*64130b0bSBill Sommerfeld parser.add_argument('--family', choices=family_map.keys(), required=True) 213*64130b0bSBill Sommerfeld parser.add_argument('--port', type=int, required=True) 214*64130b0bSBill Sommerfeld parser.add_argument('--addr', action='append') 215*64130b0bSBill Sommerfeld parser.add_argument('test', nargs='+') 216*64130b0bSBill Sommerfeld 217*64130b0bSBill Sommerfeld args = parser.parse_args(argv) 218*64130b0bSBill Sommerfeld 219*64130b0bSBill Sommerfeld endpoints = [] 220*64130b0bSBill Sommerfeld 221*64130b0bSBill Sommerfeld family=family_map[args.family] 222*64130b0bSBill Sommerfeld test_proto=test_map[args.proto] 223*64130b0bSBill Sommerfeld 224*64130b0bSBill Sommerfeld try: 225*64130b0bSBill Sommerfeld for addrstr in args.addr: 226*64130b0bSBill Sommerfeld print('listen on', addrstr) 227*64130b0bSBill Sommerfeld (saddr, ifindex) = get_addr(addrstr, args.port, family, test_proto) 228*64130b0bSBill Sommerfeld 229*64130b0bSBill Sommerfeld test_addr = test_proto(name=addrstr, family=family, addr=saddr) 230*64130b0bSBill Sommerfeld if ifindex is not None: 231*64130b0bSBill Sommerfeld test_addr.set_ifindex(ifindex) 232*64130b0bSBill Sommerfeld test_addr.start_responder() 233*64130b0bSBill Sommerfeld endpoints.append(test_addr) 234*64130b0bSBill Sommerfeld 235*64130b0bSBill Sommerfeld for addr in args.test: 236*64130b0bSBill Sommerfeld print('test to', addr) 237*64130b0bSBill Sommerfeld (saddr, ifindex) = get_addr(addr, args.port, family, test_proto) 238*64130b0bSBill Sommerfeld 239*64130b0bSBill Sommerfeld test_addr = test_proto(name=addr, family=family, addr=saddr) 240*64130b0bSBill Sommerfeld if ifindex is not None: 241*64130b0bSBill Sommerfeld test_addr.set_ifindex(ifindex) 242*64130b0bSBill Sommerfeld test_addr.run_initiator() 243*64130b0bSBill Sommerfeld except ValueError as err: 244*64130b0bSBill Sommerfeld print('FAIL:', str(err)) 245*64130b0bSBill Sommerfeld return 1 246*64130b0bSBill Sommerfeld except OSError as err: 247*64130b0bSBill Sommerfeld print('FAIL:', str(err)) 248*64130b0bSBill Sommerfeld return 1 249*64130b0bSBill Sommerfeld except socket.timeout as err: 250*64130b0bSBill Sommerfeld print('FAIL:', str(err)) 251*64130b0bSBill Sommerfeld return 1 252*64130b0bSBill Sommerfeld 253*64130b0bSBill Sommerfeld return 0 254*64130b0bSBill Sommerfeld 255*64130b0bSBill Sommerfeldif __name__ == '__main__': 256*64130b0bSBill Sommerfeld sys.exit(main(sys.argv[1:])) 257