xref: /linux/tools/testing/selftests/net/tcp_fastopen_backup_key.c (revision 4201c9260a8d3c4ef238e51692a7e9b4e1e29efe)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 /*
4  * Test key rotation for TFO.
5  * New keys are 'rotated' in two steps:
6  * 1) Add new key as the 'backup' key 'behind' the primary key
7  * 2) Make new key the primary by swapping the backup and primary keys
8  *
9  * The rotation is done in stages using multiple sockets bound
10  * to the same port via SO_REUSEPORT. This simulates key rotation
11  * behind say a load balancer. We verify that across the rotation
12  * there are no cases in which a cookie is not accepted by verifying
13  * that TcpExtTCPFastOpenPassiveFail remains 0.
14  */
15 #define _GNU_SOURCE
16 #include <arpa/inet.h>
17 #include <errno.h>
18 #include <error.h>
19 #include <stdbool.h>
20 #include <stdio.h>
21 #include <stdlib.h>
22 #include <string.h>
23 #include <sys/epoll.h>
24 #include <unistd.h>
25 #include <netinet/tcp.h>
26 #include <fcntl.h>
27 #include <time.h>
28 
29 #ifndef TCP_FASTOPEN_KEY
30 #define TCP_FASTOPEN_KEY 33
31 #endif
32 
33 #define N_LISTEN 10
34 #define PROC_FASTOPEN_KEY "/proc/sys/net/ipv4/tcp_fastopen_key"
35 #define KEY_LENGTH 16
36 
37 #ifndef ARRAY_SIZE
38 #define ARRAY_SIZE(arr) (sizeof(arr) / sizeof((arr)[0]))
39 #endif
40 
41 static bool do_ipv6;
42 static bool do_sockopt;
43 static bool do_rotate;
44 static int key_len = KEY_LENGTH;
45 static int rcv_fds[N_LISTEN];
46 static int proc_fd;
47 static const char *IP4_ADDR = "127.0.0.1";
48 static const char *IP6_ADDR = "::1";
49 static const int PORT = 8891;
50 
51 static void get_keys(int fd, uint32_t *keys)
52 {
53 	char buf[128];
54 	int len = KEY_LENGTH * 2;
55 
56 	if (do_sockopt) {
57 		if (getsockopt(fd, SOL_TCP, TCP_FASTOPEN_KEY, keys, &len))
58 			error(1, errno, "Unable to get key");
59 		return;
60 	}
61 	lseek(proc_fd, 0, SEEK_SET);
62 	if (read(proc_fd, buf, sizeof(buf)) <= 0)
63 		error(1, errno, "Unable to read %s", PROC_FASTOPEN_KEY);
64 	if (sscanf(buf, "%x-%x-%x-%x,%x-%x-%x-%x", keys, keys + 1, keys + 2,
65 	    keys + 3, keys + 4, keys + 5, keys + 6, keys + 7) != 8)
66 		error(1, 0, "Unable to parse %s", PROC_FASTOPEN_KEY);
67 }
68 
69 static void set_keys(int fd, uint32_t *keys)
70 {
71 	char buf[128];
72 
73 	if (do_sockopt) {
74 		if (setsockopt(fd, SOL_TCP, TCP_FASTOPEN_KEY, keys,
75 		    key_len))
76 			error(1, errno, "Unable to set key");
77 		return;
78 	}
79 	if (do_rotate)
80 		snprintf(buf, 128, "%08x-%08x-%08x-%08x,%08x-%08x-%08x-%08x",
81 			 keys[0], keys[1], keys[2], keys[3], keys[4], keys[5],
82 			 keys[6], keys[7]);
83 	else
84 		snprintf(buf, 128, "%08x-%08x-%08x-%08x",
85 			 keys[0], keys[1], keys[2], keys[3]);
86 	lseek(proc_fd, 0, SEEK_SET);
87 	if (write(proc_fd, buf, sizeof(buf)) <= 0)
88 		error(1, errno, "Unable to write %s", PROC_FASTOPEN_KEY);
89 }
90 
91 static void build_rcv_fd(int family, int proto, int *rcv_fds)
92 {
93 	struct sockaddr_in  addr4 = {0};
94 	struct sockaddr_in6 addr6 = {0};
95 	struct sockaddr *addr;
96 	int opt = 1, i, sz;
97 	int qlen = 100;
98 	uint32_t keys[8];
99 
100 	switch (family) {
101 	case AF_INET:
102 		addr4.sin_family = family;
103 		addr4.sin_addr.s_addr = htonl(INADDR_ANY);
104 		addr4.sin_port = htons(PORT);
105 		sz = sizeof(addr4);
106 		addr = (struct sockaddr *)&addr4;
107 		break;
108 	case AF_INET6:
109 		addr6.sin6_family = AF_INET6;
110 		addr6.sin6_addr = in6addr_any;
111 		addr6.sin6_port = htons(PORT);
112 		sz = sizeof(addr6);
113 		addr = (struct sockaddr *)&addr6;
114 		break;
115 	default:
116 		error(1, 0, "Unsupported family %d", family);
117 		/* clang does not recognize error() above as terminating
118 		 * the program, so it complains that saddr, sz are
119 		 * not initialized when this code path is taken. Silence it.
120 		 */
121 		return;
122 	}
123 	for (i = 0; i < ARRAY_SIZE(keys); i++)
124 		keys[i] = rand();
125 	for (i = 0; i < N_LISTEN; i++) {
126 		rcv_fds[i] = socket(family, proto, 0);
127 		if (rcv_fds[i] < 0)
128 			error(1, errno, "failed to create receive socket");
129 		if (setsockopt(rcv_fds[i], SOL_SOCKET, SO_REUSEPORT, &opt,
130 			       sizeof(opt)))
131 			error(1, errno, "failed to set SO_REUSEPORT");
132 		if (bind(rcv_fds[i], addr, sz))
133 			error(1, errno, "failed to bind receive socket");
134 		if (setsockopt(rcv_fds[i], SOL_TCP, TCP_FASTOPEN, &qlen,
135 			       sizeof(qlen)))
136 			error(1, errno, "failed to set TCP_FASTOPEN");
137 		set_keys(rcv_fds[i], keys);
138 		if (proto == SOCK_STREAM && listen(rcv_fds[i], 10))
139 			error(1, errno, "failed to listen on receive port");
140 	}
141 }
142 
143 static int connect_and_send(int family, int proto)
144 {
145 	struct sockaddr_in  saddr4 = {0};
146 	struct sockaddr_in  daddr4 = {0};
147 	struct sockaddr_in6 saddr6 = {0};
148 	struct sockaddr_in6 daddr6 = {0};
149 	struct sockaddr *saddr, *daddr;
150 	int fd, sz, ret;
151 	char data[1];
152 
153 	switch (family) {
154 	case AF_INET:
155 		saddr4.sin_family = AF_INET;
156 		saddr4.sin_addr.s_addr = htonl(INADDR_ANY);
157 		saddr4.sin_port = 0;
158 
159 		daddr4.sin_family = AF_INET;
160 		if (!inet_pton(family, IP4_ADDR, &daddr4.sin_addr.s_addr))
161 			error(1, errno, "inet_pton failed: %s", IP4_ADDR);
162 		daddr4.sin_port = htons(PORT);
163 
164 		sz = sizeof(saddr4);
165 		saddr = (struct sockaddr *)&saddr4;
166 		daddr = (struct sockaddr *)&daddr4;
167 		break;
168 	case AF_INET6:
169 		saddr6.sin6_family = AF_INET6;
170 		saddr6.sin6_addr = in6addr_any;
171 
172 		daddr6.sin6_family = AF_INET6;
173 		if (!inet_pton(family, IP6_ADDR, &daddr6.sin6_addr))
174 			error(1, errno, "inet_pton failed: %s", IP6_ADDR);
175 		daddr6.sin6_port = htons(PORT);
176 
177 		sz = sizeof(saddr6);
178 		saddr = (struct sockaddr *)&saddr6;
179 		daddr = (struct sockaddr *)&daddr6;
180 		break;
181 	default:
182 		error(1, 0, "Unsupported family %d", family);
183 		/* clang does not recognize error() above as terminating
184 		 * the program, so it complains that saddr, daddr, sz are
185 		 * not initialized when this code path is taken. Silence it.
186 		 */
187 		return -1;
188 	}
189 	fd = socket(family, proto, 0);
190 	if (fd < 0)
191 		error(1, errno, "failed to create send socket");
192 	if (bind(fd, saddr, sz))
193 		error(1, errno, "failed to bind send socket");
194 	data[0] = 'a';
195 	ret = sendto(fd, data, 1, MSG_FASTOPEN, daddr, sz);
196 	if (ret != 1)
197 		error(1, errno, "failed to sendto");
198 
199 	return fd;
200 }
201 
202 static bool is_listen_fd(int fd)
203 {
204 	int i;
205 
206 	for (i = 0; i < N_LISTEN; i++) {
207 		if (rcv_fds[i] == fd)
208 			return true;
209 	}
210 	return false;
211 }
212 
213 static int rotate_key(int fd)
214 {
215 	static int iter;
216 	static uint32_t new_key[4];
217 	uint32_t keys[8];
218 	uint32_t tmp_key[4];
219 	int i;
220 	int len = KEY_LENGTH * 2;
221 
222 	if (iter < N_LISTEN) {
223 		/* first set new key as backups */
224 		if (iter == 0) {
225 			for (i = 0; i < ARRAY_SIZE(new_key); i++)
226 				new_key[i] = rand();
227 		}
228 		get_keys(fd, keys);
229 		memcpy(keys + 4, new_key, KEY_LENGTH);
230 		set_keys(fd, keys);
231 	} else {
232 		/* swap the keys */
233 		get_keys(fd, keys);
234 		memcpy(tmp_key, keys + 4, KEY_LENGTH);
235 		memcpy(keys + 4, keys, KEY_LENGTH);
236 		memcpy(keys, tmp_key, KEY_LENGTH);
237 		set_keys(fd, keys);
238 	}
239 	if (++iter >= (N_LISTEN * 2))
240 		iter = 0;
241 }
242 
243 static void run_one_test(int family)
244 {
245 	struct epoll_event ev;
246 	int i, send_fd;
247 	int n_loops = 10000;
248 	int rotate_key_fd = 0;
249 	int key_rotate_interval = 50;
250 	int fd, epfd;
251 	char buf[1];
252 
253 	build_rcv_fd(family, SOCK_STREAM, rcv_fds);
254 	epfd = epoll_create(1);
255 	if (epfd < 0)
256 		error(1, errno, "failed to create epoll");
257 	ev.events = EPOLLIN;
258 	for (i = 0; i < N_LISTEN; i++) {
259 		ev.data.fd = rcv_fds[i];
260 		if (epoll_ctl(epfd, EPOLL_CTL_ADD, rcv_fds[i], &ev))
261 			error(1, errno, "failed to register sock epoll");
262 	}
263 	while (n_loops--) {
264 		send_fd = connect_and_send(family, SOCK_STREAM);
265 		if (do_rotate && ((n_loops % key_rotate_interval) == 0)) {
266 			rotate_key(rcv_fds[rotate_key_fd]);
267 			if (++rotate_key_fd >= N_LISTEN)
268 				rotate_key_fd = 0;
269 		}
270 		while (1) {
271 			i = epoll_wait(epfd, &ev, 1, -1);
272 			if (i < 0)
273 				error(1, errno, "epoll_wait failed");
274 			if (is_listen_fd(ev.data.fd)) {
275 				fd = accept(ev.data.fd, NULL, NULL);
276 				if (fd < 0)
277 					error(1, errno, "failed to accept");
278 				ev.data.fd = fd;
279 				if (epoll_ctl(epfd, EPOLL_CTL_ADD, fd, &ev))
280 					error(1, errno, "failed epoll add");
281 				continue;
282 			}
283 			i = recv(ev.data.fd, buf, sizeof(buf), 0);
284 			if (i != 1)
285 				error(1, errno, "failed recv data");
286 			if (epoll_ctl(epfd, EPOLL_CTL_DEL, ev.data.fd, NULL))
287 				error(1, errno, "failed epoll del");
288 			close(ev.data.fd);
289 			break;
290 		}
291 		close(send_fd);
292 	}
293 	for (i = 0; i < N_LISTEN; i++)
294 		close(rcv_fds[i]);
295 }
296 
297 static void parse_opts(int argc, char **argv)
298 {
299 	int c;
300 
301 	while ((c = getopt(argc, argv, "46sr")) != -1) {
302 		switch (c) {
303 		case '4':
304 			do_ipv6 = false;
305 			break;
306 		case '6':
307 			do_ipv6 = true;
308 			break;
309 		case 's':
310 			do_sockopt = true;
311 			break;
312 		case 'r':
313 			do_rotate = true;
314 			key_len = KEY_LENGTH * 2;
315 			break;
316 		default:
317 			error(1, 0, "%s: parse error", argv[0]);
318 		}
319 	}
320 }
321 
322 int main(int argc, char **argv)
323 {
324 	parse_opts(argc, argv);
325 	proc_fd = open(PROC_FASTOPEN_KEY, O_RDWR);
326 	if (proc_fd < 0)
327 		error(1, errno, "Unable to open %s", PROC_FASTOPEN_KEY);
328 	srand(time(NULL));
329 	if (do_ipv6)
330 		run_one_test(AF_INET6);
331 	else
332 		run_one_test(AF_INET);
333 	close(proc_fd);
334 	fprintf(stderr, "PASS\n");
335 	return 0;
336 }
337