1#!@PYTHON@ 2 3# 4# This file and its contents are supplied under the terms of the 5# Common Development and Distribution License ("CDDL"), version 1.0. 6# You may only use this file in accordance with the terms of version 7# 1.0 of the CDDL. 8# 9# A full copy of the text of the CDDL should have accompanied this 10# source. A copy of the CDDL is also available via the Internet at 11# http://www.illumos.org/license/CDDL. 12# 13 14# 15# Copyright 2024 Bill Sommerfeld <sommerfeld@hamachi.org> 16# 17 18""" Set up multiple bound sockets and then send/connect to some of them. 19Helps to test that link-local addresses are properly scoped """ 20 21import argparse 22import fcntl 23import struct 24import socket 25import sys 26from threading import Thread 27from typing import List, Optional, Tuple, Type, Union 28 29SockAddr = Union[Tuple[str, int], Tuple[str, int, int, int]] 30 31IP_BOUND_IF=0x41 32IPV6_BOUND_IF=0x41 33 34SIOCGLIFINDEX=0xc0786985 35LIFREQSIZE=376 36LIFNAMSIZ=32 37LIFRU_OFFSET=40 38 39def get_ifindex(arg: str) -> int: 40 "Look up ifindex corresponding to a named interface" 41 buf = bytearray(LIFREQSIZE) 42 43 ifname = bytes(arg, encoding='ascii') 44 if len(ifname) >= LIFNAMSIZ: 45 raise ValueError('Interface name too long', arg) 46 buf[0:len(ifname)] = ifname 47 48 with socket.socket(family=socket.AF_INET6, 49 type=socket.SOCK_DGRAM, 50 proto=socket.IPPROTO_UDP) as s: 51 fcntl.ioctl(s.fileno(), SIOCGLIFINDEX, buf) 52 return struct.unpack_from('i', buffer=buf, offset=LIFRU_OFFSET)[0] 53 54def fmt_addr(addr: SockAddr) -> str: 55 "Produce a printable form of a socket address" 56 (addrstr, portstr) = socket.getnameinfo( 57 addr, socket.NI_NUMERICHOST|socket.NI_NUMERICSERV) 58 return addrstr + ' port ' + portstr 59 60class TestProto: 61 """ Abstract(-ish) base class for test protocols """ 62 63 sockobj: socket.socket 64 proto: int = -1 65 type: int = -1 66 thread: Thread 67 ifindex: Optional[int] 68 69 def __init__(self, name: str, family: int, addr: SockAddr) -> None: 70 self.name = name 71 self.family = family 72 self.addr = addr 73 self.ifindex = None 74 75 def set_ifindex(self, ifindex: int) -> None: 76 "Save an ifindex for later" 77 self.ifindex = ifindex 78 79 def bind_ifindex(self) -> None: 80 "Apply saved ifindex (if any) to the socket" 81 82 if self.ifindex is not None: 83 print('bind to ifindex', self.ifindex) 84 if self.family==socket.AF_INET6: 85 self.sockobj.setsockopt(socket.IPPROTO_IPV6, IPV6_BOUND_IF, self.ifindex) 86 else: 87 self.sockobj.setsockopt(socket.IPPROTO_IP, IP_BOUND_IF, self.ifindex) 88 89 def setup_listener(self) -> None: 90 "Create a listening socket for the responder" 91 self.sockobj = socket.socket(family=self.family, type=self.type, proto=self.proto) 92 self.sockobj.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 93 self.bind_ifindex() 94 self.sockobj.bind(self.addr) 95 96 def start_responder(self) -> None: 97 "Create socket and start a server thread" 98 self.setup_listener() 99 self.thread = Thread(target=self.server_thread, name=self.name, daemon=True) 100 self.thread.start() 101 102 def server_thread(self) -> None: 103 "Placeholder server thread body" 104 raise ValueError 105 106 def run_initiator(self) -> None: 107 "Placeholder test client" 108 raise ValueError 109 110class TestProtoTcp(TestProto): 111 """ Simple test for TCP sockets """ 112 113 proto=socket.IPPROTO_TCP 114 type=socket.SOCK_STREAM 115 116 def setup_listener(self) -> None: 117 super().setup_listener() 118 self.sockobj.listen(5) 119 120 def conn_thread(self, conn: socket.socket) -> None: 121 "Secondary thread to handle an accepted connection" 122 while True: 123 buf = conn.recv(2048) 124 if len(buf) == 0: 125 conn.close() 126 return 127 conn.send(buf) 128 129 def server_thread(self) -> None: 130 while True: 131 (conn, fromaddr) = self.sockobj.accept() 132 print('accepted connection from', fmt_addr(fromaddr)) 133 134 t = Thread(target=self.conn_thread, name='connection', daemon=True, args=[conn]) 135 t.start() 136 137 def run_initiator(self) -> None: 138 139 self.sockobj = socket.socket(family=self.family, type=self.type, proto=self.proto) 140 self.bind_ifindex() 141 self.sockobj.settimeout(1.0) 142 self.sockobj.connect(self.addr) 143 144 145 msg=b'hello, world\n' 146 self.sockobj.send(msg) 147 buf = self.sockobj.recv(2048) 148 if msg == buf: 149 print (self.name, 'passed') 150 else: 151 raise ValueError('message mismatch', msg, buf) 152 153class TestProtoUdp(TestProto): 154 """ Simple test for UDP sockets """ 155 156 proto=socket.IPPROTO_UDP 157 type=socket.SOCK_DGRAM 158 159 def server_thread(self) -> None: 160 while True: 161 (buf, fromaddr) = self.sockobj.recvfrom(2048) 162 print('server received', len(buf), 'bytes from',fmt_addr(fromaddr)) 163 self.sockobj.sendto(buf, fromaddr) 164 165 def run_initiator(self) -> None: 166 167 self.sockobj = socket.socket(family=self.family, type=self.type, proto=self.proto) 168 self.bind_ifindex() 169 self.sockobj.settimeout(0.1) 170 self.sockobj.connect(self.addr) 171 172 msg=b'hello, world from %s\n' % bytes(self.name, encoding='utf-8') 173 self.sockobj.send(msg) 174 (buf, fromaddr) = self.sockobj.recvfrom(2048) 175 print('initiator received', len(buf), 'bytes from', fmt_addr(fromaddr)) 176 if msg == buf: 177 print (self.name, 'passed') 178 else: 179 raise ValueError('message mismatch', msg, buf) 180 181test_map = { 182 'udp': TestProtoUdp, 183 'tcp': TestProtoTcp, 184} 185 186family_map = { 187 '4': socket.AF_INET, 188 '6': socket.AF_INET6, 189} 190 191def get_addr(addr: str, port: int, family: int, 192 proto: Type[TestProto]) -> Tuple[SockAddr, Optional[int]]: 193 """Pull sockaddr,ifindex pair out of a command line argument; 194 accept either 'addr' or 'ifname,addr' syntax.""" 195 ifindex = None 196 197 if ',' in addr: 198 (ifname, addr) = addr.split(',', maxsplit=1) 199 ifindex = get_ifindex(ifname) 200 201 sa = socket.getaddrinfo(addr, port, family=family, proto=proto.proto, 202 flags=socket.AI_NUMERICHOST|socket.AI_NUMERICSERV) 203 204 return (sa[0][4], ifindex) 205 206def main(argv: List[str]) -> int: 207 "Multi-socket test. Bind several sockets; connect to several specified addresses" 208 209 parser = argparse.ArgumentParser(prog='dup-bind') 210 211 parser.add_argument('--proto', choices=test_map.keys(), required=True) 212 parser.add_argument('--family', choices=family_map.keys(), required=True) 213 parser.add_argument('--port', type=int, required=True) 214 parser.add_argument('--addr', action='append') 215 parser.add_argument('test', nargs='+') 216 217 args = parser.parse_args(argv) 218 219 endpoints = [] 220 221 family=family_map[args.family] 222 test_proto=test_map[args.proto] 223 224 try: 225 for addrstr in args.addr: 226 print('listen on', addrstr) 227 (saddr, ifindex) = get_addr(addrstr, args.port, family, test_proto) 228 229 test_addr = test_proto(name=addrstr, family=family, addr=saddr) 230 if ifindex is not None: 231 test_addr.set_ifindex(ifindex) 232 test_addr.start_responder() 233 endpoints.append(test_addr) 234 235 for addr in args.test: 236 print('test to', addr) 237 (saddr, ifindex) = get_addr(addr, args.port, family, test_proto) 238 239 test_addr = test_proto(name=addr, family=family, addr=saddr) 240 if ifindex is not None: 241 test_addr.set_ifindex(ifindex) 242 test_addr.run_initiator() 243 except ValueError as err: 244 print('FAIL:', str(err)) 245 return 1 246 except OSError as err: 247 print('FAIL:', str(err)) 248 return 1 249 except socket.timeout as err: 250 print('FAIL:', str(err)) 251 return 1 252 253 return 0 254 255if __name__ == '__main__': 256 sys.exit(main(sys.argv[1:])) 257