xref: /illumos-gate/usr/src/test/net-tests/tests/ipv6/dup_bind.py (revision 64130b0be265e6f79e86a9f3c515fb40680f25b1)
1#!@PYTHON@
2
3#
4# This file and its contents are supplied under the terms of the
5# Common Development and Distribution License ("CDDL"), version 1.0.
6# You may only use this file in accordance with the terms of version
7# 1.0 of the CDDL.
8#
9# A full copy of the text of the CDDL should have accompanied this
10# source.  A copy of the CDDL is also available via the Internet at
11# http://www.illumos.org/license/CDDL.
12#
13
14#
15# Copyright 2024 Bill Sommerfeld <sommerfeld@hamachi.org>
16#
17
18""" Set up multiple bound sockets and then send/connect to some of them.
19Helps to test that link-local addresses are properly scoped """
20
21import argparse
22import fcntl
23import struct
24import socket
25import sys
26from threading import Thread
27from typing import List, Optional, Tuple, Type, Union
28
29SockAddr = Union[Tuple[str, int], Tuple[str, int, int, int]]
30
31IP_BOUND_IF=0x41
32IPV6_BOUND_IF=0x41
33
34SIOCGLIFINDEX=0xc0786985
35LIFREQSIZE=376
36LIFNAMSIZ=32
37LIFRU_OFFSET=40
38
39def get_ifindex(arg: str) -> int:
40    "Look up ifindex corresponding to a named interface"
41    buf = bytearray(LIFREQSIZE)
42
43    ifname = bytes(arg, encoding='ascii')
44    if len(ifname) >= LIFNAMSIZ:
45        raise ValueError('Interface name too long', arg)
46    buf[0:len(ifname)] = ifname
47
48    with socket.socket(family=socket.AF_INET6,
49                       type=socket.SOCK_DGRAM,
50                       proto=socket.IPPROTO_UDP) as s:
51        fcntl.ioctl(s.fileno(), SIOCGLIFINDEX, buf)
52        return struct.unpack_from('i', buffer=buf, offset=LIFRU_OFFSET)[0]
53
54def fmt_addr(addr: SockAddr) -> str:
55    "Produce a printable form of a socket address"
56    (addrstr, portstr) = socket.getnameinfo(
57        addr, socket.NI_NUMERICHOST|socket.NI_NUMERICSERV)
58    return addrstr + ' port ' + portstr
59
60class TestProto:
61    """ Abstract(-ish) base class for test protocols """
62
63    sockobj: socket.socket
64    proto: int = -1
65    type: int = -1
66    thread: Thread
67    ifindex: Optional[int]
68
69    def __init__(self, name: str, family: int, addr: SockAddr) -> None:
70        self.name = name
71        self.family = family
72        self.addr = addr
73        self.ifindex = None
74
75    def set_ifindex(self, ifindex: int) -> None:
76        "Save an ifindex for later"
77        self.ifindex = ifindex
78
79    def bind_ifindex(self) -> None:
80        "Apply saved ifindex (if any) to the socket"
81
82        if self.ifindex is not None:
83            print('bind to ifindex', self.ifindex)
84            if self.family==socket.AF_INET6:
85                self.sockobj.setsockopt(socket.IPPROTO_IPV6, IPV6_BOUND_IF, self.ifindex)
86            else:
87                self.sockobj.setsockopt(socket.IPPROTO_IP, IP_BOUND_IF, self.ifindex)
88
89    def setup_listener(self) -> None:
90        "Create a listening socket for the responder"
91        self.sockobj = socket.socket(family=self.family, type=self.type, proto=self.proto)
92        self.sockobj.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
93        self.bind_ifindex()
94        self.sockobj.bind(self.addr)
95
96    def start_responder(self) -> None:
97        "Create socket and start a server thread"
98        self.setup_listener()
99        self.thread = Thread(target=self.server_thread, name=self.name, daemon=True)
100        self.thread.start()
101
102    def server_thread(self) -> None:
103        "Placeholder server thread body"
104        raise ValueError
105
106    def run_initiator(self) -> None:
107        "Placeholder test client"
108        raise ValueError
109
110class TestProtoTcp(TestProto):
111    """ Simple test for TCP sockets """
112
113    proto=socket.IPPROTO_TCP
114    type=socket.SOCK_STREAM
115
116    def setup_listener(self) -> None:
117        super().setup_listener()
118        self.sockobj.listen(5)
119
120    def conn_thread(self, conn: socket.socket) -> None:
121        "Secondary thread to handle an accepted connection"
122        while True:
123            buf = conn.recv(2048)
124            if len(buf) == 0:
125                conn.close()
126                return
127            conn.send(buf)
128
129    def server_thread(self) -> None:
130        while True:
131            (conn, fromaddr) = self.sockobj.accept()
132            print('accepted connection from', fmt_addr(fromaddr))
133
134            t = Thread(target=self.conn_thread, name='connection', daemon=True, args=[conn])
135            t.start()
136
137    def run_initiator(self) -> None:
138
139        self.sockobj = socket.socket(family=self.family, type=self.type, proto=self.proto)
140        self.bind_ifindex()
141        self.sockobj.settimeout(1.0)
142        self.sockobj.connect(self.addr)
143
144
145        msg=b'hello, world\n'
146        self.sockobj.send(msg)
147        buf = self.sockobj.recv(2048)
148        if msg == buf:
149            print (self.name, 'passed')
150        else:
151            raise ValueError('message mismatch', msg, buf)
152
153class TestProtoUdp(TestProto):
154    """ Simple test for UDP sockets """
155
156    proto=socket.IPPROTO_UDP
157    type=socket.SOCK_DGRAM
158
159    def server_thread(self) -> None:
160        while True:
161            (buf, fromaddr) = self.sockobj.recvfrom(2048)
162            print('server received', len(buf), 'bytes from',fmt_addr(fromaddr))
163            self.sockobj.sendto(buf, fromaddr)
164
165    def run_initiator(self) -> None:
166
167        self.sockobj = socket.socket(family=self.family, type=self.type, proto=self.proto)
168        self.bind_ifindex()
169        self.sockobj.settimeout(0.1)
170        self.sockobj.connect(self.addr)
171
172        msg=b'hello, world from %s\n' % bytes(self.name, encoding='utf-8')
173        self.sockobj.send(msg)
174        (buf, fromaddr) = self.sockobj.recvfrom(2048)
175        print('initiator received', len(buf), 'bytes from', fmt_addr(fromaddr))
176        if msg == buf:
177            print (self.name, 'passed')
178        else:
179            raise ValueError('message mismatch', msg, buf)
180
181test_map = {
182    'udp': TestProtoUdp,
183    'tcp': TestProtoTcp,
184}
185
186family_map = {
187    '4': socket.AF_INET,
188    '6': socket.AF_INET6,
189}
190
191def get_addr(addr: str, port: int, family: int,
192             proto: Type[TestProto]) -> Tuple[SockAddr, Optional[int]]:
193    """Pull sockaddr,ifindex pair out of a command line argument;
194    accept either 'addr' or 'ifname,addr' syntax."""
195    ifindex = None
196
197    if ',' in addr:
198        (ifname, addr) = addr.split(',', maxsplit=1)
199        ifindex = get_ifindex(ifname)
200
201    sa = socket.getaddrinfo(addr, port, family=family, proto=proto.proto,
202                            flags=socket.AI_NUMERICHOST|socket.AI_NUMERICSERV)
203
204    return (sa[0][4], ifindex)
205
206def main(argv: List[str]) -> int:
207    "Multi-socket test.   Bind several sockets; connect to several specified addresses"
208
209    parser = argparse.ArgumentParser(prog='dup-bind')
210
211    parser.add_argument('--proto', choices=test_map.keys(), required=True)
212    parser.add_argument('--family', choices=family_map.keys(), required=True)
213    parser.add_argument('--port', type=int, required=True)
214    parser.add_argument('--addr', action='append')
215    parser.add_argument('test', nargs='+')
216
217    args = parser.parse_args(argv)
218
219    endpoints = []
220
221    family=family_map[args.family]
222    test_proto=test_map[args.proto]
223
224    try:
225        for addrstr in args.addr:
226            print('listen on', addrstr)
227            (saddr, ifindex) = get_addr(addrstr, args.port, family, test_proto)
228
229            test_addr = test_proto(name=addrstr, family=family, addr=saddr)
230            if ifindex is not None:
231                test_addr.set_ifindex(ifindex)
232            test_addr.start_responder()
233            endpoints.append(test_addr)
234
235        for addr in args.test:
236            print('test to', addr)
237            (saddr, ifindex) = get_addr(addr, args.port, family, test_proto)
238
239            test_addr = test_proto(name=addr, family=family, addr=saddr)
240            if ifindex is not None:
241                test_addr.set_ifindex(ifindex)
242            test_addr.run_initiator()
243    except ValueError as err:
244        print('FAIL:', str(err))
245        return 1
246    except OSError as err:
247        print('FAIL:', str(err))
248        return 1
249    except socket.timeout as err:
250        print('FAIL:', str(err))
251        return 1
252
253    return 0
254
255if __name__ == '__main__':
256    sys.exit(main(sys.argv[1:]))
257