xref: /linux/tools/testing/selftests/net/mptcp/mptcp_inq.c (revision 1b98f357dadd6ea613a435fbaef1a5dd7b35fd21)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 #define _GNU_SOURCE
4 
5 #include <assert.h>
6 #include <errno.h>
7 #include <fcntl.h>
8 #include <limits.h>
9 #include <string.h>
10 #include <stdarg.h>
11 #include <stdbool.h>
12 #include <stdint.h>
13 #include <inttypes.h>
14 #include <stdio.h>
15 #include <stdlib.h>
16 #include <strings.h>
17 #include <unistd.h>
18 #include <time.h>
19 
20 #include <sys/ioctl.h>
21 #include <sys/random.h>
22 #include <sys/socket.h>
23 #include <sys/types.h>
24 #include <sys/wait.h>
25 
26 #include <netdb.h>
27 #include <netinet/in.h>
28 
29 #include <linux/tcp.h>
30 #include <linux/sockios.h>
31 
32 #ifndef IPPROTO_MPTCP
33 #define IPPROTO_MPTCP 262
34 #endif
35 #ifndef SOL_MPTCP
36 #define SOL_MPTCP 284
37 #endif
38 
39 static int pf = AF_INET;
40 static int proto_tx = IPPROTO_MPTCP;
41 static int proto_rx = IPPROTO_MPTCP;
42 
43 static void die_perror(const char *msg)
44 {
45 	perror(msg);
46 	exit(1);
47 }
48 
49 static void die_usage(int r)
50 {
51 	fprintf(stderr, "Usage: mptcp_inq [-6] [ -t tcp|mptcp ] [ -r tcp|mptcp]\n");
52 	exit(r);
53 }
54 
55 static void xerror(const char *fmt, ...)
56 {
57 	va_list ap;
58 
59 	va_start(ap, fmt);
60 	vfprintf(stderr, fmt, ap);
61 	va_end(ap);
62 	fputc('\n', stderr);
63 	exit(1);
64 }
65 
66 static const char *getxinfo_strerr(int err)
67 {
68 	if (err == EAI_SYSTEM)
69 		return strerror(errno);
70 
71 	return gai_strerror(err);
72 }
73 
74 static void xgetaddrinfo(const char *node, const char *service,
75 			 struct addrinfo *hints,
76 			 struct addrinfo **res)
77 {
78 again:
79 	int err = getaddrinfo(node, service, hints, res);
80 
81 	if (err) {
82 		const char *errstr;
83 
84 		if (err == EAI_SOCKTYPE) {
85 			hints->ai_protocol = IPPROTO_TCP;
86 			goto again;
87 		}
88 
89 		errstr = getxinfo_strerr(err);
90 
91 		fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n",
92 			node ? node : "", service ? service : "", errstr);
93 		exit(1);
94 	}
95 }
96 
97 static int sock_listen_mptcp(const char * const listenaddr,
98 			     const char * const port)
99 {
100 	int sock = -1;
101 	struct addrinfo hints = {
102 		.ai_protocol = IPPROTO_MPTCP,
103 		.ai_socktype = SOCK_STREAM,
104 		.ai_flags = AI_PASSIVE | AI_NUMERICHOST
105 	};
106 
107 	hints.ai_family = pf;
108 
109 	struct addrinfo *a, *addr;
110 	int one = 1;
111 
112 	xgetaddrinfo(listenaddr, port, &hints, &addr);
113 	hints.ai_family = pf;
114 
115 	for (a = addr; a; a = a->ai_next) {
116 		sock = socket(a->ai_family, a->ai_socktype, proto_rx);
117 		if (sock < 0)
118 			continue;
119 
120 		if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one,
121 				     sizeof(one)))
122 			perror("setsockopt");
123 
124 		if (bind(sock, a->ai_addr, a->ai_addrlen) == 0)
125 			break; /* success */
126 
127 		perror("bind");
128 		close(sock);
129 		sock = -1;
130 	}
131 
132 	freeaddrinfo(addr);
133 
134 	if (sock < 0)
135 		xerror("could not create listen socket");
136 
137 	if (listen(sock, 20))
138 		die_perror("listen");
139 
140 	return sock;
141 }
142 
143 static int sock_connect_mptcp(const char * const remoteaddr,
144 			      const char * const port, int proto)
145 {
146 	struct addrinfo hints = {
147 		.ai_protocol = IPPROTO_MPTCP,
148 		.ai_socktype = SOCK_STREAM,
149 	};
150 	struct addrinfo *a, *addr;
151 	int sock = -1;
152 
153 	hints.ai_family = pf;
154 
155 	xgetaddrinfo(remoteaddr, port, &hints, &addr);
156 	for (a = addr; a; a = a->ai_next) {
157 		sock = socket(a->ai_family, a->ai_socktype, proto);
158 		if (sock < 0)
159 			continue;
160 
161 		if (connect(sock, a->ai_addr, a->ai_addrlen) == 0)
162 			break; /* success */
163 
164 		die_perror("connect");
165 	}
166 
167 	if (sock < 0)
168 		xerror("could not create connect socket");
169 
170 	freeaddrinfo(addr);
171 	return sock;
172 }
173 
174 static int protostr_to_num(const char *s)
175 {
176 	if (strcasecmp(s, "tcp") == 0)
177 		return IPPROTO_TCP;
178 	if (strcasecmp(s, "mptcp") == 0)
179 		return IPPROTO_MPTCP;
180 
181 	die_usage(1);
182 	return 0;
183 }
184 
185 static void parse_opts(int argc, char **argv)
186 {
187 	int c;
188 
189 	while ((c = getopt(argc, argv, "h6t:r:")) != -1) {
190 		switch (c) {
191 		case 'h':
192 			die_usage(0);
193 			break;
194 		case '6':
195 			pf = AF_INET6;
196 			break;
197 		case 't':
198 			proto_tx = protostr_to_num(optarg);
199 			break;
200 		case 'r':
201 			proto_rx = protostr_to_num(optarg);
202 			break;
203 		default:
204 			die_usage(1);
205 			break;
206 		}
207 	}
208 }
209 
210 /* wait up to timeout milliseconds */
211 static void wait_for_ack(int fd, int timeout, size_t total)
212 {
213 	int i;
214 
215 	for (i = 0; i < timeout; i++) {
216 		int nsd, ret, queued = -1;
217 		struct timespec req;
218 
219 		ret = ioctl(fd, TIOCOUTQ, &queued);
220 		if (ret < 0)
221 			die_perror("TIOCOUTQ");
222 
223 		ret = ioctl(fd, SIOCOUTQNSD, &nsd);
224 		if (ret < 0)
225 			die_perror("SIOCOUTQNSD");
226 
227 		if ((size_t)queued > total)
228 			xerror("TIOCOUTQ %u, but only %zu expected\n", queued, total);
229 		assert(nsd <= queued);
230 
231 		if (queued == 0)
232 			return;
233 
234 		/* wait for peer to ack rx of all data */
235 		req.tv_sec = 0;
236 		req.tv_nsec = 1 * 1000 * 1000ul; /* 1ms */
237 		nanosleep(&req, NULL);
238 	}
239 
240 	xerror("still tx data queued after %u ms\n", timeout);
241 }
242 
243 static void connect_one_server(int fd, int unixfd)
244 {
245 	size_t len, i, total, sent;
246 	char buf[4096], buf2[4096];
247 	ssize_t ret;
248 
249 	len = rand() % (sizeof(buf) - 1);
250 
251 	if (len < 128)
252 		len = 128;
253 
254 	for (i = 0; i < len ; i++) {
255 		buf[i] = rand() % 26;
256 		buf[i] += 'A';
257 	}
258 
259 	buf[i] = '\n';
260 
261 	/* un-block server */
262 	ret = read(unixfd, buf2, 4);
263 	assert(ret == 4);
264 
265 	assert(strncmp(buf2, "xmit", 4) == 0);
266 
267 	ret = write(unixfd, &len, sizeof(len));
268 	assert(ret == (ssize_t)sizeof(len));
269 
270 	ret = write(fd, buf, len);
271 	if (ret < 0)
272 		die_perror("write");
273 
274 	if (ret != (ssize_t)len)
275 		xerror("short write");
276 
277 	ret = read(unixfd, buf2, 4);
278 	assert(strncmp(buf2, "huge", 4) == 0);
279 
280 	total = rand() % (16 * 1024 * 1024);
281 	total += (1 * 1024 * 1024);
282 	sent = total;
283 
284 	ret = write(unixfd, &total, sizeof(total));
285 	assert(ret == (ssize_t)sizeof(total));
286 
287 	wait_for_ack(fd, 5000, len);
288 
289 	while (total > 0) {
290 		if (total > sizeof(buf))
291 			len = sizeof(buf);
292 		else
293 			len = total;
294 
295 		ret = write(fd, buf, len);
296 		if (ret < 0)
297 			die_perror("write");
298 		total -= ret;
299 
300 		/* we don't have to care about buf content, only
301 		 * number of total bytes sent
302 		 */
303 	}
304 
305 	ret = read(unixfd, buf2, 4);
306 	assert(ret == 4);
307 	assert(strncmp(buf2, "shut", 4) == 0);
308 
309 	wait_for_ack(fd, 5000, sent);
310 
311 	ret = write(fd, buf, 1);
312 	assert(ret == 1);
313 	close(fd);
314 	ret = write(unixfd, "closed", 6);
315 	assert(ret == 6);
316 
317 	close(unixfd);
318 }
319 
320 static void get_tcp_inq(struct msghdr *msgh, unsigned int *inqv)
321 {
322 	struct cmsghdr *cmsg;
323 
324 	for (cmsg = CMSG_FIRSTHDR(msgh); cmsg ; cmsg = CMSG_NXTHDR(msgh, cmsg)) {
325 		if (cmsg->cmsg_level == IPPROTO_TCP && cmsg->cmsg_type == TCP_CM_INQ) {
326 			memcpy(inqv, CMSG_DATA(cmsg), sizeof(*inqv));
327 			return;
328 		}
329 	}
330 
331 	xerror("could not find TCP_CM_INQ cmsg type");
332 }
333 
334 static void process_one_client(int fd, int unixfd)
335 {
336 	unsigned int tcp_inq;
337 	size_t expect_len;
338 	char msg_buf[4096];
339 	char buf[4096];
340 	char tmp[16];
341 	struct iovec iov = {
342 		.iov_base = buf,
343 		.iov_len = 1,
344 	};
345 	struct msghdr msg = {
346 		.msg_iov = &iov,
347 		.msg_iovlen = 1,
348 		.msg_control = msg_buf,
349 		.msg_controllen = sizeof(msg_buf),
350 	};
351 	ssize_t ret, tot;
352 
353 	ret = write(unixfd, "xmit", 4);
354 	assert(ret == 4);
355 
356 	ret = read(unixfd, &expect_len, sizeof(expect_len));
357 	assert(ret == (ssize_t)sizeof(expect_len));
358 
359 	if (expect_len > sizeof(buf))
360 		xerror("expect len %zu exceeds buffer size", expect_len);
361 
362 	for (;;) {
363 		struct timespec req;
364 		unsigned int queued;
365 
366 		ret = ioctl(fd, FIONREAD, &queued);
367 		if (ret < 0)
368 			die_perror("FIONREAD");
369 		if (queued > expect_len)
370 			xerror("FIONREAD returned %u, but only %zu expected\n",
371 			       queued, expect_len);
372 		if (queued == expect_len)
373 			break;
374 
375 		req.tv_sec = 0;
376 		req.tv_nsec = 1000 * 1000ul;
377 		nanosleep(&req, NULL);
378 	}
379 
380 	/* read one byte, expect cmsg to return expected - 1 */
381 	ret = recvmsg(fd, &msg, 0);
382 	if (ret < 0)
383 		die_perror("recvmsg");
384 
385 	if (msg.msg_controllen == 0)
386 		xerror("msg_controllen is 0");
387 
388 	get_tcp_inq(&msg, &tcp_inq);
389 
390 	assert((size_t)tcp_inq == (expect_len - 1));
391 
392 	iov.iov_len = sizeof(buf);
393 	ret = recvmsg(fd, &msg, 0);
394 	if (ret < 0)
395 		die_perror("recvmsg");
396 
397 	/* should have gotten exact remainder of all pending data */
398 	assert(ret == (ssize_t)tcp_inq);
399 
400 	/* should be 0, all drained */
401 	get_tcp_inq(&msg, &tcp_inq);
402 	assert(tcp_inq == 0);
403 
404 	/* request a large swath of data. */
405 	ret = write(unixfd, "huge", 4);
406 	assert(ret == 4);
407 
408 	ret = read(unixfd, &expect_len, sizeof(expect_len));
409 	assert(ret == (ssize_t)sizeof(expect_len));
410 
411 	/* peer should send us a few mb of data */
412 	if (expect_len <= sizeof(buf))
413 		xerror("expect len %zu too small\n", expect_len);
414 
415 	tot = 0;
416 	do {
417 		iov.iov_len = sizeof(buf);
418 		ret = recvmsg(fd, &msg, 0);
419 		if (ret < 0)
420 			die_perror("recvmsg");
421 
422 		tot += ret;
423 
424 		get_tcp_inq(&msg, &tcp_inq);
425 
426 		if (tcp_inq > expect_len - tot)
427 			xerror("inq %d, remaining %d total_len %d\n",
428 			       tcp_inq, expect_len - tot, (int)expect_len);
429 
430 		assert(tcp_inq <= expect_len - tot);
431 	} while ((size_t)tot < expect_len);
432 
433 	ret = write(unixfd, "shut", 4);
434 	assert(ret == 4);
435 
436 	/* wait for hangup. Should have received one more byte of data. */
437 	ret = read(unixfd, tmp, sizeof(tmp));
438 	assert(ret == 6);
439 	assert(strncmp(tmp, "closed", 6) == 0);
440 
441 	sleep(1);
442 
443 	iov.iov_len = 1;
444 	ret = recvmsg(fd, &msg, 0);
445 	if (ret < 0)
446 		die_perror("recvmsg");
447 	assert(ret == 1);
448 
449 	get_tcp_inq(&msg, &tcp_inq);
450 
451 	/* tcp_inq should be 1 due to received fin. */
452 	assert(tcp_inq == 1);
453 
454 	iov.iov_len = 1;
455 	ret = recvmsg(fd, &msg, 0);
456 	if (ret < 0)
457 		die_perror("recvmsg");
458 
459 	/* expect EOF */
460 	assert(ret == 0);
461 	get_tcp_inq(&msg, &tcp_inq);
462 	assert(tcp_inq == 1);
463 
464 	close(fd);
465 }
466 
467 static int xaccept(int s)
468 {
469 	int fd = accept(s, NULL, 0);
470 
471 	if (fd < 0)
472 		die_perror("accept");
473 
474 	return fd;
475 }
476 
477 static int server(int unixfd)
478 {
479 	int fd = -1, r, on = 1;
480 
481 	switch (pf) {
482 	case AF_INET:
483 		fd = sock_listen_mptcp("127.0.0.1", "15432");
484 		break;
485 	case AF_INET6:
486 		fd = sock_listen_mptcp("::1", "15432");
487 		break;
488 	default:
489 		xerror("Unknown pf %d\n", pf);
490 		break;
491 	}
492 
493 	r = write(unixfd, "conn", 4);
494 	assert(r == 4);
495 
496 	alarm(15);
497 	r = xaccept(fd);
498 
499 	if (-1 == setsockopt(r, IPPROTO_TCP, TCP_INQ, &on, sizeof(on)))
500 		die_perror("setsockopt");
501 
502 	process_one_client(r, unixfd);
503 
504 	return 0;
505 }
506 
507 static int client(int unixfd)
508 {
509 	int fd = -1;
510 
511 	alarm(15);
512 
513 	switch (pf) {
514 	case AF_INET:
515 		fd = sock_connect_mptcp("127.0.0.1", "15432", proto_tx);
516 		break;
517 	case AF_INET6:
518 		fd = sock_connect_mptcp("::1", "15432", proto_tx);
519 		break;
520 	default:
521 		xerror("Unknown pf %d\n", pf);
522 	}
523 
524 	connect_one_server(fd, unixfd);
525 
526 	return 0;
527 }
528 
529 static void init_rng(void)
530 {
531 	unsigned int foo;
532 
533 	if (getrandom(&foo, sizeof(foo), 0) == -1) {
534 		perror("getrandom");
535 		exit(1);
536 	}
537 
538 	srand(foo);
539 }
540 
541 static pid_t xfork(void)
542 {
543 	pid_t p = fork();
544 
545 	if (p < 0)
546 		die_perror("fork");
547 	else if (p == 0)
548 		init_rng();
549 
550 	return p;
551 }
552 
553 static int rcheck(int wstatus, const char *what)
554 {
555 	if (WIFEXITED(wstatus)) {
556 		if (WEXITSTATUS(wstatus) == 0)
557 			return 0;
558 		fprintf(stderr, "%s exited, status=%d\n", what, WEXITSTATUS(wstatus));
559 		return WEXITSTATUS(wstatus);
560 	} else if (WIFSIGNALED(wstatus)) {
561 		xerror("%s killed by signal %d\n", what, WTERMSIG(wstatus));
562 	} else if (WIFSTOPPED(wstatus)) {
563 		xerror("%s stopped by signal %d\n", what, WSTOPSIG(wstatus));
564 	}
565 
566 	return 111;
567 }
568 
569 int main(int argc, char *argv[])
570 {
571 	int e1, e2, wstatus;
572 	pid_t s, c, ret;
573 	int unixfds[2];
574 
575 	parse_opts(argc, argv);
576 
577 	e1 = socketpair(AF_UNIX, SOCK_DGRAM, 0, unixfds);
578 	if (e1 < 0)
579 		die_perror("pipe");
580 
581 	s = xfork();
582 	if (s == 0)
583 		return server(unixfds[1]);
584 
585 	close(unixfds[1]);
586 
587 	/* wait until server bound a socket */
588 	e1 = read(unixfds[0], &e1, 4);
589 	assert(e1 == 4);
590 
591 	c = xfork();
592 	if (c == 0)
593 		return client(unixfds[0]);
594 
595 	close(unixfds[0]);
596 
597 	ret = waitpid(s, &wstatus, 0);
598 	if (ret == -1)
599 		die_perror("waitpid");
600 	e1 = rcheck(wstatus, "server");
601 	ret = waitpid(c, &wstatus, 0);
602 	if (ret == -1)
603 		die_perror("waitpid");
604 	e2 = rcheck(wstatus, "client");
605 
606 	return e1 ? e1 : e2;
607 }
608