xref: /linux/tools/testing/selftests/net/mptcp/mptcp_inq.c (revision 24b10e5f8e0d2bee1a10fc67011ea5d936c1a389)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 #define _GNU_SOURCE
4 
5 #include <assert.h>
6 #include <errno.h>
7 #include <fcntl.h>
8 #include <limits.h>
9 #include <string.h>
10 #include <stdarg.h>
11 #include <stdbool.h>
12 #include <stdint.h>
13 #include <inttypes.h>
14 #include <stdio.h>
15 #include <stdlib.h>
16 #include <strings.h>
17 #include <unistd.h>
18 #include <time.h>
19 
20 #include <sys/ioctl.h>
21 #include <sys/random.h>
22 #include <sys/socket.h>
23 #include <sys/types.h>
24 #include <sys/wait.h>
25 
26 #include <netdb.h>
27 #include <netinet/in.h>
28 
29 #include <linux/tcp.h>
30 #include <linux/sockios.h>
31 
32 #ifndef IPPROTO_MPTCP
33 #define IPPROTO_MPTCP 262
34 #endif
35 #ifndef SOL_MPTCP
36 #define SOL_MPTCP 284
37 #endif
38 
39 static int pf = AF_INET;
40 static int proto_tx = IPPROTO_MPTCP;
41 static int proto_rx = IPPROTO_MPTCP;
42 
43 static void die_perror(const char *msg)
44 {
45 	perror(msg);
46 	exit(1);
47 }
48 
49 static void die_usage(int r)
50 {
51 	fprintf(stderr, "Usage: mptcp_inq [-6] [ -t tcp|mptcp ] [ -r tcp|mptcp]\n");
52 	exit(r);
53 }
54 
55 static void xerror(const char *fmt, ...)
56 {
57 	va_list ap;
58 
59 	va_start(ap, fmt);
60 	vfprintf(stderr, fmt, ap);
61 	va_end(ap);
62 	fputc('\n', stderr);
63 	exit(1);
64 }
65 
66 static const char *getxinfo_strerr(int err)
67 {
68 	if (err == EAI_SYSTEM)
69 		return strerror(errno);
70 
71 	return gai_strerror(err);
72 }
73 
74 static void xgetaddrinfo(const char *node, const char *service,
75 			 const struct addrinfo *hints,
76 			 struct addrinfo **res)
77 {
78 	int err = getaddrinfo(node, service, hints, res);
79 
80 	if (err) {
81 		const char *errstr = getxinfo_strerr(err);
82 
83 		fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n",
84 			node ? node : "", service ? service : "", errstr);
85 		exit(1);
86 	}
87 }
88 
89 static int sock_listen_mptcp(const char * const listenaddr,
90 			     const char * const port)
91 {
92 	int sock = -1;
93 	struct addrinfo hints = {
94 		.ai_protocol = IPPROTO_TCP,
95 		.ai_socktype = SOCK_STREAM,
96 		.ai_flags = AI_PASSIVE | AI_NUMERICHOST
97 	};
98 
99 	hints.ai_family = pf;
100 
101 	struct addrinfo *a, *addr;
102 	int one = 1;
103 
104 	xgetaddrinfo(listenaddr, port, &hints, &addr);
105 	hints.ai_family = pf;
106 
107 	for (a = addr; a; a = a->ai_next) {
108 		sock = socket(a->ai_family, a->ai_socktype, proto_rx);
109 		if (sock < 0)
110 			continue;
111 
112 		if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one,
113 				     sizeof(one)))
114 			perror("setsockopt");
115 
116 		if (bind(sock, a->ai_addr, a->ai_addrlen) == 0)
117 			break; /* success */
118 
119 		perror("bind");
120 		close(sock);
121 		sock = -1;
122 	}
123 
124 	freeaddrinfo(addr);
125 
126 	if (sock < 0)
127 		xerror("could not create listen socket");
128 
129 	if (listen(sock, 20))
130 		die_perror("listen");
131 
132 	return sock;
133 }
134 
135 static int sock_connect_mptcp(const char * const remoteaddr,
136 			      const char * const port, int proto)
137 {
138 	struct addrinfo hints = {
139 		.ai_protocol = IPPROTO_TCP,
140 		.ai_socktype = SOCK_STREAM,
141 	};
142 	struct addrinfo *a, *addr;
143 	int sock = -1;
144 
145 	hints.ai_family = pf;
146 
147 	xgetaddrinfo(remoteaddr, port, &hints, &addr);
148 	for (a = addr; a; a = a->ai_next) {
149 		sock = socket(a->ai_family, a->ai_socktype, proto);
150 		if (sock < 0)
151 			continue;
152 
153 		if (connect(sock, a->ai_addr, a->ai_addrlen) == 0)
154 			break; /* success */
155 
156 		die_perror("connect");
157 	}
158 
159 	if (sock < 0)
160 		xerror("could not create connect socket");
161 
162 	freeaddrinfo(addr);
163 	return sock;
164 }
165 
166 static int protostr_to_num(const char *s)
167 {
168 	if (strcasecmp(s, "tcp") == 0)
169 		return IPPROTO_TCP;
170 	if (strcasecmp(s, "mptcp") == 0)
171 		return IPPROTO_MPTCP;
172 
173 	die_usage(1);
174 	return 0;
175 }
176 
177 static void parse_opts(int argc, char **argv)
178 {
179 	int c;
180 
181 	while ((c = getopt(argc, argv, "h6t:r:")) != -1) {
182 		switch (c) {
183 		case 'h':
184 			die_usage(0);
185 			break;
186 		case '6':
187 			pf = AF_INET6;
188 			break;
189 		case 't':
190 			proto_tx = protostr_to_num(optarg);
191 			break;
192 		case 'r':
193 			proto_rx = protostr_to_num(optarg);
194 			break;
195 		default:
196 			die_usage(1);
197 			break;
198 		}
199 	}
200 }
201 
202 /* wait up to timeout milliseconds */
203 static void wait_for_ack(int fd, int timeout, size_t total)
204 {
205 	int i;
206 
207 	for (i = 0; i < timeout; i++) {
208 		int nsd, ret, queued = -1;
209 		struct timespec req;
210 
211 		ret = ioctl(fd, TIOCOUTQ, &queued);
212 		if (ret < 0)
213 			die_perror("TIOCOUTQ");
214 
215 		ret = ioctl(fd, SIOCOUTQNSD, &nsd);
216 		if (ret < 0)
217 			die_perror("SIOCOUTQNSD");
218 
219 		if ((size_t)queued > total)
220 			xerror("TIOCOUTQ %u, but only %zu expected\n", queued, total);
221 		assert(nsd <= queued);
222 
223 		if (queued == 0)
224 			return;
225 
226 		/* wait for peer to ack rx of all data */
227 		req.tv_sec = 0;
228 		req.tv_nsec = 1 * 1000 * 1000ul; /* 1ms */
229 		nanosleep(&req, NULL);
230 	}
231 
232 	xerror("still tx data queued after %u ms\n", timeout);
233 }
234 
235 static void connect_one_server(int fd, int unixfd)
236 {
237 	size_t len, i, total, sent;
238 	char buf[4096], buf2[4096];
239 	ssize_t ret;
240 
241 	len = rand() % (sizeof(buf) - 1);
242 
243 	if (len < 128)
244 		len = 128;
245 
246 	for (i = 0; i < len ; i++) {
247 		buf[i] = rand() % 26;
248 		buf[i] += 'A';
249 	}
250 
251 	buf[i] = '\n';
252 
253 	/* un-block server */
254 	ret = read(unixfd, buf2, 4);
255 	assert(ret == 4);
256 
257 	assert(strncmp(buf2, "xmit", 4) == 0);
258 
259 	ret = write(unixfd, &len, sizeof(len));
260 	assert(ret == (ssize_t)sizeof(len));
261 
262 	ret = write(fd, buf, len);
263 	if (ret < 0)
264 		die_perror("write");
265 
266 	if (ret != (ssize_t)len)
267 		xerror("short write");
268 
269 	ret = read(unixfd, buf2, 4);
270 	assert(strncmp(buf2, "huge", 4) == 0);
271 
272 	total = rand() % (16 * 1024 * 1024);
273 	total += (1 * 1024 * 1024);
274 	sent = total;
275 
276 	ret = write(unixfd, &total, sizeof(total));
277 	assert(ret == (ssize_t)sizeof(total));
278 
279 	wait_for_ack(fd, 5000, len);
280 
281 	while (total > 0) {
282 		if (total > sizeof(buf))
283 			len = sizeof(buf);
284 		else
285 			len = total;
286 
287 		ret = write(fd, buf, len);
288 		if (ret < 0)
289 			die_perror("write");
290 		total -= ret;
291 
292 		/* we don't have to care about buf content, only
293 		 * number of total bytes sent
294 		 */
295 	}
296 
297 	ret = read(unixfd, buf2, 4);
298 	assert(ret == 4);
299 	assert(strncmp(buf2, "shut", 4) == 0);
300 
301 	wait_for_ack(fd, 5000, sent);
302 
303 	ret = write(fd, buf, 1);
304 	assert(ret == 1);
305 	close(fd);
306 	ret = write(unixfd, "closed", 6);
307 	assert(ret == 6);
308 
309 	close(unixfd);
310 }
311 
312 static void get_tcp_inq(struct msghdr *msgh, unsigned int *inqv)
313 {
314 	struct cmsghdr *cmsg;
315 
316 	for (cmsg = CMSG_FIRSTHDR(msgh); cmsg ; cmsg = CMSG_NXTHDR(msgh, cmsg)) {
317 		if (cmsg->cmsg_level == IPPROTO_TCP && cmsg->cmsg_type == TCP_CM_INQ) {
318 			memcpy(inqv, CMSG_DATA(cmsg), sizeof(*inqv));
319 			return;
320 		}
321 	}
322 
323 	xerror("could not find TCP_CM_INQ cmsg type");
324 }
325 
326 static void process_one_client(int fd, int unixfd)
327 {
328 	unsigned int tcp_inq;
329 	size_t expect_len;
330 	char msg_buf[4096];
331 	char buf[4096];
332 	char tmp[16];
333 	struct iovec iov = {
334 		.iov_base = buf,
335 		.iov_len = 1,
336 	};
337 	struct msghdr msg = {
338 		.msg_iov = &iov,
339 		.msg_iovlen = 1,
340 		.msg_control = msg_buf,
341 		.msg_controllen = sizeof(msg_buf),
342 	};
343 	ssize_t ret, tot;
344 
345 	ret = write(unixfd, "xmit", 4);
346 	assert(ret == 4);
347 
348 	ret = read(unixfd, &expect_len, sizeof(expect_len));
349 	assert(ret == (ssize_t)sizeof(expect_len));
350 
351 	if (expect_len > sizeof(buf))
352 		xerror("expect len %zu exceeds buffer size", expect_len);
353 
354 	for (;;) {
355 		struct timespec req;
356 		unsigned int queued;
357 
358 		ret = ioctl(fd, FIONREAD, &queued);
359 		if (ret < 0)
360 			die_perror("FIONREAD");
361 		if (queued > expect_len)
362 			xerror("FIONREAD returned %u, but only %zu expected\n",
363 			       queued, expect_len);
364 		if (queued == expect_len)
365 			break;
366 
367 		req.tv_sec = 0;
368 		req.tv_nsec = 1000 * 1000ul;
369 		nanosleep(&req, NULL);
370 	}
371 
372 	/* read one byte, expect cmsg to return expected - 1 */
373 	ret = recvmsg(fd, &msg, 0);
374 	if (ret < 0)
375 		die_perror("recvmsg");
376 
377 	if (msg.msg_controllen == 0)
378 		xerror("msg_controllen is 0");
379 
380 	get_tcp_inq(&msg, &tcp_inq);
381 
382 	assert((size_t)tcp_inq == (expect_len - 1));
383 
384 	iov.iov_len = sizeof(buf);
385 	ret = recvmsg(fd, &msg, 0);
386 	if (ret < 0)
387 		die_perror("recvmsg");
388 
389 	/* should have gotten exact remainder of all pending data */
390 	assert(ret == (ssize_t)tcp_inq);
391 
392 	/* should be 0, all drained */
393 	get_tcp_inq(&msg, &tcp_inq);
394 	assert(tcp_inq == 0);
395 
396 	/* request a large swath of data. */
397 	ret = write(unixfd, "huge", 4);
398 	assert(ret == 4);
399 
400 	ret = read(unixfd, &expect_len, sizeof(expect_len));
401 	assert(ret == (ssize_t)sizeof(expect_len));
402 
403 	/* peer should send us a few mb of data */
404 	if (expect_len <= sizeof(buf))
405 		xerror("expect len %zu too small\n", expect_len);
406 
407 	tot = 0;
408 	do {
409 		iov.iov_len = sizeof(buf);
410 		ret = recvmsg(fd, &msg, 0);
411 		if (ret < 0)
412 			die_perror("recvmsg");
413 
414 		tot += ret;
415 
416 		get_tcp_inq(&msg, &tcp_inq);
417 
418 		if (tcp_inq > expect_len - tot)
419 			xerror("inq %d, remaining %d total_len %d\n",
420 			       tcp_inq, expect_len - tot, (int)expect_len);
421 
422 		assert(tcp_inq <= expect_len - tot);
423 	} while ((size_t)tot < expect_len);
424 
425 	ret = write(unixfd, "shut", 4);
426 	assert(ret == 4);
427 
428 	/* wait for hangup. Should have received one more byte of data. */
429 	ret = read(unixfd, tmp, sizeof(tmp));
430 	assert(ret == 6);
431 	assert(strncmp(tmp, "closed", 6) == 0);
432 
433 	sleep(1);
434 
435 	iov.iov_len = 1;
436 	ret = recvmsg(fd, &msg, 0);
437 	if (ret < 0)
438 		die_perror("recvmsg");
439 	assert(ret == 1);
440 
441 	get_tcp_inq(&msg, &tcp_inq);
442 
443 	/* tcp_inq should be 1 due to received fin. */
444 	assert(tcp_inq == 1);
445 
446 	iov.iov_len = 1;
447 	ret = recvmsg(fd, &msg, 0);
448 	if (ret < 0)
449 		die_perror("recvmsg");
450 
451 	/* expect EOF */
452 	assert(ret == 0);
453 	get_tcp_inq(&msg, &tcp_inq);
454 	assert(tcp_inq == 1);
455 
456 	close(fd);
457 }
458 
459 static int xaccept(int s)
460 {
461 	int fd = accept(s, NULL, 0);
462 
463 	if (fd < 0)
464 		die_perror("accept");
465 
466 	return fd;
467 }
468 
469 static int server(int unixfd)
470 {
471 	int fd = -1, r, on = 1;
472 
473 	switch (pf) {
474 	case AF_INET:
475 		fd = sock_listen_mptcp("127.0.0.1", "15432");
476 		break;
477 	case AF_INET6:
478 		fd = sock_listen_mptcp("::1", "15432");
479 		break;
480 	default:
481 		xerror("Unknown pf %d\n", pf);
482 		break;
483 	}
484 
485 	r = write(unixfd, "conn", 4);
486 	assert(r == 4);
487 
488 	alarm(15);
489 	r = xaccept(fd);
490 
491 	if (-1 == setsockopt(r, IPPROTO_TCP, TCP_INQ, &on, sizeof(on)))
492 		die_perror("setsockopt");
493 
494 	process_one_client(r, unixfd);
495 
496 	return 0;
497 }
498 
499 static int client(int unixfd)
500 {
501 	int fd = -1;
502 
503 	alarm(15);
504 
505 	switch (pf) {
506 	case AF_INET:
507 		fd = sock_connect_mptcp("127.0.0.1", "15432", proto_tx);
508 		break;
509 	case AF_INET6:
510 		fd = sock_connect_mptcp("::1", "15432", proto_tx);
511 		break;
512 	default:
513 		xerror("Unknown pf %d\n", pf);
514 	}
515 
516 	connect_one_server(fd, unixfd);
517 
518 	return 0;
519 }
520 
521 static void init_rng(void)
522 {
523 	unsigned int foo;
524 
525 	if (getrandom(&foo, sizeof(foo), 0) == -1) {
526 		perror("getrandom");
527 		exit(1);
528 	}
529 
530 	srand(foo);
531 }
532 
533 static pid_t xfork(void)
534 {
535 	pid_t p = fork();
536 
537 	if (p < 0)
538 		die_perror("fork");
539 	else if (p == 0)
540 		init_rng();
541 
542 	return p;
543 }
544 
545 static int rcheck(int wstatus, const char *what)
546 {
547 	if (WIFEXITED(wstatus)) {
548 		if (WEXITSTATUS(wstatus) == 0)
549 			return 0;
550 		fprintf(stderr, "%s exited, status=%d\n", what, WEXITSTATUS(wstatus));
551 		return WEXITSTATUS(wstatus);
552 	} else if (WIFSIGNALED(wstatus)) {
553 		xerror("%s killed by signal %d\n", what, WTERMSIG(wstatus));
554 	} else if (WIFSTOPPED(wstatus)) {
555 		xerror("%s stopped by signal %d\n", what, WSTOPSIG(wstatus));
556 	}
557 
558 	return 111;
559 }
560 
561 int main(int argc, char *argv[])
562 {
563 	int e1, e2, wstatus;
564 	pid_t s, c, ret;
565 	int unixfds[2];
566 
567 	parse_opts(argc, argv);
568 
569 	e1 = socketpair(AF_UNIX, SOCK_DGRAM, 0, unixfds);
570 	if (e1 < 0)
571 		die_perror("pipe");
572 
573 	s = xfork();
574 	if (s == 0)
575 		return server(unixfds[1]);
576 
577 	close(unixfds[1]);
578 
579 	/* wait until server bound a socket */
580 	e1 = read(unixfds[0], &e1, 4);
581 	assert(e1 == 4);
582 
583 	c = xfork();
584 	if (c == 0)
585 		return client(unixfds[0]);
586 
587 	close(unixfds[0]);
588 
589 	ret = waitpid(s, &wstatus, 0);
590 	if (ret == -1)
591 		die_perror("waitpid");
592 	e1 = rcheck(wstatus, "server");
593 	ret = waitpid(c, &wstatus, 0);
594 	if (ret == -1)
595 		die_perror("waitpid");
596 	e2 = rcheck(wstatus, "client");
597 
598 	return e1 ? e1 : e2;
599 }
600