1 // SPDX-License-Identifier: GPL-2.0
2
3 #define _GNU_SOURCE
4
5 #include <assert.h>
6 #include <errno.h>
7 #include <fcntl.h>
8 #include <limits.h>
9 #include <string.h>
10 #include <stdarg.h>
11 #include <stdbool.h>
12 #include <stdint.h>
13 #include <inttypes.h>
14 #include <stdio.h>
15 #include <stdlib.h>
16 #include <strings.h>
17 #include <unistd.h>
18 #include <time.h>
19
20 #include <sys/ioctl.h>
21 #include <sys/random.h>
22 #include <sys/socket.h>
23 #include <sys/types.h>
24 #include <sys/wait.h>
25
26 #include <netdb.h>
27 #include <netinet/in.h>
28
29 #include <linux/tcp.h>
30 #include <linux/sockios.h>
31 #include <linux/compiler.h>
32
33 #ifndef IPPROTO_MPTCP
34 #define IPPROTO_MPTCP 262
35 #endif
36 #ifndef SOL_MPTCP
37 #define SOL_MPTCP 284
38 #endif
39
40 static int pf = AF_INET;
41 static int proto_tx = IPPROTO_MPTCP;
42 static int proto_rx = IPPROTO_MPTCP;
43
die_perror(const char * msg)44 static void __noreturn die_perror(const char *msg)
45 {
46 perror(msg);
47 exit(1);
48 }
49
die_usage(int r)50 static void die_usage(int r)
51 {
52 fprintf(stderr, "Usage: mptcp_inq [-6] [ -t tcp|mptcp ] [ -r tcp|mptcp]\n");
53 exit(r);
54 }
55
xerror(const char * fmt,...)56 static void __noreturn xerror(const char *fmt, ...)
57 {
58 va_list ap;
59
60 va_start(ap, fmt);
61 vfprintf(stderr, fmt, ap);
62 va_end(ap);
63 fputc('\n', stderr);
64 exit(1);
65 }
66
getxinfo_strerr(int err)67 static const char *getxinfo_strerr(int err)
68 {
69 if (err == EAI_SYSTEM)
70 return strerror(errno);
71
72 return gai_strerror(err);
73 }
74
xgetaddrinfo(const char * node,const char * service,struct addrinfo * hints,struct addrinfo ** res)75 static void xgetaddrinfo(const char *node, const char *service,
76 struct addrinfo *hints,
77 struct addrinfo **res)
78 {
79 int err;
80
81 again:
82 err = getaddrinfo(node, service, hints, res);
83 if (err) {
84 const char *errstr;
85
86 if (err == EAI_SOCKTYPE) {
87 hints->ai_protocol = IPPROTO_TCP;
88 goto again;
89 }
90
91 errstr = getxinfo_strerr(err);
92
93 fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n",
94 node ? node : "", service ? service : "", errstr);
95 exit(1);
96 }
97 }
98
sock_listen_mptcp(const char * const listenaddr,const char * const port)99 static int sock_listen_mptcp(const char * const listenaddr,
100 const char * const port)
101 {
102 int sock = -1;
103 struct addrinfo hints = {
104 .ai_protocol = IPPROTO_MPTCP,
105 .ai_socktype = SOCK_STREAM,
106 .ai_flags = AI_PASSIVE | AI_NUMERICHOST
107 };
108
109 hints.ai_family = pf;
110
111 struct addrinfo *a, *addr;
112 int one = 1;
113
114 xgetaddrinfo(listenaddr, port, &hints, &addr);
115 hints.ai_family = pf;
116
117 for (a = addr; a; a = a->ai_next) {
118 sock = socket(a->ai_family, a->ai_socktype, proto_rx);
119 if (sock < 0)
120 continue;
121
122 if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one,
123 sizeof(one)))
124 perror("setsockopt");
125
126 if (bind(sock, a->ai_addr, a->ai_addrlen) == 0)
127 break; /* success */
128
129 perror("bind");
130 close(sock);
131 sock = -1;
132 }
133
134 freeaddrinfo(addr);
135
136 if (sock < 0)
137 xerror("could not create listen socket");
138
139 if (listen(sock, 20))
140 die_perror("listen");
141
142 return sock;
143 }
144
sock_connect_mptcp(const char * const remoteaddr,const char * const port,int proto)145 static int sock_connect_mptcp(const char * const remoteaddr,
146 const char * const port, int proto)
147 {
148 struct addrinfo hints = {
149 .ai_protocol = IPPROTO_MPTCP,
150 .ai_socktype = SOCK_STREAM,
151 };
152 struct addrinfo *a, *addr;
153 int sock = -1;
154
155 hints.ai_family = pf;
156
157 xgetaddrinfo(remoteaddr, port, &hints, &addr);
158 for (a = addr; a; a = a->ai_next) {
159 sock = socket(a->ai_family, a->ai_socktype, proto);
160 if (sock < 0)
161 continue;
162
163 if (connect(sock, a->ai_addr, a->ai_addrlen) == 0)
164 break; /* success */
165
166 die_perror("connect");
167 }
168
169 if (sock < 0)
170 xerror("could not create connect socket");
171
172 freeaddrinfo(addr);
173 return sock;
174 }
175
protostr_to_num(const char * s)176 static int protostr_to_num(const char *s)
177 {
178 if (strcasecmp(s, "tcp") == 0)
179 return IPPROTO_TCP;
180 if (strcasecmp(s, "mptcp") == 0)
181 return IPPROTO_MPTCP;
182
183 die_usage(1);
184 return 0;
185 }
186
parse_opts(int argc,char ** argv)187 static void parse_opts(int argc, char **argv)
188 {
189 int c;
190
191 while ((c = getopt(argc, argv, "h6t:r:")) != -1) {
192 switch (c) {
193 case 'h':
194 die_usage(0);
195 break;
196 case '6':
197 pf = AF_INET6;
198 break;
199 case 't':
200 proto_tx = protostr_to_num(optarg);
201 break;
202 case 'r':
203 proto_rx = protostr_to_num(optarg);
204 break;
205 default:
206 die_usage(1);
207 break;
208 }
209 }
210 }
211
212 /* wait up to timeout milliseconds */
wait_for_ack(int fd,int timeout,size_t total)213 static void wait_for_ack(int fd, int timeout, size_t total)
214 {
215 int i;
216
217 for (i = 0; i < timeout; i++) {
218 int nsd, ret, queued = -1;
219 struct timespec req;
220
221 ret = ioctl(fd, TIOCOUTQ, &queued);
222 if (ret < 0)
223 die_perror("TIOCOUTQ");
224
225 ret = ioctl(fd, SIOCOUTQNSD, &nsd);
226 if (ret < 0)
227 die_perror("SIOCOUTQNSD");
228
229 if ((size_t)queued > total)
230 xerror("TIOCOUTQ %u, but only %zu expected\n", queued, total);
231 assert(nsd <= queued);
232
233 if (queued == 0)
234 return;
235
236 /* wait for peer to ack rx of all data */
237 req.tv_sec = 0;
238 req.tv_nsec = 1 * 1000 * 1000ul; /* 1ms */
239 nanosleep(&req, NULL);
240 }
241
242 xerror("still tx data queued after %u ms\n", timeout);
243 }
244
connect_one_server(int fd,int unixfd)245 static void connect_one_server(int fd, int unixfd)
246 {
247 size_t len, i, total, sent;
248 char buf[4096], buf2[4096];
249 ssize_t ret;
250
251 len = rand() % (sizeof(buf) - 1);
252
253 if (len < 128)
254 len = 128;
255
256 for (i = 0; i < len ; i++) {
257 buf[i] = rand() % 26;
258 buf[i] += 'A';
259 }
260
261 buf[i] = '\n';
262
263 /* un-block server */
264 ret = read(unixfd, buf2, 4);
265 assert(ret == 4);
266
267 assert(strncmp(buf2, "xmit", 4) == 0);
268
269 ret = write(unixfd, &len, sizeof(len));
270 assert(ret == (ssize_t)sizeof(len));
271
272 ret = write(fd, buf, len);
273 if (ret < 0)
274 die_perror("write");
275
276 if (ret != (ssize_t)len)
277 xerror("short write");
278
279 ret = read(unixfd, buf2, 4);
280 assert(strncmp(buf2, "huge", 4) == 0);
281
282 total = rand() % (16 * 1024 * 1024);
283 total += (1 * 1024 * 1024);
284 sent = total;
285
286 ret = write(unixfd, &total, sizeof(total));
287 assert(ret == (ssize_t)sizeof(total));
288
289 wait_for_ack(fd, 5000, len);
290
291 while (total > 0) {
292 if (total > sizeof(buf))
293 len = sizeof(buf);
294 else
295 len = total;
296
297 ret = write(fd, buf, len);
298 if (ret < 0)
299 die_perror("write");
300 total -= ret;
301
302 /* we don't have to care about buf content, only
303 * number of total bytes sent
304 */
305 }
306
307 ret = read(unixfd, buf2, 4);
308 assert(ret == 4);
309 assert(strncmp(buf2, "shut", 4) == 0);
310
311 wait_for_ack(fd, 5000, sent);
312
313 ret = write(fd, buf, 1);
314 assert(ret == 1);
315 close(fd);
316 ret = write(unixfd, "closed", 6);
317 assert(ret == 6);
318
319 close(unixfd);
320 }
321
get_tcp_inq(struct msghdr * msgh,unsigned int * inqv)322 static void get_tcp_inq(struct msghdr *msgh, unsigned int *inqv)
323 {
324 struct cmsghdr *cmsg;
325
326 for (cmsg = CMSG_FIRSTHDR(msgh); cmsg ; cmsg = CMSG_NXTHDR(msgh, cmsg)) {
327 if (cmsg->cmsg_level == IPPROTO_TCP && cmsg->cmsg_type == TCP_CM_INQ) {
328 memcpy(inqv, CMSG_DATA(cmsg), sizeof(*inqv));
329 return;
330 }
331 }
332
333 xerror("could not find TCP_CM_INQ cmsg type");
334 }
335
process_one_client(int fd,int unixfd)336 static void process_one_client(int fd, int unixfd)
337 {
338 unsigned int tcp_inq;
339 size_t expect_len;
340 char msg_buf[4096];
341 char buf[4096];
342 char tmp[16];
343 struct iovec iov = {
344 .iov_base = buf,
345 .iov_len = 1,
346 };
347 struct msghdr msg = {
348 .msg_iov = &iov,
349 .msg_iovlen = 1,
350 .msg_control = msg_buf,
351 .msg_controllen = sizeof(msg_buf),
352 };
353 ssize_t ret, tot;
354
355 ret = write(unixfd, "xmit", 4);
356 assert(ret == 4);
357
358 ret = read(unixfd, &expect_len, sizeof(expect_len));
359 assert(ret == (ssize_t)sizeof(expect_len));
360
361 if (expect_len > sizeof(buf))
362 xerror("expect len %zu exceeds buffer size", expect_len);
363
364 for (;;) {
365 struct timespec req;
366 unsigned int queued;
367
368 ret = ioctl(fd, FIONREAD, &queued);
369 if (ret < 0)
370 die_perror("FIONREAD");
371 if (queued > expect_len)
372 xerror("FIONREAD returned %u, but only %zu expected\n",
373 queued, expect_len);
374 if (queued == expect_len)
375 break;
376
377 req.tv_sec = 0;
378 req.tv_nsec = 1000 * 1000ul;
379 nanosleep(&req, NULL);
380 }
381
382 /* read one byte, expect cmsg to return expected - 1 */
383 ret = recvmsg(fd, &msg, 0);
384 if (ret < 0)
385 die_perror("recvmsg");
386
387 if (msg.msg_controllen == 0)
388 xerror("msg_controllen is 0");
389
390 get_tcp_inq(&msg, &tcp_inq);
391
392 assert((size_t)tcp_inq == (expect_len - 1));
393
394 iov.iov_len = sizeof(buf);
395 ret = recvmsg(fd, &msg, 0);
396 if (ret < 0)
397 die_perror("recvmsg");
398
399 /* should have gotten exact remainder of all pending data */
400 assert(ret == (ssize_t)tcp_inq);
401
402 /* should be 0, all drained */
403 get_tcp_inq(&msg, &tcp_inq);
404 assert(tcp_inq == 0);
405
406 /* request a large swath of data. */
407 ret = write(unixfd, "huge", 4);
408 assert(ret == 4);
409
410 ret = read(unixfd, &expect_len, sizeof(expect_len));
411 assert(ret == (ssize_t)sizeof(expect_len));
412
413 /* peer should send us a few mb of data */
414 if (expect_len <= sizeof(buf))
415 xerror("expect len %zu too small\n", expect_len);
416
417 tot = 0;
418 do {
419 iov.iov_len = sizeof(buf);
420 ret = recvmsg(fd, &msg, 0);
421 if (ret < 0)
422 die_perror("recvmsg");
423
424 tot += ret;
425
426 get_tcp_inq(&msg, &tcp_inq);
427
428 if (tcp_inq > expect_len - tot)
429 xerror("inq %d, remaining %d total_len %d\n",
430 tcp_inq, expect_len - tot, (int)expect_len);
431
432 assert(tcp_inq <= expect_len - tot);
433 } while ((size_t)tot < expect_len);
434
435 ret = write(unixfd, "shut", 4);
436 assert(ret == 4);
437
438 /* wait for hangup. Should have received one more byte of data. */
439 ret = read(unixfd, tmp, sizeof(tmp));
440 assert(ret == 6);
441 assert(strncmp(tmp, "closed", 6) == 0);
442
443 sleep(1);
444
445 iov.iov_len = 1;
446 ret = recvmsg(fd, &msg, 0);
447 if (ret < 0)
448 die_perror("recvmsg");
449 assert(ret == 1);
450
451 get_tcp_inq(&msg, &tcp_inq);
452
453 /* tcp_inq should be 1 due to received fin. */
454 assert(tcp_inq == 1);
455
456 iov.iov_len = 1;
457 ret = recvmsg(fd, &msg, 0);
458 if (ret < 0)
459 die_perror("recvmsg");
460
461 /* expect EOF */
462 assert(ret == 0);
463 get_tcp_inq(&msg, &tcp_inq);
464 assert(tcp_inq == 1);
465
466 close(fd);
467 }
468
xaccept(int s)469 static int xaccept(int s)
470 {
471 int fd = accept(s, NULL, 0);
472
473 if (fd < 0)
474 die_perror("accept");
475
476 return fd;
477 }
478
server(int unixfd)479 static int server(int unixfd)
480 {
481 int fd = -1, r, on = 1;
482
483 switch (pf) {
484 case AF_INET:
485 fd = sock_listen_mptcp("127.0.0.1", "15432");
486 break;
487 case AF_INET6:
488 fd = sock_listen_mptcp("::1", "15432");
489 break;
490 default:
491 xerror("Unknown pf %d\n", pf);
492 break;
493 }
494
495 r = write(unixfd, "conn", 4);
496 assert(r == 4);
497
498 alarm(15);
499 r = xaccept(fd);
500
501 if (-1 == setsockopt(r, IPPROTO_TCP, TCP_INQ, &on, sizeof(on)))
502 die_perror("setsockopt");
503
504 process_one_client(r, unixfd);
505
506 close(fd);
507 return 0;
508 }
509
client(int unixfd)510 static int client(int unixfd)
511 {
512 int fd = -1;
513
514 alarm(15);
515
516 switch (pf) {
517 case AF_INET:
518 fd = sock_connect_mptcp("127.0.0.1", "15432", proto_tx);
519 break;
520 case AF_INET6:
521 fd = sock_connect_mptcp("::1", "15432", proto_tx);
522 break;
523 default:
524 xerror("Unknown pf %d\n", pf);
525 }
526
527 connect_one_server(fd, unixfd);
528
529 return 0;
530 }
531
init_rng(void)532 static void init_rng(void)
533 {
534 unsigned int foo;
535
536 if (getrandom(&foo, sizeof(foo), 0) == -1) {
537 perror("getrandom");
538 exit(1);
539 }
540
541 srand(foo);
542 }
543
xfork(void)544 static pid_t xfork(void)
545 {
546 pid_t p = fork();
547
548 if (p < 0)
549 die_perror("fork");
550 else if (p == 0)
551 init_rng();
552
553 return p;
554 }
555
rcheck(int wstatus,const char * what)556 static int rcheck(int wstatus, const char *what)
557 {
558 if (WIFEXITED(wstatus)) {
559 if (WEXITSTATUS(wstatus) == 0)
560 return 0;
561 fprintf(stderr, "%s exited, status=%d\n", what, WEXITSTATUS(wstatus));
562 return WEXITSTATUS(wstatus);
563 } else if (WIFSIGNALED(wstatus)) {
564 xerror("%s killed by signal %d\n", what, WTERMSIG(wstatus));
565 } else if (WIFSTOPPED(wstatus)) {
566 xerror("%s stopped by signal %d\n", what, WSTOPSIG(wstatus));
567 }
568
569 return 111;
570 }
571
main(int argc,char * argv[])572 int main(int argc, char *argv[])
573 {
574 int e1, e2, wstatus;
575 pid_t s, c, ret;
576 int unixfds[2];
577
578 parse_opts(argc, argv);
579
580 e1 = socketpair(AF_UNIX, SOCK_DGRAM, 0, unixfds);
581 if (e1 < 0)
582 die_perror("pipe");
583
584 s = xfork();
585 if (s == 0) {
586 close(unixfds[0]);
587 ret = server(unixfds[1]);
588 close(unixfds[1]);
589 return ret;
590 }
591
592 close(unixfds[1]);
593
594 /* wait until server bound a socket */
595 e1 = read(unixfds[0], &e1, 4);
596 assert(e1 == 4);
597
598 c = xfork();
599 if (c == 0)
600 return client(unixfds[0]);
601
602 close(unixfds[0]);
603
604 ret = waitpid(s, &wstatus, 0);
605 if (ret == -1)
606 die_perror("waitpid");
607 e1 = rcheck(wstatus, "server");
608 ret = waitpid(c, &wstatus, 0);
609 if (ret == -1)
610 die_perror("waitpid");
611 e2 = rcheck(wstatus, "client");
612
613 return e1 ? e1 : e2;
614 }
615