xref: /linux/tools/testing/selftests/net/mptcp/mptcp_inq.c (revision f2a3b12b305c7bb72467b2a56d19a4587b6007f9)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 #define _GNU_SOURCE
4 
5 #include <assert.h>
6 #include <errno.h>
7 #include <fcntl.h>
8 #include <limits.h>
9 #include <string.h>
10 #include <stdarg.h>
11 #include <stdbool.h>
12 #include <stdint.h>
13 #include <inttypes.h>
14 #include <stdio.h>
15 #include <stdlib.h>
16 #include <strings.h>
17 #include <unistd.h>
18 #include <time.h>
19 
20 #include <sys/ioctl.h>
21 #include <sys/random.h>
22 #include <sys/socket.h>
23 #include <sys/types.h>
24 #include <sys/wait.h>
25 
26 #include <netdb.h>
27 #include <netinet/in.h>
28 
29 #include <linux/tcp.h>
30 #include <linux/sockios.h>
31 #include <linux/compiler.h>
32 
33 #ifndef IPPROTO_MPTCP
34 #define IPPROTO_MPTCP 262
35 #endif
36 #ifndef SOL_MPTCP
37 #define SOL_MPTCP 284
38 #endif
39 
40 static int pf = AF_INET;
41 static int proto_tx = IPPROTO_MPTCP;
42 static int proto_rx = IPPROTO_MPTCP;
43 
die_perror(const char * msg)44 static void __noreturn die_perror(const char *msg)
45 {
46 	perror(msg);
47 	exit(1);
48 }
49 
die_usage(int r)50 static void die_usage(int r)
51 {
52 	fprintf(stderr, "Usage: mptcp_inq [-6] [ -t tcp|mptcp ] [ -r tcp|mptcp]\n");
53 	exit(r);
54 }
55 
xerror(const char * fmt,...)56 static void __noreturn xerror(const char *fmt, ...)
57 {
58 	va_list ap;
59 
60 	va_start(ap, fmt);
61 	vfprintf(stderr, fmt, ap);
62 	va_end(ap);
63 	fputc('\n', stderr);
64 	exit(1);
65 }
66 
getxinfo_strerr(int err)67 static const char *getxinfo_strerr(int err)
68 {
69 	if (err == EAI_SYSTEM)
70 		return strerror(errno);
71 
72 	return gai_strerror(err);
73 }
74 
xgetaddrinfo(const char * node,const char * service,struct addrinfo * hints,struct addrinfo ** res)75 static void xgetaddrinfo(const char *node, const char *service,
76 			 struct addrinfo *hints,
77 			 struct addrinfo **res)
78 {
79 	int err;
80 
81 again:
82 	err = getaddrinfo(node, service, hints, res);
83 	if (err) {
84 		const char *errstr;
85 
86 		if (err == EAI_SOCKTYPE) {
87 			hints->ai_protocol = IPPROTO_TCP;
88 			goto again;
89 		}
90 
91 		errstr = getxinfo_strerr(err);
92 
93 		fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n",
94 			node ? node : "", service ? service : "", errstr);
95 		exit(1);
96 	}
97 }
98 
sock_listen_mptcp(const char * const listenaddr,const char * const port)99 static int sock_listen_mptcp(const char * const listenaddr,
100 			     const char * const port)
101 {
102 	int sock = -1;
103 	struct addrinfo hints = {
104 		.ai_protocol = IPPROTO_MPTCP,
105 		.ai_socktype = SOCK_STREAM,
106 		.ai_flags = AI_PASSIVE | AI_NUMERICHOST
107 	};
108 
109 	hints.ai_family = pf;
110 
111 	struct addrinfo *a, *addr;
112 	int one = 1;
113 
114 	xgetaddrinfo(listenaddr, port, &hints, &addr);
115 	hints.ai_family = pf;
116 
117 	for (a = addr; a; a = a->ai_next) {
118 		sock = socket(a->ai_family, a->ai_socktype, proto_rx);
119 		if (sock < 0)
120 			continue;
121 
122 		if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one,
123 				     sizeof(one)))
124 			perror("setsockopt");
125 
126 		if (bind(sock, a->ai_addr, a->ai_addrlen) == 0)
127 			break; /* success */
128 
129 		perror("bind");
130 		close(sock);
131 		sock = -1;
132 	}
133 
134 	freeaddrinfo(addr);
135 
136 	if (sock < 0)
137 		xerror("could not create listen socket");
138 
139 	if (listen(sock, 20))
140 		die_perror("listen");
141 
142 	return sock;
143 }
144 
sock_connect_mptcp(const char * const remoteaddr,const char * const port,int proto)145 static int sock_connect_mptcp(const char * const remoteaddr,
146 			      const char * const port, int proto)
147 {
148 	struct addrinfo hints = {
149 		.ai_protocol = IPPROTO_MPTCP,
150 		.ai_socktype = SOCK_STREAM,
151 	};
152 	struct addrinfo *a, *addr;
153 	int sock = -1;
154 
155 	hints.ai_family = pf;
156 
157 	xgetaddrinfo(remoteaddr, port, &hints, &addr);
158 	for (a = addr; a; a = a->ai_next) {
159 		sock = socket(a->ai_family, a->ai_socktype, proto);
160 		if (sock < 0)
161 			continue;
162 
163 		if (connect(sock, a->ai_addr, a->ai_addrlen) == 0)
164 			break; /* success */
165 
166 		die_perror("connect");
167 	}
168 
169 	if (sock < 0)
170 		xerror("could not create connect socket");
171 
172 	freeaddrinfo(addr);
173 	return sock;
174 }
175 
protostr_to_num(const char * s)176 static int protostr_to_num(const char *s)
177 {
178 	if (strcasecmp(s, "tcp") == 0)
179 		return IPPROTO_TCP;
180 	if (strcasecmp(s, "mptcp") == 0)
181 		return IPPROTO_MPTCP;
182 
183 	die_usage(1);
184 	return 0;
185 }
186 
parse_opts(int argc,char ** argv)187 static void parse_opts(int argc, char **argv)
188 {
189 	int c;
190 
191 	while ((c = getopt(argc, argv, "h6t:r:")) != -1) {
192 		switch (c) {
193 		case 'h':
194 			die_usage(0);
195 			break;
196 		case '6':
197 			pf = AF_INET6;
198 			break;
199 		case 't':
200 			proto_tx = protostr_to_num(optarg);
201 			break;
202 		case 'r':
203 			proto_rx = protostr_to_num(optarg);
204 			break;
205 		default:
206 			die_usage(1);
207 			break;
208 		}
209 	}
210 }
211 
212 /* wait up to timeout milliseconds */
wait_for_ack(int fd,int timeout,size_t total)213 static void wait_for_ack(int fd, int timeout, size_t total)
214 {
215 	int i;
216 
217 	for (i = 0; i < timeout; i++) {
218 		int nsd, ret, queued = -1;
219 		struct timespec req;
220 
221 		ret = ioctl(fd, TIOCOUTQ, &queued);
222 		if (ret < 0)
223 			die_perror("TIOCOUTQ");
224 
225 		ret = ioctl(fd, SIOCOUTQNSD, &nsd);
226 		if (ret < 0)
227 			die_perror("SIOCOUTQNSD");
228 
229 		if ((size_t)queued > total)
230 			xerror("TIOCOUTQ %u, but only %zu expected\n", queued, total);
231 		assert(nsd <= queued);
232 
233 		if (queued == 0)
234 			return;
235 
236 		/* wait for peer to ack rx of all data */
237 		req.tv_sec = 0;
238 		req.tv_nsec = 1 * 1000 * 1000ul; /* 1ms */
239 		nanosleep(&req, NULL);
240 	}
241 
242 	xerror("still tx data queued after %u ms\n", timeout);
243 }
244 
connect_one_server(int fd,int unixfd)245 static void connect_one_server(int fd, int unixfd)
246 {
247 	size_t len, i, total, sent;
248 	char buf[4096], buf2[4096];
249 	ssize_t ret;
250 
251 	len = rand() % (sizeof(buf) - 1);
252 
253 	if (len < 128)
254 		len = 128;
255 
256 	for (i = 0; i < len ; i++) {
257 		buf[i] = rand() % 26;
258 		buf[i] += 'A';
259 	}
260 
261 	buf[i] = '\n';
262 
263 	/* un-block server */
264 	ret = read(unixfd, buf2, 4);
265 	assert(ret == 4);
266 
267 	assert(strncmp(buf2, "xmit", 4) == 0);
268 
269 	ret = write(unixfd, &len, sizeof(len));
270 	assert(ret == (ssize_t)sizeof(len));
271 
272 	ret = write(fd, buf, len);
273 	if (ret < 0)
274 		die_perror("write");
275 
276 	if (ret != (ssize_t)len)
277 		xerror("short write");
278 
279 	ret = read(unixfd, buf2, 4);
280 	assert(strncmp(buf2, "huge", 4) == 0);
281 
282 	total = rand() % (16 * 1024 * 1024);
283 	total += (1 * 1024 * 1024);
284 	sent = total;
285 
286 	ret = write(unixfd, &total, sizeof(total));
287 	assert(ret == (ssize_t)sizeof(total));
288 
289 	wait_for_ack(fd, 5000, len);
290 
291 	while (total > 0) {
292 		if (total > sizeof(buf))
293 			len = sizeof(buf);
294 		else
295 			len = total;
296 
297 		ret = write(fd, buf, len);
298 		if (ret < 0)
299 			die_perror("write");
300 		total -= ret;
301 
302 		/* we don't have to care about buf content, only
303 		 * number of total bytes sent
304 		 */
305 	}
306 
307 	ret = read(unixfd, buf2, 4);
308 	assert(ret == 4);
309 	assert(strncmp(buf2, "shut", 4) == 0);
310 
311 	wait_for_ack(fd, 5000, sent);
312 
313 	ret = write(fd, buf, 1);
314 	assert(ret == 1);
315 	close(fd);
316 	ret = write(unixfd, "closed", 6);
317 	assert(ret == 6);
318 
319 	close(unixfd);
320 }
321 
get_tcp_inq(struct msghdr * msgh,unsigned int * inqv)322 static void get_tcp_inq(struct msghdr *msgh, unsigned int *inqv)
323 {
324 	struct cmsghdr *cmsg;
325 
326 	for (cmsg = CMSG_FIRSTHDR(msgh); cmsg ; cmsg = CMSG_NXTHDR(msgh, cmsg)) {
327 		if (cmsg->cmsg_level == IPPROTO_TCP && cmsg->cmsg_type == TCP_CM_INQ) {
328 			memcpy(inqv, CMSG_DATA(cmsg), sizeof(*inqv));
329 			return;
330 		}
331 	}
332 
333 	xerror("could not find TCP_CM_INQ cmsg type");
334 }
335 
process_one_client(int fd,int unixfd)336 static void process_one_client(int fd, int unixfd)
337 {
338 	unsigned int tcp_inq;
339 	size_t expect_len;
340 	char msg_buf[4096];
341 	char buf[4096];
342 	char tmp[16];
343 	struct iovec iov = {
344 		.iov_base = buf,
345 		.iov_len = 1,
346 	};
347 	struct msghdr msg = {
348 		.msg_iov = &iov,
349 		.msg_iovlen = 1,
350 		.msg_control = msg_buf,
351 		.msg_controllen = sizeof(msg_buf),
352 	};
353 	ssize_t ret, tot;
354 
355 	ret = write(unixfd, "xmit", 4);
356 	assert(ret == 4);
357 
358 	ret = read(unixfd, &expect_len, sizeof(expect_len));
359 	assert(ret == (ssize_t)sizeof(expect_len));
360 
361 	if (expect_len > sizeof(buf))
362 		xerror("expect len %zu exceeds buffer size", expect_len);
363 
364 	for (;;) {
365 		struct timespec req;
366 		unsigned int queued;
367 
368 		ret = ioctl(fd, FIONREAD, &queued);
369 		if (ret < 0)
370 			die_perror("FIONREAD");
371 		if (queued > expect_len)
372 			xerror("FIONREAD returned %u, but only %zu expected\n",
373 			       queued, expect_len);
374 		if (queued == expect_len)
375 			break;
376 
377 		req.tv_sec = 0;
378 		req.tv_nsec = 1000 * 1000ul;
379 		nanosleep(&req, NULL);
380 	}
381 
382 	/* read one byte, expect cmsg to return expected - 1 */
383 	ret = recvmsg(fd, &msg, 0);
384 	if (ret < 0)
385 		die_perror("recvmsg");
386 
387 	if (msg.msg_controllen == 0)
388 		xerror("msg_controllen is 0");
389 
390 	get_tcp_inq(&msg, &tcp_inq);
391 
392 	assert((size_t)tcp_inq == (expect_len - 1));
393 
394 	iov.iov_len = sizeof(buf);
395 	ret = recvmsg(fd, &msg, 0);
396 	if (ret < 0)
397 		die_perror("recvmsg");
398 
399 	/* should have gotten exact remainder of all pending data */
400 	assert(ret == (ssize_t)tcp_inq);
401 
402 	/* should be 0, all drained */
403 	get_tcp_inq(&msg, &tcp_inq);
404 	assert(tcp_inq == 0);
405 
406 	/* request a large swath of data. */
407 	ret = write(unixfd, "huge", 4);
408 	assert(ret == 4);
409 
410 	ret = read(unixfd, &expect_len, sizeof(expect_len));
411 	assert(ret == (ssize_t)sizeof(expect_len));
412 
413 	/* peer should send us a few mb of data */
414 	if (expect_len <= sizeof(buf))
415 		xerror("expect len %zu too small\n", expect_len);
416 
417 	tot = 0;
418 	do {
419 		iov.iov_len = sizeof(buf);
420 		ret = recvmsg(fd, &msg, 0);
421 		if (ret < 0)
422 			die_perror("recvmsg");
423 
424 		tot += ret;
425 
426 		get_tcp_inq(&msg, &tcp_inq);
427 
428 		if (tcp_inq > expect_len - tot)
429 			xerror("inq %d, remaining %d total_len %d\n",
430 			       tcp_inq, expect_len - tot, (int)expect_len);
431 
432 		assert(tcp_inq <= expect_len - tot);
433 	} while ((size_t)tot < expect_len);
434 
435 	ret = write(unixfd, "shut", 4);
436 	assert(ret == 4);
437 
438 	/* wait for hangup. Should have received one more byte of data. */
439 	ret = read(unixfd, tmp, sizeof(tmp));
440 	assert(ret == 6);
441 	assert(strncmp(tmp, "closed", 6) == 0);
442 
443 	sleep(1);
444 
445 	iov.iov_len = 1;
446 	ret = recvmsg(fd, &msg, 0);
447 	if (ret < 0)
448 		die_perror("recvmsg");
449 	assert(ret == 1);
450 
451 	get_tcp_inq(&msg, &tcp_inq);
452 
453 	/* tcp_inq should be 1 due to received fin. */
454 	assert(tcp_inq == 1);
455 
456 	iov.iov_len = 1;
457 	ret = recvmsg(fd, &msg, 0);
458 	if (ret < 0)
459 		die_perror("recvmsg");
460 
461 	/* expect EOF */
462 	assert(ret == 0);
463 	get_tcp_inq(&msg, &tcp_inq);
464 	assert(tcp_inq == 1);
465 
466 	close(fd);
467 }
468 
xaccept(int s)469 static int xaccept(int s)
470 {
471 	int fd = accept(s, NULL, 0);
472 
473 	if (fd < 0)
474 		die_perror("accept");
475 
476 	return fd;
477 }
478 
server(int unixfd)479 static int server(int unixfd)
480 {
481 	int fd = -1, r, on = 1;
482 
483 	switch (pf) {
484 	case AF_INET:
485 		fd = sock_listen_mptcp("127.0.0.1", "15432");
486 		break;
487 	case AF_INET6:
488 		fd = sock_listen_mptcp("::1", "15432");
489 		break;
490 	default:
491 		xerror("Unknown pf %d\n", pf);
492 		break;
493 	}
494 
495 	r = write(unixfd, "conn", 4);
496 	assert(r == 4);
497 
498 	alarm(15);
499 	r = xaccept(fd);
500 
501 	if (-1 == setsockopt(r, IPPROTO_TCP, TCP_INQ, &on, sizeof(on)))
502 		die_perror("setsockopt");
503 
504 	process_one_client(r, unixfd);
505 
506 	close(fd);
507 	return 0;
508 }
509 
client(int unixfd)510 static int client(int unixfd)
511 {
512 	int fd = -1;
513 
514 	alarm(15);
515 
516 	switch (pf) {
517 	case AF_INET:
518 		fd = sock_connect_mptcp("127.0.0.1", "15432", proto_tx);
519 		break;
520 	case AF_INET6:
521 		fd = sock_connect_mptcp("::1", "15432", proto_tx);
522 		break;
523 	default:
524 		xerror("Unknown pf %d\n", pf);
525 	}
526 
527 	connect_one_server(fd, unixfd);
528 
529 	return 0;
530 }
531 
init_rng(void)532 static void init_rng(void)
533 {
534 	unsigned int foo;
535 
536 	if (getrandom(&foo, sizeof(foo), 0) == -1) {
537 		perror("getrandom");
538 		exit(1);
539 	}
540 
541 	srand(foo);
542 }
543 
xfork(void)544 static pid_t xfork(void)
545 {
546 	pid_t p = fork();
547 
548 	if (p < 0)
549 		die_perror("fork");
550 	else if (p == 0)
551 		init_rng();
552 
553 	return p;
554 }
555 
rcheck(int wstatus,const char * what)556 static int rcheck(int wstatus, const char *what)
557 {
558 	if (WIFEXITED(wstatus)) {
559 		if (WEXITSTATUS(wstatus) == 0)
560 			return 0;
561 		fprintf(stderr, "%s exited, status=%d\n", what, WEXITSTATUS(wstatus));
562 		return WEXITSTATUS(wstatus);
563 	} else if (WIFSIGNALED(wstatus)) {
564 		xerror("%s killed by signal %d\n", what, WTERMSIG(wstatus));
565 	} else if (WIFSTOPPED(wstatus)) {
566 		xerror("%s stopped by signal %d\n", what, WSTOPSIG(wstatus));
567 	}
568 
569 	return 111;
570 }
571 
main(int argc,char * argv[])572 int main(int argc, char *argv[])
573 {
574 	int e1, e2, wstatus;
575 	pid_t s, c, ret;
576 	int unixfds[2];
577 
578 	parse_opts(argc, argv);
579 
580 	e1 = socketpair(AF_UNIX, SOCK_DGRAM, 0, unixfds);
581 	if (e1 < 0)
582 		die_perror("pipe");
583 
584 	s = xfork();
585 	if (s == 0) {
586 		close(unixfds[0]);
587 		ret = server(unixfds[1]);
588 		close(unixfds[1]);
589 		return ret;
590 	}
591 
592 	close(unixfds[1]);
593 
594 	/* wait until server bound a socket */
595 	e1 = read(unixfds[0], &e1, 4);
596 	assert(e1 == 4);
597 
598 	c = xfork();
599 	if (c == 0)
600 		return client(unixfds[0]);
601 
602 	close(unixfds[0]);
603 
604 	ret = waitpid(s, &wstatus, 0);
605 	if (ret == -1)
606 		die_perror("waitpid");
607 	e1 = rcheck(wstatus, "server");
608 	ret = waitpid(c, &wstatus, 0);
609 	if (ret == -1)
610 		die_perror("waitpid");
611 	e2 = rcheck(wstatus, "client");
612 
613 	return e1 ? e1 : e2;
614 }
615