xref: /linux/tools/testing/selftests/net/mptcp/mptcp_connect.c (revision 41fb0cf1bced59c1fe178cf6cc9f716b5da9e40e)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 #define _GNU_SOURCE
4 
5 #include <errno.h>
6 #include <limits.h>
7 #include <fcntl.h>
8 #include <string.h>
9 #include <stdarg.h>
10 #include <stdbool.h>
11 #include <stdint.h>
12 #include <stdio.h>
13 #include <stdlib.h>
14 #include <strings.h>
15 #include <signal.h>
16 #include <unistd.h>
17 #include <time.h>
18 
19 #include <sys/poll.h>
20 #include <sys/sendfile.h>
21 #include <sys/stat.h>
22 #include <sys/socket.h>
23 #include <sys/types.h>
24 #include <sys/mman.h>
25 
26 #include <netdb.h>
27 #include <netinet/in.h>
28 
29 #include <linux/tcp.h>
30 #include <linux/time_types.h>
31 
32 extern int optind;
33 
34 #ifndef IPPROTO_MPTCP
35 #define IPPROTO_MPTCP 262
36 #endif
37 #ifndef TCP_ULP
38 #define TCP_ULP 31
39 #endif
40 
41 static int  poll_timeout = 10 * 1000;
42 static bool listen_mode;
43 static bool quit;
44 
45 enum cfg_mode {
46 	CFG_MODE_POLL,
47 	CFG_MODE_MMAP,
48 	CFG_MODE_SENDFILE,
49 };
50 
51 enum cfg_peek {
52 	CFG_NONE_PEEK,
53 	CFG_WITH_PEEK,
54 	CFG_AFTER_PEEK,
55 };
56 
57 static enum cfg_mode cfg_mode = CFG_MODE_POLL;
58 static enum cfg_peek cfg_peek = CFG_NONE_PEEK;
59 static const char *cfg_host;
60 static const char *cfg_port	= "12000";
61 static int cfg_sock_proto	= IPPROTO_MPTCP;
62 static bool tcpulp_audit;
63 static int pf = AF_INET;
64 static int cfg_sndbuf;
65 static int cfg_rcvbuf;
66 static bool cfg_join;
67 static bool cfg_remove;
68 static unsigned int cfg_time;
69 static unsigned int cfg_do_w;
70 static int cfg_wait;
71 static uint32_t cfg_mark;
72 
73 struct cfg_cmsg_types {
74 	unsigned int cmsg_enabled:1;
75 	unsigned int timestampns:1;
76 	unsigned int tcp_inq:1;
77 };
78 
79 struct cfg_sockopt_types {
80 	unsigned int transparent:1;
81 };
82 
83 struct tcp_inq_state {
84 	unsigned int last;
85 	bool expect_eof;
86 };
87 
88 static struct tcp_inq_state tcp_inq;
89 
90 static struct cfg_cmsg_types cfg_cmsg_types;
91 static struct cfg_sockopt_types cfg_sockopt_types;
92 
93 static void die_usage(void)
94 {
95 	fprintf(stderr, "Usage: mptcp_connect [-6] [-u] [-s MPTCP|TCP] [-p port] [-m mode]"
96 		"[-l] [-w sec] [-t num] [-T num] connect_address\n");
97 	fprintf(stderr, "\t-6 use ipv6\n");
98 	fprintf(stderr, "\t-t num -- set poll timeout to num\n");
99 	fprintf(stderr, "\t-T num -- set expected runtime to num ms\n");
100 	fprintf(stderr, "\t-S num -- set SO_SNDBUF to num\n");
101 	fprintf(stderr, "\t-R num -- set SO_RCVBUF to num\n");
102 	fprintf(stderr, "\t-p num -- use port num\n");
103 	fprintf(stderr, "\t-s [MPTCP|TCP] -- use mptcp(default) or tcp sockets\n");
104 	fprintf(stderr, "\t-m [poll|mmap|sendfile] -- use poll(default)/mmap+write/sendfile\n");
105 	fprintf(stderr, "\t-M mark -- set socket packet mark\n");
106 	fprintf(stderr, "\t-u -- check mptcp ulp\n");
107 	fprintf(stderr, "\t-w num -- wait num sec before closing the socket\n");
108 	fprintf(stderr, "\t-c cmsg -- test cmsg type <cmsg>\n");
109 	fprintf(stderr, "\t-o option -- test sockopt <option>\n");
110 	fprintf(stderr,
111 		"\t-P [saveWithPeek|saveAfterPeek] -- save data with/after MSG_PEEK form tcp socket\n");
112 	exit(1);
113 }
114 
115 static void xerror(const char *fmt, ...)
116 {
117 	va_list ap;
118 
119 	va_start(ap, fmt);
120 	vfprintf(stderr, fmt, ap);
121 	va_end(ap);
122 	exit(1);
123 }
124 
125 static void handle_signal(int nr)
126 {
127 	quit = true;
128 }
129 
130 static const char *getxinfo_strerr(int err)
131 {
132 	if (err == EAI_SYSTEM)
133 		return strerror(errno);
134 
135 	return gai_strerror(err);
136 }
137 
138 static void xgetnameinfo(const struct sockaddr *addr, socklen_t addrlen,
139 			 char *host, socklen_t hostlen,
140 			 char *serv, socklen_t servlen)
141 {
142 	int flags = NI_NUMERICHOST | NI_NUMERICSERV;
143 	int err = getnameinfo(addr, addrlen, host, hostlen, serv, servlen,
144 			      flags);
145 
146 	if (err) {
147 		const char *errstr = getxinfo_strerr(err);
148 
149 		fprintf(stderr, "Fatal: getnameinfo: %s\n", errstr);
150 		exit(1);
151 	}
152 }
153 
154 static void xgetaddrinfo(const char *node, const char *service,
155 			 const struct addrinfo *hints,
156 			 struct addrinfo **res)
157 {
158 	int err = getaddrinfo(node, service, hints, res);
159 
160 	if (err) {
161 		const char *errstr = getxinfo_strerr(err);
162 
163 		fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n",
164 			node ? node : "", service ? service : "", errstr);
165 		exit(1);
166 	}
167 }
168 
169 static void set_rcvbuf(int fd, unsigned int size)
170 {
171 	int err;
172 
173 	err = setsockopt(fd, SOL_SOCKET, SO_RCVBUF, &size, sizeof(size));
174 	if (err) {
175 		perror("set SO_RCVBUF");
176 		exit(1);
177 	}
178 }
179 
180 static void set_sndbuf(int fd, unsigned int size)
181 {
182 	int err;
183 
184 	err = setsockopt(fd, SOL_SOCKET, SO_SNDBUF, &size, sizeof(size));
185 	if (err) {
186 		perror("set SO_SNDBUF");
187 		exit(1);
188 	}
189 }
190 
191 static void set_mark(int fd, uint32_t mark)
192 {
193 	int err;
194 
195 	err = setsockopt(fd, SOL_SOCKET, SO_MARK, &mark, sizeof(mark));
196 	if (err) {
197 		perror("set SO_MARK");
198 		exit(1);
199 	}
200 }
201 
202 static void set_transparent(int fd, int pf)
203 {
204 	int one = 1;
205 
206 	switch (pf) {
207 	case AF_INET:
208 		if (-1 == setsockopt(fd, SOL_IP, IP_TRANSPARENT, &one, sizeof(one)))
209 			perror("IP_TRANSPARENT");
210 		break;
211 	case AF_INET6:
212 		if (-1 == setsockopt(fd, IPPROTO_IPV6, IPV6_TRANSPARENT, &one, sizeof(one)))
213 			perror("IPV6_TRANSPARENT");
214 		break;
215 	}
216 }
217 
218 static int sock_listen_mptcp(const char * const listenaddr,
219 			     const char * const port)
220 {
221 	int sock;
222 	struct addrinfo hints = {
223 		.ai_protocol = IPPROTO_TCP,
224 		.ai_socktype = SOCK_STREAM,
225 		.ai_flags = AI_PASSIVE | AI_NUMERICHOST
226 	};
227 
228 	hints.ai_family = pf;
229 
230 	struct addrinfo *a, *addr;
231 	int one = 1;
232 
233 	xgetaddrinfo(listenaddr, port, &hints, &addr);
234 	hints.ai_family = pf;
235 
236 	for (a = addr; a; a = a->ai_next) {
237 		sock = socket(a->ai_family, a->ai_socktype, cfg_sock_proto);
238 		if (sock < 0)
239 			continue;
240 
241 		if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one,
242 				     sizeof(one)))
243 			perror("setsockopt");
244 
245 		if (cfg_sockopt_types.transparent)
246 			set_transparent(sock, pf);
247 
248 		if (bind(sock, a->ai_addr, a->ai_addrlen) == 0)
249 			break; /* success */
250 
251 		perror("bind");
252 		close(sock);
253 		sock = -1;
254 	}
255 
256 	freeaddrinfo(addr);
257 
258 	if (sock < 0) {
259 		fprintf(stderr, "Could not create listen socket\n");
260 		return sock;
261 	}
262 
263 	if (listen(sock, 20)) {
264 		perror("listen");
265 		close(sock);
266 		return -1;
267 	}
268 
269 	return sock;
270 }
271 
272 static bool sock_test_tcpulp(const char * const remoteaddr,
273 			     const char * const port)
274 {
275 	struct addrinfo hints = {
276 		.ai_protocol = IPPROTO_TCP,
277 		.ai_socktype = SOCK_STREAM,
278 	};
279 	struct addrinfo *a, *addr;
280 	int sock = -1, ret = 0;
281 	bool test_pass = false;
282 
283 	hints.ai_family = AF_INET;
284 
285 	xgetaddrinfo(remoteaddr, port, &hints, &addr);
286 	for (a = addr; a; a = a->ai_next) {
287 		sock = socket(a->ai_family, a->ai_socktype, IPPROTO_TCP);
288 		if (sock < 0) {
289 			perror("socket");
290 			continue;
291 		}
292 		ret = setsockopt(sock, IPPROTO_TCP, TCP_ULP, "mptcp",
293 				 sizeof("mptcp"));
294 		if (ret == -1 && errno == EOPNOTSUPP)
295 			test_pass = true;
296 		close(sock);
297 
298 		if (test_pass)
299 			break;
300 		if (!ret)
301 			fprintf(stderr,
302 				"setsockopt(TCP_ULP) returned 0\n");
303 		else
304 			perror("setsockopt(TCP_ULP)");
305 	}
306 	return test_pass;
307 }
308 
309 static int sock_connect_mptcp(const char * const remoteaddr,
310 			      const char * const port, int proto)
311 {
312 	struct addrinfo hints = {
313 		.ai_protocol = IPPROTO_TCP,
314 		.ai_socktype = SOCK_STREAM,
315 	};
316 	struct addrinfo *a, *addr;
317 	int sock = -1;
318 
319 	hints.ai_family = pf;
320 
321 	xgetaddrinfo(remoteaddr, port, &hints, &addr);
322 	for (a = addr; a; a = a->ai_next) {
323 		sock = socket(a->ai_family, a->ai_socktype, proto);
324 		if (sock < 0) {
325 			perror("socket");
326 			continue;
327 		}
328 
329 		if (cfg_mark)
330 			set_mark(sock, cfg_mark);
331 
332 		if (connect(sock, a->ai_addr, a->ai_addrlen) == 0)
333 			break; /* success */
334 
335 		perror("connect()");
336 		close(sock);
337 		sock = -1;
338 	}
339 
340 	freeaddrinfo(addr);
341 	return sock;
342 }
343 
344 static size_t do_rnd_write(const int fd, char *buf, const size_t len)
345 {
346 	static bool first = true;
347 	unsigned int do_w;
348 	ssize_t bw;
349 
350 	do_w = rand() & 0xffff;
351 	if (do_w == 0 || do_w > len)
352 		do_w = len;
353 
354 	if (cfg_join && first && do_w > 100)
355 		do_w = 100;
356 
357 	if (cfg_remove && do_w > cfg_do_w)
358 		do_w = cfg_do_w;
359 
360 	bw = write(fd, buf, do_w);
361 	if (bw < 0)
362 		perror("write");
363 
364 	/* let the join handshake complete, before going on */
365 	if (cfg_join && first) {
366 		usleep(200000);
367 		first = false;
368 	}
369 
370 	if (cfg_remove)
371 		usleep(200000);
372 
373 	return bw;
374 }
375 
376 static size_t do_write(const int fd, char *buf, const size_t len)
377 {
378 	size_t offset = 0;
379 
380 	while (offset < len) {
381 		size_t written;
382 		ssize_t bw;
383 
384 		bw = write(fd, buf + offset, len - offset);
385 		if (bw < 0) {
386 			perror("write");
387 			return 0;
388 		}
389 
390 		written = (size_t)bw;
391 		offset += written;
392 	}
393 
394 	return offset;
395 }
396 
397 static void process_cmsg(struct msghdr *msgh)
398 {
399 	struct __kernel_timespec ts;
400 	bool inq_found = false;
401 	bool ts_found = false;
402 	unsigned int inq = 0;
403 	struct cmsghdr *cmsg;
404 
405 	for (cmsg = CMSG_FIRSTHDR(msgh); cmsg ; cmsg = CMSG_NXTHDR(msgh, cmsg)) {
406 		if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SO_TIMESTAMPNS_NEW) {
407 			memcpy(&ts, CMSG_DATA(cmsg), sizeof(ts));
408 			ts_found = true;
409 			continue;
410 		}
411 		if (cmsg->cmsg_level == IPPROTO_TCP && cmsg->cmsg_type == TCP_CM_INQ) {
412 			memcpy(&inq, CMSG_DATA(cmsg), sizeof(inq));
413 			inq_found = true;
414 			continue;
415 		}
416 
417 	}
418 
419 	if (cfg_cmsg_types.timestampns) {
420 		if (!ts_found)
421 			xerror("TIMESTAMPNS not present\n");
422 	}
423 
424 	if (cfg_cmsg_types.tcp_inq) {
425 		if (!inq_found)
426 			xerror("TCP_INQ not present\n");
427 
428 		if (inq > 1024)
429 			xerror("tcp_inq %u is larger than one kbyte\n", inq);
430 		tcp_inq.last = inq;
431 	}
432 }
433 
434 static ssize_t do_recvmsg_cmsg(const int fd, char *buf, const size_t len)
435 {
436 	char msg_buf[8192];
437 	struct iovec iov = {
438 		.iov_base = buf,
439 		.iov_len = len,
440 	};
441 	struct msghdr msg = {
442 		.msg_iov = &iov,
443 		.msg_iovlen = 1,
444 		.msg_control = msg_buf,
445 		.msg_controllen = sizeof(msg_buf),
446 	};
447 	int flags = 0;
448 	unsigned int last_hint = tcp_inq.last;
449 	int ret = recvmsg(fd, &msg, flags);
450 
451 	if (ret <= 0) {
452 		if (ret == 0 && tcp_inq.expect_eof)
453 			return ret;
454 
455 		if (ret == 0 && cfg_cmsg_types.tcp_inq)
456 			if (last_hint != 1 && last_hint != 0)
457 				xerror("EOF but last tcp_inq hint was %u\n", last_hint);
458 
459 		return ret;
460 	}
461 
462 	if (tcp_inq.expect_eof)
463 		xerror("expected EOF, last_hint %u, now %u\n",
464 		       last_hint, tcp_inq.last);
465 
466 	if (msg.msg_controllen && !cfg_cmsg_types.cmsg_enabled)
467 		xerror("got %lu bytes of cmsg data, expected 0\n",
468 		       (unsigned long)msg.msg_controllen);
469 
470 	if (msg.msg_controllen == 0 && cfg_cmsg_types.cmsg_enabled)
471 		xerror("%s\n", "got no cmsg data");
472 
473 	if (msg.msg_controllen)
474 		process_cmsg(&msg);
475 
476 	if (cfg_cmsg_types.tcp_inq) {
477 		if ((size_t)ret < len && last_hint > (unsigned int)ret) {
478 			if (ret + 1 != (int)last_hint) {
479 				int next = read(fd, msg_buf, sizeof(msg_buf));
480 
481 				xerror("read %u of %u, last_hint was %u tcp_inq hint now %u next_read returned %d/%m\n",
482 				       ret, (unsigned int)len, last_hint, tcp_inq.last, next);
483 			} else {
484 				tcp_inq.expect_eof = true;
485 			}
486 		}
487 	}
488 
489 	return ret;
490 }
491 
492 static ssize_t do_rnd_read(const int fd, char *buf, const size_t len)
493 {
494 	int ret = 0;
495 	char tmp[16384];
496 	size_t cap = rand();
497 
498 	cap &= 0xffff;
499 
500 	if (cap == 0)
501 		cap = 1;
502 	else if (cap > len)
503 		cap = len;
504 
505 	if (cfg_peek == CFG_WITH_PEEK) {
506 		ret = recv(fd, buf, cap, MSG_PEEK);
507 		ret = (ret < 0) ? ret : read(fd, tmp, ret);
508 	} else if (cfg_peek == CFG_AFTER_PEEK) {
509 		ret = recv(fd, buf, cap, MSG_PEEK);
510 		ret = (ret < 0) ? ret : read(fd, buf, cap);
511 	} else if (cfg_cmsg_types.cmsg_enabled) {
512 		ret = do_recvmsg_cmsg(fd, buf, cap);
513 	} else {
514 		ret = read(fd, buf, cap);
515 	}
516 
517 	return ret;
518 }
519 
520 static void set_nonblock(int fd)
521 {
522 	int flags = fcntl(fd, F_GETFL);
523 
524 	if (flags == -1)
525 		return;
526 
527 	fcntl(fd, F_SETFL, flags | O_NONBLOCK);
528 }
529 
530 static int copyfd_io_poll(int infd, int peerfd, int outfd, bool *in_closed_after_out)
531 {
532 	struct pollfd fds = {
533 		.fd = peerfd,
534 		.events = POLLIN | POLLOUT,
535 	};
536 	unsigned int woff = 0, wlen = 0;
537 	char wbuf[8192];
538 
539 	set_nonblock(peerfd);
540 
541 	for (;;) {
542 		char rbuf[8192];
543 		ssize_t len;
544 
545 		if (fds.events == 0)
546 			break;
547 
548 		switch (poll(&fds, 1, poll_timeout)) {
549 		case -1:
550 			if (errno == EINTR)
551 				continue;
552 			perror("poll");
553 			return 1;
554 		case 0:
555 			fprintf(stderr, "%s: poll timed out (events: "
556 				"POLLIN %u, POLLOUT %u)\n", __func__,
557 				fds.events & POLLIN, fds.events & POLLOUT);
558 			return 2;
559 		}
560 
561 		if (fds.revents & POLLIN) {
562 			len = do_rnd_read(peerfd, rbuf, sizeof(rbuf));
563 			if (len == 0) {
564 				/* no more data to receive:
565 				 * peer has closed its write side
566 				 */
567 				fds.events &= ~POLLIN;
568 
569 				if ((fds.events & POLLOUT) == 0) {
570 					*in_closed_after_out = true;
571 					/* and nothing more to send */
572 					break;
573 				}
574 
575 			/* Else, still have data to transmit */
576 			} else if (len < 0) {
577 				perror("read");
578 				return 3;
579 			}
580 
581 			do_write(outfd, rbuf, len);
582 		}
583 
584 		if (fds.revents & POLLOUT) {
585 			if (wlen == 0) {
586 				woff = 0;
587 				wlen = read(infd, wbuf, sizeof(wbuf));
588 			}
589 
590 			if (wlen > 0) {
591 				ssize_t bw;
592 
593 				bw = do_rnd_write(peerfd, wbuf + woff, wlen);
594 				if (bw < 0)
595 					return 111;
596 
597 				woff += bw;
598 				wlen -= bw;
599 			} else if (wlen == 0) {
600 				/* We have no more data to send. */
601 				fds.events &= ~POLLOUT;
602 
603 				if ((fds.events & POLLIN) == 0)
604 					/* ... and peer also closed already */
605 					break;
606 
607 				/* ... but we still receive.
608 				 * Close our write side, ev. give some time
609 				 * for address notification and/or checking
610 				 * the current status
611 				 */
612 				if (cfg_wait)
613 					usleep(cfg_wait);
614 				shutdown(peerfd, SHUT_WR);
615 			} else {
616 				if (errno == EINTR)
617 					continue;
618 				perror("read");
619 				return 4;
620 			}
621 		}
622 
623 		if (fds.revents & (POLLERR | POLLNVAL)) {
624 			fprintf(stderr, "Unexpected revents: "
625 				"POLLERR/POLLNVAL(%x)\n", fds.revents);
626 			return 5;
627 		}
628 	}
629 
630 	/* leave some time for late join/announce */
631 	if (cfg_remove)
632 		usleep(cfg_wait);
633 
634 	close(peerfd);
635 	return 0;
636 }
637 
638 static int do_recvfile(int infd, int outfd)
639 {
640 	ssize_t r;
641 
642 	do {
643 		char buf[16384];
644 
645 		r = do_rnd_read(infd, buf, sizeof(buf));
646 		if (r > 0) {
647 			if (write(outfd, buf, r) != r)
648 				break;
649 		} else if (r < 0) {
650 			perror("read");
651 		}
652 	} while (r > 0);
653 
654 	return (int)r;
655 }
656 
657 static int do_mmap(int infd, int outfd, unsigned int size)
658 {
659 	char *inbuf = mmap(NULL, size, PROT_READ, MAP_SHARED, infd, 0);
660 	ssize_t ret = 0, off = 0;
661 	size_t rem;
662 
663 	if (inbuf == MAP_FAILED) {
664 		perror("mmap");
665 		return 1;
666 	}
667 
668 	rem = size;
669 
670 	while (rem > 0) {
671 		ret = write(outfd, inbuf + off, rem);
672 
673 		if (ret < 0) {
674 			perror("write");
675 			break;
676 		}
677 
678 		off += ret;
679 		rem -= ret;
680 	}
681 
682 	munmap(inbuf, size);
683 	return rem;
684 }
685 
686 static int get_infd_size(int fd)
687 {
688 	struct stat sb;
689 	ssize_t count;
690 	int err;
691 
692 	err = fstat(fd, &sb);
693 	if (err < 0) {
694 		perror("fstat");
695 		return -1;
696 	}
697 
698 	if ((sb.st_mode & S_IFMT) != S_IFREG) {
699 		fprintf(stderr, "%s: stdin is not a regular file\n", __func__);
700 		return -2;
701 	}
702 
703 	count = sb.st_size;
704 	if (count > INT_MAX) {
705 		fprintf(stderr, "File too large: %zu\n", count);
706 		return -3;
707 	}
708 
709 	return (int)count;
710 }
711 
712 static int do_sendfile(int infd, int outfd, unsigned int count)
713 {
714 	while (count > 0) {
715 		ssize_t r;
716 
717 		r = sendfile(outfd, infd, NULL, count);
718 		if (r < 0) {
719 			perror("sendfile");
720 			return 3;
721 		}
722 
723 		count -= r;
724 	}
725 
726 	return 0;
727 }
728 
729 static int copyfd_io_mmap(int infd, int peerfd, int outfd,
730 			  unsigned int size, bool *in_closed_after_out)
731 {
732 	int err;
733 
734 	if (listen_mode) {
735 		err = do_recvfile(peerfd, outfd);
736 		if (err)
737 			return err;
738 
739 		err = do_mmap(infd, peerfd, size);
740 	} else {
741 		err = do_mmap(infd, peerfd, size);
742 		if (err)
743 			return err;
744 
745 		shutdown(peerfd, SHUT_WR);
746 
747 		err = do_recvfile(peerfd, outfd);
748 		*in_closed_after_out = true;
749 	}
750 
751 	return err;
752 }
753 
754 static int copyfd_io_sendfile(int infd, int peerfd, int outfd,
755 			      unsigned int size, bool *in_closed_after_out)
756 {
757 	int err;
758 
759 	if (listen_mode) {
760 		err = do_recvfile(peerfd, outfd);
761 		if (err)
762 			return err;
763 
764 		err = do_sendfile(infd, peerfd, size);
765 	} else {
766 		err = do_sendfile(infd, peerfd, size);
767 		if (err)
768 			return err;
769 		err = do_recvfile(peerfd, outfd);
770 		*in_closed_after_out = true;
771 	}
772 
773 	return err;
774 }
775 
776 static int copyfd_io(int infd, int peerfd, int outfd)
777 {
778 	bool in_closed_after_out = false;
779 	struct timespec start, end;
780 	int file_size;
781 	int ret;
782 
783 	if (cfg_time && (clock_gettime(CLOCK_MONOTONIC, &start) < 0))
784 		xerror("can not fetch start time %d", errno);
785 
786 	switch (cfg_mode) {
787 	case CFG_MODE_POLL:
788 		ret = copyfd_io_poll(infd, peerfd, outfd, &in_closed_after_out);
789 		break;
790 
791 	case CFG_MODE_MMAP:
792 		file_size = get_infd_size(infd);
793 		if (file_size < 0)
794 			return file_size;
795 		ret = copyfd_io_mmap(infd, peerfd, outfd, file_size, &in_closed_after_out);
796 		break;
797 
798 	case CFG_MODE_SENDFILE:
799 		file_size = get_infd_size(infd);
800 		if (file_size < 0)
801 			return file_size;
802 		ret = copyfd_io_sendfile(infd, peerfd, outfd, file_size, &in_closed_after_out);
803 		break;
804 
805 	default:
806 		fprintf(stderr, "Invalid mode %d\n", cfg_mode);
807 
808 		die_usage();
809 		return 1;
810 	}
811 
812 	if (ret)
813 		return ret;
814 
815 	if (cfg_time) {
816 		unsigned int delta_ms;
817 
818 		if (clock_gettime(CLOCK_MONOTONIC, &end) < 0)
819 			xerror("can not fetch end time %d", errno);
820 		delta_ms = (end.tv_sec - start.tv_sec) * 1000 + (end.tv_nsec - start.tv_nsec) / 1000000;
821 		if (delta_ms > cfg_time) {
822 			xerror("transfer slower than expected! runtime %d ms, expected %d ms",
823 			       delta_ms, cfg_time);
824 		}
825 
826 		/* show the runtime only if this end shutdown(wr) before receiving the EOF,
827 		 * (that is, if this end got the longer runtime)
828 		 */
829 		if (in_closed_after_out)
830 			fprintf(stderr, "%d", delta_ms);
831 	}
832 
833 	return 0;
834 }
835 
836 static void check_sockaddr(int pf, struct sockaddr_storage *ss,
837 			   socklen_t salen)
838 {
839 	struct sockaddr_in6 *sin6;
840 	struct sockaddr_in *sin;
841 	socklen_t wanted_size = 0;
842 
843 	switch (pf) {
844 	case AF_INET:
845 		wanted_size = sizeof(*sin);
846 		sin = (void *)ss;
847 		if (!sin->sin_port)
848 			fprintf(stderr, "accept: something wrong: ip connection from port 0");
849 		break;
850 	case AF_INET6:
851 		wanted_size = sizeof(*sin6);
852 		sin6 = (void *)ss;
853 		if (!sin6->sin6_port)
854 			fprintf(stderr, "accept: something wrong: ipv6 connection from port 0");
855 		break;
856 	default:
857 		fprintf(stderr, "accept: Unknown pf %d, salen %u\n", pf, salen);
858 		return;
859 	}
860 
861 	if (salen != wanted_size)
862 		fprintf(stderr, "accept: size mismatch, got %d expected %d\n",
863 			(int)salen, wanted_size);
864 
865 	if (ss->ss_family != pf)
866 		fprintf(stderr, "accept: pf mismatch, expect %d, ss_family is %d\n",
867 			(int)ss->ss_family, pf);
868 }
869 
870 static void check_getpeername(int fd, struct sockaddr_storage *ss, socklen_t salen)
871 {
872 	struct sockaddr_storage peerss;
873 	socklen_t peersalen = sizeof(peerss);
874 
875 	if (getpeername(fd, (struct sockaddr *)&peerss, &peersalen) < 0) {
876 		perror("getpeername");
877 		return;
878 	}
879 
880 	if (peersalen != salen) {
881 		fprintf(stderr, "%s: %d vs %d\n", __func__, peersalen, salen);
882 		return;
883 	}
884 
885 	if (memcmp(ss, &peerss, peersalen)) {
886 		char a[INET6_ADDRSTRLEN];
887 		char b[INET6_ADDRSTRLEN];
888 		char c[INET6_ADDRSTRLEN];
889 		char d[INET6_ADDRSTRLEN];
890 
891 		xgetnameinfo((struct sockaddr *)ss, salen,
892 			     a, sizeof(a), b, sizeof(b));
893 
894 		xgetnameinfo((struct sockaddr *)&peerss, peersalen,
895 			     c, sizeof(c), d, sizeof(d));
896 
897 		fprintf(stderr, "%s: memcmp failure: accept %s vs peername %s, %s vs %s salen %d vs %d\n",
898 			__func__, a, c, b, d, peersalen, salen);
899 	}
900 }
901 
902 static void check_getpeername_connect(int fd)
903 {
904 	struct sockaddr_storage ss;
905 	socklen_t salen = sizeof(ss);
906 	char a[INET6_ADDRSTRLEN];
907 	char b[INET6_ADDRSTRLEN];
908 
909 	if (getpeername(fd, (struct sockaddr *)&ss, &salen) < 0) {
910 		perror("getpeername");
911 		return;
912 	}
913 
914 	xgetnameinfo((struct sockaddr *)&ss, salen,
915 		     a, sizeof(a), b, sizeof(b));
916 
917 	if (strcmp(cfg_host, a) || strcmp(cfg_port, b))
918 		fprintf(stderr, "%s: %s vs %s, %s vs %s\n", __func__,
919 			cfg_host, a, cfg_port, b);
920 }
921 
922 static void maybe_close(int fd)
923 {
924 	unsigned int r = rand();
925 
926 	if (!(cfg_join || cfg_remove) && (r & 1))
927 		close(fd);
928 }
929 
930 int main_loop_s(int listensock)
931 {
932 	struct sockaddr_storage ss;
933 	struct pollfd polls;
934 	socklen_t salen;
935 	int remotesock;
936 
937 	polls.fd = listensock;
938 	polls.events = POLLIN;
939 
940 	switch (poll(&polls, 1, poll_timeout)) {
941 	case -1:
942 		perror("poll");
943 		return 1;
944 	case 0:
945 		fprintf(stderr, "%s: timed out\n", __func__);
946 		close(listensock);
947 		return 2;
948 	}
949 
950 	salen = sizeof(ss);
951 	remotesock = accept(listensock, (struct sockaddr *)&ss, &salen);
952 	if (remotesock >= 0) {
953 		maybe_close(listensock);
954 		check_sockaddr(pf, &ss, salen);
955 		check_getpeername(remotesock, &ss, salen);
956 
957 		return copyfd_io(0, remotesock, 1);
958 	}
959 
960 	perror("accept");
961 
962 	return 1;
963 }
964 
965 static void init_rng(void)
966 {
967 	int fd = open("/dev/urandom", O_RDONLY);
968 	unsigned int foo;
969 
970 	if (fd > 0) {
971 		int ret = read(fd, &foo, sizeof(foo));
972 
973 		if (ret < 0)
974 			srand(fd + foo);
975 		close(fd);
976 	}
977 
978 	srand(foo);
979 }
980 
981 static void xsetsockopt(int fd, int level, int optname, const void *optval, socklen_t optlen)
982 {
983 	int err;
984 
985 	err = setsockopt(fd, level, optname, optval, optlen);
986 	if (err) {
987 		perror("setsockopt");
988 		exit(1);
989 	}
990 }
991 
992 static void apply_cmsg_types(int fd, const struct cfg_cmsg_types *cmsg)
993 {
994 	static const unsigned int on = 1;
995 
996 	if (cmsg->timestampns)
997 		xsetsockopt(fd, SOL_SOCKET, SO_TIMESTAMPNS_NEW, &on, sizeof(on));
998 	if (cmsg->tcp_inq)
999 		xsetsockopt(fd, IPPROTO_TCP, TCP_INQ, &on, sizeof(on));
1000 }
1001 
1002 static void parse_cmsg_types(const char *type)
1003 {
1004 	char *next = strchr(type, ',');
1005 	unsigned int len = 0;
1006 
1007 	cfg_cmsg_types.cmsg_enabled = 1;
1008 
1009 	if (next) {
1010 		parse_cmsg_types(next + 1);
1011 		len = next - type;
1012 	} else {
1013 		len = strlen(type);
1014 	}
1015 
1016 	if (strncmp(type, "TIMESTAMPNS", len) == 0) {
1017 		cfg_cmsg_types.timestampns = 1;
1018 		return;
1019 	}
1020 
1021 	if (strncmp(type, "TCPINQ", len) == 0) {
1022 		cfg_cmsg_types.tcp_inq = 1;
1023 		return;
1024 	}
1025 
1026 	fprintf(stderr, "Unrecognized cmsg option %s\n", type);
1027 	exit(1);
1028 }
1029 
1030 static void parse_setsock_options(const char *name)
1031 {
1032 	char *next = strchr(name, ',');
1033 	unsigned int len = 0;
1034 
1035 	if (next) {
1036 		parse_setsock_options(next + 1);
1037 		len = next - name;
1038 	} else {
1039 		len = strlen(name);
1040 	}
1041 
1042 	if (strncmp(name, "TRANSPARENT", len) == 0) {
1043 		cfg_sockopt_types.transparent = 1;
1044 		return;
1045 	}
1046 
1047 	fprintf(stderr, "Unrecognized setsockopt option %s\n", name);
1048 	exit(1);
1049 }
1050 
1051 int main_loop(void)
1052 {
1053 	int fd;
1054 
1055 	/* listener is ready. */
1056 	fd = sock_connect_mptcp(cfg_host, cfg_port, cfg_sock_proto);
1057 	if (fd < 0)
1058 		return 2;
1059 
1060 	check_getpeername_connect(fd);
1061 
1062 	if (cfg_rcvbuf)
1063 		set_rcvbuf(fd, cfg_rcvbuf);
1064 	if (cfg_sndbuf)
1065 		set_sndbuf(fd, cfg_sndbuf);
1066 	if (cfg_cmsg_types.cmsg_enabled)
1067 		apply_cmsg_types(fd, &cfg_cmsg_types);
1068 
1069 	return copyfd_io(0, fd, 1);
1070 }
1071 
1072 int parse_proto(const char *proto)
1073 {
1074 	if (!strcasecmp(proto, "MPTCP"))
1075 		return IPPROTO_MPTCP;
1076 	if (!strcasecmp(proto, "TCP"))
1077 		return IPPROTO_TCP;
1078 
1079 	fprintf(stderr, "Unknown protocol: %s\n.", proto);
1080 	die_usage();
1081 
1082 	/* silence compiler warning */
1083 	return 0;
1084 }
1085 
1086 int parse_mode(const char *mode)
1087 {
1088 	if (!strcasecmp(mode, "poll"))
1089 		return CFG_MODE_POLL;
1090 	if (!strcasecmp(mode, "mmap"))
1091 		return CFG_MODE_MMAP;
1092 	if (!strcasecmp(mode, "sendfile"))
1093 		return CFG_MODE_SENDFILE;
1094 
1095 	fprintf(stderr, "Unknown test mode: %s\n", mode);
1096 	fprintf(stderr, "Supported modes are:\n");
1097 	fprintf(stderr, "\t\t\"poll\" - interleaved read/write using poll()\n");
1098 	fprintf(stderr, "\t\t\"mmap\" - send entire input file (mmap+write), then read response (-l will read input first)\n");
1099 	fprintf(stderr, "\t\t\"sendfile\" - send entire input file (sendfile), then read response (-l will read input first)\n");
1100 
1101 	die_usage();
1102 
1103 	/* silence compiler warning */
1104 	return 0;
1105 }
1106 
1107 int parse_peek(const char *mode)
1108 {
1109 	if (!strcasecmp(mode, "saveWithPeek"))
1110 		return CFG_WITH_PEEK;
1111 	if (!strcasecmp(mode, "saveAfterPeek"))
1112 		return CFG_AFTER_PEEK;
1113 
1114 	fprintf(stderr, "Unknown: %s\n", mode);
1115 	fprintf(stderr, "Supported MSG_PEEK mode are:\n");
1116 	fprintf(stderr,
1117 		"\t\t\"saveWithPeek\" - recv data with flags 'MSG_PEEK' and save the peek data into file\n");
1118 	fprintf(stderr,
1119 		"\t\t\"saveAfterPeek\" - read and save data into file after recv with flags 'MSG_PEEK'\n");
1120 
1121 	die_usage();
1122 
1123 	/* silence compiler warning */
1124 	return 0;
1125 }
1126 
1127 static int parse_int(const char *size)
1128 {
1129 	unsigned long s;
1130 
1131 	errno = 0;
1132 
1133 	s = strtoul(size, NULL, 0);
1134 
1135 	if (errno) {
1136 		fprintf(stderr, "Invalid sndbuf size %s (%s)\n",
1137 			size, strerror(errno));
1138 		die_usage();
1139 	}
1140 
1141 	if (s > INT_MAX) {
1142 		fprintf(stderr, "Invalid sndbuf size %s (%s)\n",
1143 			size, strerror(ERANGE));
1144 		die_usage();
1145 	}
1146 
1147 	return (int)s;
1148 }
1149 
1150 static void parse_opts(int argc, char **argv)
1151 {
1152 	int c;
1153 
1154 	while ((c = getopt(argc, argv, "6jr:lp:s:hut:T:m:S:R:w:M:P:c:o:")) != -1) {
1155 		switch (c) {
1156 		case 'j':
1157 			cfg_join = true;
1158 			cfg_mode = CFG_MODE_POLL;
1159 			break;
1160 		case 'r':
1161 			cfg_remove = true;
1162 			cfg_mode = CFG_MODE_POLL;
1163 			cfg_wait = 400000;
1164 			cfg_do_w = atoi(optarg);
1165 			if (cfg_do_w <= 0)
1166 				cfg_do_w = 50;
1167 			break;
1168 		case 'l':
1169 			listen_mode = true;
1170 			break;
1171 		case 'p':
1172 			cfg_port = optarg;
1173 			break;
1174 		case 's':
1175 			cfg_sock_proto = parse_proto(optarg);
1176 			break;
1177 		case 'h':
1178 			die_usage();
1179 			break;
1180 		case 'u':
1181 			tcpulp_audit = true;
1182 			break;
1183 		case '6':
1184 			pf = AF_INET6;
1185 			break;
1186 		case 't':
1187 			poll_timeout = atoi(optarg) * 1000;
1188 			if (poll_timeout <= 0)
1189 				poll_timeout = -1;
1190 			break;
1191 		case 'T':
1192 			cfg_time = atoi(optarg);
1193 			break;
1194 		case 'm':
1195 			cfg_mode = parse_mode(optarg);
1196 			break;
1197 		case 'S':
1198 			cfg_sndbuf = parse_int(optarg);
1199 			break;
1200 		case 'R':
1201 			cfg_rcvbuf = parse_int(optarg);
1202 			break;
1203 		case 'w':
1204 			cfg_wait = atoi(optarg)*1000000;
1205 			break;
1206 		case 'M':
1207 			cfg_mark = strtol(optarg, NULL, 0);
1208 			break;
1209 		case 'P':
1210 			cfg_peek = parse_peek(optarg);
1211 			break;
1212 		case 'c':
1213 			parse_cmsg_types(optarg);
1214 			break;
1215 		case 'o':
1216 			parse_setsock_options(optarg);
1217 			break;
1218 		}
1219 	}
1220 
1221 	if (optind + 1 != argc)
1222 		die_usage();
1223 	cfg_host = argv[optind];
1224 
1225 	if (strchr(cfg_host, ':'))
1226 		pf = AF_INET6;
1227 }
1228 
1229 int main(int argc, char *argv[])
1230 {
1231 	init_rng();
1232 
1233 	signal(SIGUSR1, handle_signal);
1234 	parse_opts(argc, argv);
1235 
1236 	if (tcpulp_audit)
1237 		return sock_test_tcpulp(cfg_host, cfg_port) ? 0 : 1;
1238 
1239 	if (listen_mode) {
1240 		int fd = sock_listen_mptcp(cfg_host, cfg_port);
1241 
1242 		if (fd < 0)
1243 			return 1;
1244 
1245 		if (cfg_rcvbuf)
1246 			set_rcvbuf(fd, cfg_rcvbuf);
1247 		if (cfg_sndbuf)
1248 			set_sndbuf(fd, cfg_sndbuf);
1249 		if (cfg_mark)
1250 			set_mark(fd, cfg_mark);
1251 		if (cfg_cmsg_types.cmsg_enabled)
1252 			apply_cmsg_types(fd, &cfg_cmsg_types);
1253 
1254 		return main_loop_s(fd);
1255 	}
1256 
1257 	return main_loop();
1258 }
1259