xref: /linux/tools/testing/selftests/net/mptcp/mptcp_inq.c (revision b4ada0618eed0fbd1b1630f73deb048c592b06a1)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 #define _GNU_SOURCE
4 
5 #include <assert.h>
6 #include <errno.h>
7 #include <fcntl.h>
8 #include <limits.h>
9 #include <string.h>
10 #include <stdarg.h>
11 #include <stdbool.h>
12 #include <stdint.h>
13 #include <inttypes.h>
14 #include <stdio.h>
15 #include <stdlib.h>
16 #include <strings.h>
17 #include <unistd.h>
18 #include <time.h>
19 
20 #include <sys/ioctl.h>
21 #include <sys/random.h>
22 #include <sys/socket.h>
23 #include <sys/types.h>
24 #include <sys/wait.h>
25 
26 #include <netdb.h>
27 #include <netinet/in.h>
28 
29 #include <linux/tcp.h>
30 #include <linux/sockios.h>
31 
32 #ifndef IPPROTO_MPTCP
33 #define IPPROTO_MPTCP 262
34 #endif
35 #ifndef SOL_MPTCP
36 #define SOL_MPTCP 284
37 #endif
38 
39 static int pf = AF_INET;
40 static int proto_tx = IPPROTO_MPTCP;
41 static int proto_rx = IPPROTO_MPTCP;
42 
43 static void die_perror(const char *msg)
44 {
45 	perror(msg);
46 	exit(1);
47 }
48 
49 static void die_usage(int r)
50 {
51 	fprintf(stderr, "Usage: mptcp_inq [-6] [ -t tcp|mptcp ] [ -r tcp|mptcp]\n");
52 	exit(r);
53 }
54 
55 static void xerror(const char *fmt, ...)
56 {
57 	va_list ap;
58 
59 	va_start(ap, fmt);
60 	vfprintf(stderr, fmt, ap);
61 	va_end(ap);
62 	fputc('\n', stderr);
63 	exit(1);
64 }
65 
66 static const char *getxinfo_strerr(int err)
67 {
68 	if (err == EAI_SYSTEM)
69 		return strerror(errno);
70 
71 	return gai_strerror(err);
72 }
73 
74 static void xgetaddrinfo(const char *node, const char *service,
75 			 struct addrinfo *hints,
76 			 struct addrinfo **res)
77 {
78 	int err;
79 
80 again:
81 	err = getaddrinfo(node, service, hints, res);
82 	if (err) {
83 		const char *errstr;
84 
85 		if (err == EAI_SOCKTYPE) {
86 			hints->ai_protocol = IPPROTO_TCP;
87 			goto again;
88 		}
89 
90 		errstr = getxinfo_strerr(err);
91 
92 		fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n",
93 			node ? node : "", service ? service : "", errstr);
94 		exit(1);
95 	}
96 }
97 
98 static int sock_listen_mptcp(const char * const listenaddr,
99 			     const char * const port)
100 {
101 	int sock = -1;
102 	struct addrinfo hints = {
103 		.ai_protocol = IPPROTO_MPTCP,
104 		.ai_socktype = SOCK_STREAM,
105 		.ai_flags = AI_PASSIVE | AI_NUMERICHOST
106 	};
107 
108 	hints.ai_family = pf;
109 
110 	struct addrinfo *a, *addr;
111 	int one = 1;
112 
113 	xgetaddrinfo(listenaddr, port, &hints, &addr);
114 	hints.ai_family = pf;
115 
116 	for (a = addr; a; a = a->ai_next) {
117 		sock = socket(a->ai_family, a->ai_socktype, proto_rx);
118 		if (sock < 0)
119 			continue;
120 
121 		if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one,
122 				     sizeof(one)))
123 			perror("setsockopt");
124 
125 		if (bind(sock, a->ai_addr, a->ai_addrlen) == 0)
126 			break; /* success */
127 
128 		perror("bind");
129 		close(sock);
130 		sock = -1;
131 	}
132 
133 	freeaddrinfo(addr);
134 
135 	if (sock < 0)
136 		xerror("could not create listen socket");
137 
138 	if (listen(sock, 20))
139 		die_perror("listen");
140 
141 	return sock;
142 }
143 
144 static int sock_connect_mptcp(const char * const remoteaddr,
145 			      const char * const port, int proto)
146 {
147 	struct addrinfo hints = {
148 		.ai_protocol = IPPROTO_MPTCP,
149 		.ai_socktype = SOCK_STREAM,
150 	};
151 	struct addrinfo *a, *addr;
152 	int sock = -1;
153 
154 	hints.ai_family = pf;
155 
156 	xgetaddrinfo(remoteaddr, port, &hints, &addr);
157 	for (a = addr; a; a = a->ai_next) {
158 		sock = socket(a->ai_family, a->ai_socktype, proto);
159 		if (sock < 0)
160 			continue;
161 
162 		if (connect(sock, a->ai_addr, a->ai_addrlen) == 0)
163 			break; /* success */
164 
165 		die_perror("connect");
166 	}
167 
168 	if (sock < 0)
169 		xerror("could not create connect socket");
170 
171 	freeaddrinfo(addr);
172 	return sock;
173 }
174 
175 static int protostr_to_num(const char *s)
176 {
177 	if (strcasecmp(s, "tcp") == 0)
178 		return IPPROTO_TCP;
179 	if (strcasecmp(s, "mptcp") == 0)
180 		return IPPROTO_MPTCP;
181 
182 	die_usage(1);
183 	return 0;
184 }
185 
186 static void parse_opts(int argc, char **argv)
187 {
188 	int c;
189 
190 	while ((c = getopt(argc, argv, "h6t:r:")) != -1) {
191 		switch (c) {
192 		case 'h':
193 			die_usage(0);
194 			break;
195 		case '6':
196 			pf = AF_INET6;
197 			break;
198 		case 't':
199 			proto_tx = protostr_to_num(optarg);
200 			break;
201 		case 'r':
202 			proto_rx = protostr_to_num(optarg);
203 			break;
204 		default:
205 			die_usage(1);
206 			break;
207 		}
208 	}
209 }
210 
211 /* wait up to timeout milliseconds */
212 static void wait_for_ack(int fd, int timeout, size_t total)
213 {
214 	int i;
215 
216 	for (i = 0; i < timeout; i++) {
217 		int nsd, ret, queued = -1;
218 		struct timespec req;
219 
220 		ret = ioctl(fd, TIOCOUTQ, &queued);
221 		if (ret < 0)
222 			die_perror("TIOCOUTQ");
223 
224 		ret = ioctl(fd, SIOCOUTQNSD, &nsd);
225 		if (ret < 0)
226 			die_perror("SIOCOUTQNSD");
227 
228 		if ((size_t)queued > total)
229 			xerror("TIOCOUTQ %u, but only %zu expected\n", queued, total);
230 		assert(nsd <= queued);
231 
232 		if (queued == 0)
233 			return;
234 
235 		/* wait for peer to ack rx of all data */
236 		req.tv_sec = 0;
237 		req.tv_nsec = 1 * 1000 * 1000ul; /* 1ms */
238 		nanosleep(&req, NULL);
239 	}
240 
241 	xerror("still tx data queued after %u ms\n", timeout);
242 }
243 
244 static void connect_one_server(int fd, int unixfd)
245 {
246 	size_t len, i, total, sent;
247 	char buf[4096], buf2[4096];
248 	ssize_t ret;
249 
250 	len = rand() % (sizeof(buf) - 1);
251 
252 	if (len < 128)
253 		len = 128;
254 
255 	for (i = 0; i < len ; i++) {
256 		buf[i] = rand() % 26;
257 		buf[i] += 'A';
258 	}
259 
260 	buf[i] = '\n';
261 
262 	/* un-block server */
263 	ret = read(unixfd, buf2, 4);
264 	assert(ret == 4);
265 
266 	assert(strncmp(buf2, "xmit", 4) == 0);
267 
268 	ret = write(unixfd, &len, sizeof(len));
269 	assert(ret == (ssize_t)sizeof(len));
270 
271 	ret = write(fd, buf, len);
272 	if (ret < 0)
273 		die_perror("write");
274 
275 	if (ret != (ssize_t)len)
276 		xerror("short write");
277 
278 	ret = read(unixfd, buf2, 4);
279 	assert(strncmp(buf2, "huge", 4) == 0);
280 
281 	total = rand() % (16 * 1024 * 1024);
282 	total += (1 * 1024 * 1024);
283 	sent = total;
284 
285 	ret = write(unixfd, &total, sizeof(total));
286 	assert(ret == (ssize_t)sizeof(total));
287 
288 	wait_for_ack(fd, 5000, len);
289 
290 	while (total > 0) {
291 		if (total > sizeof(buf))
292 			len = sizeof(buf);
293 		else
294 			len = total;
295 
296 		ret = write(fd, buf, len);
297 		if (ret < 0)
298 			die_perror("write");
299 		total -= ret;
300 
301 		/* we don't have to care about buf content, only
302 		 * number of total bytes sent
303 		 */
304 	}
305 
306 	ret = read(unixfd, buf2, 4);
307 	assert(ret == 4);
308 	assert(strncmp(buf2, "shut", 4) == 0);
309 
310 	wait_for_ack(fd, 5000, sent);
311 
312 	ret = write(fd, buf, 1);
313 	assert(ret == 1);
314 	close(fd);
315 	ret = write(unixfd, "closed", 6);
316 	assert(ret == 6);
317 
318 	close(unixfd);
319 }
320 
321 static void get_tcp_inq(struct msghdr *msgh, unsigned int *inqv)
322 {
323 	struct cmsghdr *cmsg;
324 
325 	for (cmsg = CMSG_FIRSTHDR(msgh); cmsg ; cmsg = CMSG_NXTHDR(msgh, cmsg)) {
326 		if (cmsg->cmsg_level == IPPROTO_TCP && cmsg->cmsg_type == TCP_CM_INQ) {
327 			memcpy(inqv, CMSG_DATA(cmsg), sizeof(*inqv));
328 			return;
329 		}
330 	}
331 
332 	xerror("could not find TCP_CM_INQ cmsg type");
333 }
334 
335 static void process_one_client(int fd, int unixfd)
336 {
337 	unsigned int tcp_inq;
338 	size_t expect_len;
339 	char msg_buf[4096];
340 	char buf[4096];
341 	char tmp[16];
342 	struct iovec iov = {
343 		.iov_base = buf,
344 		.iov_len = 1,
345 	};
346 	struct msghdr msg = {
347 		.msg_iov = &iov,
348 		.msg_iovlen = 1,
349 		.msg_control = msg_buf,
350 		.msg_controllen = sizeof(msg_buf),
351 	};
352 	ssize_t ret, tot;
353 
354 	ret = write(unixfd, "xmit", 4);
355 	assert(ret == 4);
356 
357 	ret = read(unixfd, &expect_len, sizeof(expect_len));
358 	assert(ret == (ssize_t)sizeof(expect_len));
359 
360 	if (expect_len > sizeof(buf))
361 		xerror("expect len %zu exceeds buffer size", expect_len);
362 
363 	for (;;) {
364 		struct timespec req;
365 		unsigned int queued;
366 
367 		ret = ioctl(fd, FIONREAD, &queued);
368 		if (ret < 0)
369 			die_perror("FIONREAD");
370 		if (queued > expect_len)
371 			xerror("FIONREAD returned %u, but only %zu expected\n",
372 			       queued, expect_len);
373 		if (queued == expect_len)
374 			break;
375 
376 		req.tv_sec = 0;
377 		req.tv_nsec = 1000 * 1000ul;
378 		nanosleep(&req, NULL);
379 	}
380 
381 	/* read one byte, expect cmsg to return expected - 1 */
382 	ret = recvmsg(fd, &msg, 0);
383 	if (ret < 0)
384 		die_perror("recvmsg");
385 
386 	if (msg.msg_controllen == 0)
387 		xerror("msg_controllen is 0");
388 
389 	get_tcp_inq(&msg, &tcp_inq);
390 
391 	assert((size_t)tcp_inq == (expect_len - 1));
392 
393 	iov.iov_len = sizeof(buf);
394 	ret = recvmsg(fd, &msg, 0);
395 	if (ret < 0)
396 		die_perror("recvmsg");
397 
398 	/* should have gotten exact remainder of all pending data */
399 	assert(ret == (ssize_t)tcp_inq);
400 
401 	/* should be 0, all drained */
402 	get_tcp_inq(&msg, &tcp_inq);
403 	assert(tcp_inq == 0);
404 
405 	/* request a large swath of data. */
406 	ret = write(unixfd, "huge", 4);
407 	assert(ret == 4);
408 
409 	ret = read(unixfd, &expect_len, sizeof(expect_len));
410 	assert(ret == (ssize_t)sizeof(expect_len));
411 
412 	/* peer should send us a few mb of data */
413 	if (expect_len <= sizeof(buf))
414 		xerror("expect len %zu too small\n", expect_len);
415 
416 	tot = 0;
417 	do {
418 		iov.iov_len = sizeof(buf);
419 		ret = recvmsg(fd, &msg, 0);
420 		if (ret < 0)
421 			die_perror("recvmsg");
422 
423 		tot += ret;
424 
425 		get_tcp_inq(&msg, &tcp_inq);
426 
427 		if (tcp_inq > expect_len - tot)
428 			xerror("inq %d, remaining %d total_len %d\n",
429 			       tcp_inq, expect_len - tot, (int)expect_len);
430 
431 		assert(tcp_inq <= expect_len - tot);
432 	} while ((size_t)tot < expect_len);
433 
434 	ret = write(unixfd, "shut", 4);
435 	assert(ret == 4);
436 
437 	/* wait for hangup. Should have received one more byte of data. */
438 	ret = read(unixfd, tmp, sizeof(tmp));
439 	assert(ret == 6);
440 	assert(strncmp(tmp, "closed", 6) == 0);
441 
442 	sleep(1);
443 
444 	iov.iov_len = 1;
445 	ret = recvmsg(fd, &msg, 0);
446 	if (ret < 0)
447 		die_perror("recvmsg");
448 	assert(ret == 1);
449 
450 	get_tcp_inq(&msg, &tcp_inq);
451 
452 	/* tcp_inq should be 1 due to received fin. */
453 	assert(tcp_inq == 1);
454 
455 	iov.iov_len = 1;
456 	ret = recvmsg(fd, &msg, 0);
457 	if (ret < 0)
458 		die_perror("recvmsg");
459 
460 	/* expect EOF */
461 	assert(ret == 0);
462 	get_tcp_inq(&msg, &tcp_inq);
463 	assert(tcp_inq == 1);
464 
465 	close(fd);
466 }
467 
468 static int xaccept(int s)
469 {
470 	int fd = accept(s, NULL, 0);
471 
472 	if (fd < 0)
473 		die_perror("accept");
474 
475 	return fd;
476 }
477 
478 static int server(int unixfd)
479 {
480 	int fd = -1, r, on = 1;
481 
482 	switch (pf) {
483 	case AF_INET:
484 		fd = sock_listen_mptcp("127.0.0.1", "15432");
485 		break;
486 	case AF_INET6:
487 		fd = sock_listen_mptcp("::1", "15432");
488 		break;
489 	default:
490 		xerror("Unknown pf %d\n", pf);
491 		break;
492 	}
493 
494 	r = write(unixfd, "conn", 4);
495 	assert(r == 4);
496 
497 	alarm(15);
498 	r = xaccept(fd);
499 
500 	if (-1 == setsockopt(r, IPPROTO_TCP, TCP_INQ, &on, sizeof(on)))
501 		die_perror("setsockopt");
502 
503 	process_one_client(r, unixfd);
504 
505 	return 0;
506 }
507 
508 static int client(int unixfd)
509 {
510 	int fd = -1;
511 
512 	alarm(15);
513 
514 	switch (pf) {
515 	case AF_INET:
516 		fd = sock_connect_mptcp("127.0.0.1", "15432", proto_tx);
517 		break;
518 	case AF_INET6:
519 		fd = sock_connect_mptcp("::1", "15432", proto_tx);
520 		break;
521 	default:
522 		xerror("Unknown pf %d\n", pf);
523 	}
524 
525 	connect_one_server(fd, unixfd);
526 
527 	return 0;
528 }
529 
530 static void init_rng(void)
531 {
532 	unsigned int foo;
533 
534 	if (getrandom(&foo, sizeof(foo), 0) == -1) {
535 		perror("getrandom");
536 		exit(1);
537 	}
538 
539 	srand(foo);
540 }
541 
542 static pid_t xfork(void)
543 {
544 	pid_t p = fork();
545 
546 	if (p < 0)
547 		die_perror("fork");
548 	else if (p == 0)
549 		init_rng();
550 
551 	return p;
552 }
553 
554 static int rcheck(int wstatus, const char *what)
555 {
556 	if (WIFEXITED(wstatus)) {
557 		if (WEXITSTATUS(wstatus) == 0)
558 			return 0;
559 		fprintf(stderr, "%s exited, status=%d\n", what, WEXITSTATUS(wstatus));
560 		return WEXITSTATUS(wstatus);
561 	} else if (WIFSIGNALED(wstatus)) {
562 		xerror("%s killed by signal %d\n", what, WTERMSIG(wstatus));
563 	} else if (WIFSTOPPED(wstatus)) {
564 		xerror("%s stopped by signal %d\n", what, WSTOPSIG(wstatus));
565 	}
566 
567 	return 111;
568 }
569 
570 int main(int argc, char *argv[])
571 {
572 	int e1, e2, wstatus;
573 	pid_t s, c, ret;
574 	int unixfds[2];
575 
576 	parse_opts(argc, argv);
577 
578 	e1 = socketpair(AF_UNIX, SOCK_DGRAM, 0, unixfds);
579 	if (e1 < 0)
580 		die_perror("pipe");
581 
582 	s = xfork();
583 	if (s == 0)
584 		return server(unixfds[1]);
585 
586 	close(unixfds[1]);
587 
588 	/* wait until server bound a socket */
589 	e1 = read(unixfds[0], &e1, 4);
590 	assert(e1 == 4);
591 
592 	c = xfork();
593 	if (c == 0)
594 		return client(unixfds[0]);
595 
596 	close(unixfds[0]);
597 
598 	ret = waitpid(s, &wstatus, 0);
599 	if (ret == -1)
600 		die_perror("waitpid");
601 	e1 = rcheck(wstatus, "server");
602 	ret = waitpid(c, &wstatus, 0);
603 	if (ret == -1)
604 		die_perror("waitpid");
605 	e2 = rcheck(wstatus, "client");
606 
607 	return e1 ? e1 : e2;
608 }
609