xref: /linux/tools/testing/vsock/vsock_test.c (revision 1a2ac6d7ecdcde74a4e16f31de64124160fc7237)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * vsock_test - vsock.ko test suite
4  *
5  * Copyright (C) 2017 Red Hat, Inc.
6  *
7  * Author: Stefan Hajnoczi <stefanha@redhat.com>
8  */
9 
10 #include <getopt.h>
11 #include <stdio.h>
12 #include <stdlib.h>
13 #include <string.h>
14 #include <errno.h>
15 #include <unistd.h>
16 #include <linux/kernel.h>
17 #include <sys/types.h>
18 #include <sys/socket.h>
19 #include <time.h>
20 #include <sys/mman.h>
21 #include <poll.h>
22 
23 #include "timeout.h"
24 #include "control.h"
25 #include "util.h"
26 
27 static void test_stream_connection_reset(const struct test_opts *opts)
28 {
29 	union {
30 		struct sockaddr sa;
31 		struct sockaddr_vm svm;
32 	} addr = {
33 		.svm = {
34 			.svm_family = AF_VSOCK,
35 			.svm_port = 1234,
36 			.svm_cid = opts->peer_cid,
37 		},
38 	};
39 	int ret;
40 	int fd;
41 
42 	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
43 
44 	timeout_begin(TIMEOUT);
45 	do {
46 		ret = connect(fd, &addr.sa, sizeof(addr.svm));
47 		timeout_check("connect");
48 	} while (ret < 0 && errno == EINTR);
49 	timeout_end();
50 
51 	if (ret != -1) {
52 		fprintf(stderr, "expected connect(2) failure, got %d\n", ret);
53 		exit(EXIT_FAILURE);
54 	}
55 	if (errno != ECONNRESET) {
56 		fprintf(stderr, "unexpected connect(2) errno %d\n", errno);
57 		exit(EXIT_FAILURE);
58 	}
59 
60 	close(fd);
61 }
62 
63 static void test_stream_bind_only_client(const struct test_opts *opts)
64 {
65 	union {
66 		struct sockaddr sa;
67 		struct sockaddr_vm svm;
68 	} addr = {
69 		.svm = {
70 			.svm_family = AF_VSOCK,
71 			.svm_port = 1234,
72 			.svm_cid = opts->peer_cid,
73 		},
74 	};
75 	int ret;
76 	int fd;
77 
78 	/* Wait for the server to be ready */
79 	control_expectln("BIND");
80 
81 	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
82 
83 	timeout_begin(TIMEOUT);
84 	do {
85 		ret = connect(fd, &addr.sa, sizeof(addr.svm));
86 		timeout_check("connect");
87 	} while (ret < 0 && errno == EINTR);
88 	timeout_end();
89 
90 	if (ret != -1) {
91 		fprintf(stderr, "expected connect(2) failure, got %d\n", ret);
92 		exit(EXIT_FAILURE);
93 	}
94 	if (errno != ECONNRESET) {
95 		fprintf(stderr, "unexpected connect(2) errno %d\n", errno);
96 		exit(EXIT_FAILURE);
97 	}
98 
99 	/* Notify the server that the client has finished */
100 	control_writeln("DONE");
101 
102 	close(fd);
103 }
104 
105 static void test_stream_bind_only_server(const struct test_opts *opts)
106 {
107 	union {
108 		struct sockaddr sa;
109 		struct sockaddr_vm svm;
110 	} addr = {
111 		.svm = {
112 			.svm_family = AF_VSOCK,
113 			.svm_port = 1234,
114 			.svm_cid = VMADDR_CID_ANY,
115 		},
116 	};
117 	int fd;
118 
119 	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
120 
121 	if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
122 		perror("bind");
123 		exit(EXIT_FAILURE);
124 	}
125 
126 	/* Notify the client that the server is ready */
127 	control_writeln("BIND");
128 
129 	/* Wait for the client to finish */
130 	control_expectln("DONE");
131 
132 	close(fd);
133 }
134 
135 static void test_stream_client_close_client(const struct test_opts *opts)
136 {
137 	int fd;
138 
139 	fd = vsock_stream_connect(opts->peer_cid, 1234);
140 	if (fd < 0) {
141 		perror("connect");
142 		exit(EXIT_FAILURE);
143 	}
144 
145 	send_byte(fd, 1, 0);
146 	close(fd);
147 }
148 
149 static void test_stream_client_close_server(const struct test_opts *opts)
150 {
151 	int fd;
152 
153 	fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
154 	if (fd < 0) {
155 		perror("accept");
156 		exit(EXIT_FAILURE);
157 	}
158 
159 	/* Wait for the remote to close the connection, before check
160 	 * -EPIPE error on send.
161 	 */
162 	vsock_wait_remote_close(fd);
163 
164 	send_byte(fd, -EPIPE, 0);
165 	recv_byte(fd, 1, 0);
166 	recv_byte(fd, 0, 0);
167 	close(fd);
168 }
169 
170 static void test_stream_server_close_client(const struct test_opts *opts)
171 {
172 	int fd;
173 
174 	fd = vsock_stream_connect(opts->peer_cid, 1234);
175 	if (fd < 0) {
176 		perror("connect");
177 		exit(EXIT_FAILURE);
178 	}
179 
180 	/* Wait for the remote to close the connection, before check
181 	 * -EPIPE error on send.
182 	 */
183 	vsock_wait_remote_close(fd);
184 
185 	send_byte(fd, -EPIPE, 0);
186 	recv_byte(fd, 1, 0);
187 	recv_byte(fd, 0, 0);
188 	close(fd);
189 }
190 
191 static void test_stream_server_close_server(const struct test_opts *opts)
192 {
193 	int fd;
194 
195 	fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
196 	if (fd < 0) {
197 		perror("accept");
198 		exit(EXIT_FAILURE);
199 	}
200 
201 	send_byte(fd, 1, 0);
202 	close(fd);
203 }
204 
205 /* With the standard socket sizes, VMCI is able to support about 100
206  * concurrent stream connections.
207  */
208 #define MULTICONN_NFDS 100
209 
210 static void test_stream_multiconn_client(const struct test_opts *opts)
211 {
212 	int fds[MULTICONN_NFDS];
213 	int i;
214 
215 	for (i = 0; i < MULTICONN_NFDS; i++) {
216 		fds[i] = vsock_stream_connect(opts->peer_cid, 1234);
217 		if (fds[i] < 0) {
218 			perror("connect");
219 			exit(EXIT_FAILURE);
220 		}
221 	}
222 
223 	for (i = 0; i < MULTICONN_NFDS; i++) {
224 		if (i % 2)
225 			recv_byte(fds[i], 1, 0);
226 		else
227 			send_byte(fds[i], 1, 0);
228 	}
229 
230 	for (i = 0; i < MULTICONN_NFDS; i++)
231 		close(fds[i]);
232 }
233 
234 static void test_stream_multiconn_server(const struct test_opts *opts)
235 {
236 	int fds[MULTICONN_NFDS];
237 	int i;
238 
239 	for (i = 0; i < MULTICONN_NFDS; i++) {
240 		fds[i] = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
241 		if (fds[i] < 0) {
242 			perror("accept");
243 			exit(EXIT_FAILURE);
244 		}
245 	}
246 
247 	for (i = 0; i < MULTICONN_NFDS; i++) {
248 		if (i % 2)
249 			send_byte(fds[i], 1, 0);
250 		else
251 			recv_byte(fds[i], 1, 0);
252 	}
253 
254 	for (i = 0; i < MULTICONN_NFDS; i++)
255 		close(fds[i]);
256 }
257 
258 static void test_stream_msg_peek_client(const struct test_opts *opts)
259 {
260 	int fd;
261 
262 	fd = vsock_stream_connect(opts->peer_cid, 1234);
263 	if (fd < 0) {
264 		perror("connect");
265 		exit(EXIT_FAILURE);
266 	}
267 
268 	send_byte(fd, 1, 0);
269 	close(fd);
270 }
271 
272 static void test_stream_msg_peek_server(const struct test_opts *opts)
273 {
274 	int fd;
275 
276 	fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
277 	if (fd < 0) {
278 		perror("accept");
279 		exit(EXIT_FAILURE);
280 	}
281 
282 	recv_byte(fd, 1, MSG_PEEK);
283 	recv_byte(fd, 1, 0);
284 	close(fd);
285 }
286 
287 #define SOCK_BUF_SIZE (2 * 1024 * 1024)
288 #define MAX_MSG_SIZE (32 * 1024)
289 
290 static void test_seqpacket_msg_bounds_client(const struct test_opts *opts)
291 {
292 	unsigned long curr_hash;
293 	int page_size;
294 	int msg_count;
295 	int fd;
296 
297 	fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
298 	if (fd < 0) {
299 		perror("connect");
300 		exit(EXIT_FAILURE);
301 	}
302 
303 	/* Wait, until receiver sets buffer size. */
304 	control_expectln("SRVREADY");
305 
306 	curr_hash = 0;
307 	page_size = getpagesize();
308 	msg_count = SOCK_BUF_SIZE / MAX_MSG_SIZE;
309 
310 	for (int i = 0; i < msg_count; i++) {
311 		ssize_t send_size;
312 		size_t buf_size;
313 		int flags;
314 		void *buf;
315 
316 		/* Use "small" buffers and "big" buffers. */
317 		if (i & 1)
318 			buf_size = page_size +
319 					(rand() % (MAX_MSG_SIZE - page_size));
320 		else
321 			buf_size = 1 + (rand() % page_size);
322 
323 		buf = malloc(buf_size);
324 
325 		if (!buf) {
326 			perror("malloc");
327 			exit(EXIT_FAILURE);
328 		}
329 
330 		memset(buf, rand() & 0xff, buf_size);
331 		/* Set at least one MSG_EOR + some random. */
332 		if (i == (msg_count / 2) || (rand() & 1)) {
333 			flags = MSG_EOR;
334 			curr_hash++;
335 		} else {
336 			flags = 0;
337 		}
338 
339 		send_size = send(fd, buf, buf_size, flags);
340 
341 		if (send_size < 0) {
342 			perror("send");
343 			exit(EXIT_FAILURE);
344 		}
345 
346 		if (send_size != buf_size) {
347 			fprintf(stderr, "Invalid send size\n");
348 			exit(EXIT_FAILURE);
349 		}
350 
351 		/*
352 		 * Hash sum is computed at both client and server in
353 		 * the same way:
354 		 * H += hash('message data')
355 		 * Such hash "controls" both data integrity and message
356 		 * bounds. After data exchange, both sums are compared
357 		 * using control socket, and if message bounds wasn't
358 		 * broken - two values must be equal.
359 		 */
360 		curr_hash += hash_djb2(buf, buf_size);
361 		free(buf);
362 	}
363 
364 	control_writeln("SENDDONE");
365 	control_writeulong(curr_hash);
366 	close(fd);
367 }
368 
369 static void test_seqpacket_msg_bounds_server(const struct test_opts *opts)
370 {
371 	unsigned long sock_buf_size;
372 	unsigned long remote_hash;
373 	unsigned long curr_hash;
374 	int fd;
375 	char buf[MAX_MSG_SIZE];
376 	struct msghdr msg = {0};
377 	struct iovec iov = {0};
378 
379 	fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
380 	if (fd < 0) {
381 		perror("accept");
382 		exit(EXIT_FAILURE);
383 	}
384 
385 	sock_buf_size = SOCK_BUF_SIZE;
386 
387 	if (setsockopt(fd, AF_VSOCK, SO_VM_SOCKETS_BUFFER_MAX_SIZE,
388 		       &sock_buf_size, sizeof(sock_buf_size))) {
389 		perror("setsockopt(SO_VM_SOCKETS_BUFFER_MAX_SIZE)");
390 		exit(EXIT_FAILURE);
391 	}
392 
393 	if (setsockopt(fd, AF_VSOCK, SO_VM_SOCKETS_BUFFER_SIZE,
394 		       &sock_buf_size, sizeof(sock_buf_size))) {
395 		perror("setsockopt(SO_VM_SOCKETS_BUFFER_SIZE)");
396 		exit(EXIT_FAILURE);
397 	}
398 
399 	/* Ready to receive data. */
400 	control_writeln("SRVREADY");
401 	/* Wait, until peer sends whole data. */
402 	control_expectln("SENDDONE");
403 	iov.iov_base = buf;
404 	iov.iov_len = sizeof(buf);
405 	msg.msg_iov = &iov;
406 	msg.msg_iovlen = 1;
407 
408 	curr_hash = 0;
409 
410 	while (1) {
411 		ssize_t recv_size;
412 
413 		recv_size = recvmsg(fd, &msg, 0);
414 
415 		if (!recv_size)
416 			break;
417 
418 		if (recv_size < 0) {
419 			perror("recvmsg");
420 			exit(EXIT_FAILURE);
421 		}
422 
423 		if (msg.msg_flags & MSG_EOR)
424 			curr_hash++;
425 
426 		curr_hash += hash_djb2(msg.msg_iov[0].iov_base, recv_size);
427 	}
428 
429 	close(fd);
430 	remote_hash = control_readulong();
431 
432 	if (curr_hash != remote_hash) {
433 		fprintf(stderr, "Message bounds broken\n");
434 		exit(EXIT_FAILURE);
435 	}
436 }
437 
438 #define MESSAGE_TRUNC_SZ 32
439 static void test_seqpacket_msg_trunc_client(const struct test_opts *opts)
440 {
441 	int fd;
442 	char buf[MESSAGE_TRUNC_SZ];
443 
444 	fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
445 	if (fd < 0) {
446 		perror("connect");
447 		exit(EXIT_FAILURE);
448 	}
449 
450 	if (send(fd, buf, sizeof(buf), 0) != sizeof(buf)) {
451 		perror("send failed");
452 		exit(EXIT_FAILURE);
453 	}
454 
455 	control_writeln("SENDDONE");
456 	close(fd);
457 }
458 
459 static void test_seqpacket_msg_trunc_server(const struct test_opts *opts)
460 {
461 	int fd;
462 	char buf[MESSAGE_TRUNC_SZ / 2];
463 	struct msghdr msg = {0};
464 	struct iovec iov = {0};
465 
466 	fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
467 	if (fd < 0) {
468 		perror("accept");
469 		exit(EXIT_FAILURE);
470 	}
471 
472 	control_expectln("SENDDONE");
473 	iov.iov_base = buf;
474 	iov.iov_len = sizeof(buf);
475 	msg.msg_iov = &iov;
476 	msg.msg_iovlen = 1;
477 
478 	ssize_t ret = recvmsg(fd, &msg, MSG_TRUNC);
479 
480 	if (ret != MESSAGE_TRUNC_SZ) {
481 		printf("%zi\n", ret);
482 		perror("MSG_TRUNC doesn't work");
483 		exit(EXIT_FAILURE);
484 	}
485 
486 	if (!(msg.msg_flags & MSG_TRUNC)) {
487 		fprintf(stderr, "MSG_TRUNC expected\n");
488 		exit(EXIT_FAILURE);
489 	}
490 
491 	close(fd);
492 }
493 
494 static time_t current_nsec(void)
495 {
496 	struct timespec ts;
497 
498 	if (clock_gettime(CLOCK_REALTIME, &ts)) {
499 		perror("clock_gettime(3) failed");
500 		exit(EXIT_FAILURE);
501 	}
502 
503 	return (ts.tv_sec * 1000000000ULL) + ts.tv_nsec;
504 }
505 
506 #define RCVTIMEO_TIMEOUT_SEC 1
507 #define READ_OVERHEAD_NSEC 250000000 /* 0.25 sec */
508 
509 static void test_seqpacket_timeout_client(const struct test_opts *opts)
510 {
511 	int fd;
512 	struct timeval tv;
513 	char dummy;
514 	time_t read_enter_ns;
515 	time_t read_overhead_ns;
516 
517 	fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
518 	if (fd < 0) {
519 		perror("connect");
520 		exit(EXIT_FAILURE);
521 	}
522 
523 	tv.tv_sec = RCVTIMEO_TIMEOUT_SEC;
524 	tv.tv_usec = 0;
525 
526 	if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, (void *)&tv, sizeof(tv)) == -1) {
527 		perror("setsockopt(SO_RCVTIMEO)");
528 		exit(EXIT_FAILURE);
529 	}
530 
531 	read_enter_ns = current_nsec();
532 
533 	if (read(fd, &dummy, sizeof(dummy)) != -1) {
534 		fprintf(stderr,
535 			"expected 'dummy' read(2) failure\n");
536 		exit(EXIT_FAILURE);
537 	}
538 
539 	if (errno != EAGAIN) {
540 		perror("EAGAIN expected");
541 		exit(EXIT_FAILURE);
542 	}
543 
544 	read_overhead_ns = current_nsec() - read_enter_ns -
545 			1000000000ULL * RCVTIMEO_TIMEOUT_SEC;
546 
547 	if (read_overhead_ns > READ_OVERHEAD_NSEC) {
548 		fprintf(stderr,
549 			"too much time in read(2), %lu > %i ns\n",
550 			read_overhead_ns, READ_OVERHEAD_NSEC);
551 		exit(EXIT_FAILURE);
552 	}
553 
554 	control_writeln("WAITDONE");
555 	close(fd);
556 }
557 
558 static void test_seqpacket_timeout_server(const struct test_opts *opts)
559 {
560 	int fd;
561 
562 	fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
563 	if (fd < 0) {
564 		perror("accept");
565 		exit(EXIT_FAILURE);
566 	}
567 
568 	control_expectln("WAITDONE");
569 	close(fd);
570 }
571 
572 static void test_seqpacket_bigmsg_client(const struct test_opts *opts)
573 {
574 	unsigned long sock_buf_size;
575 	ssize_t send_size;
576 	socklen_t len;
577 	void *data;
578 	int fd;
579 
580 	len = sizeof(sock_buf_size);
581 
582 	fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
583 	if (fd < 0) {
584 		perror("connect");
585 		exit(EXIT_FAILURE);
586 	}
587 
588 	if (getsockopt(fd, AF_VSOCK, SO_VM_SOCKETS_BUFFER_SIZE,
589 		       &sock_buf_size, &len)) {
590 		perror("getsockopt");
591 		exit(EXIT_FAILURE);
592 	}
593 
594 	sock_buf_size++;
595 
596 	data = malloc(sock_buf_size);
597 	if (!data) {
598 		perror("malloc");
599 		exit(EXIT_FAILURE);
600 	}
601 
602 	send_size = send(fd, data, sock_buf_size, 0);
603 	if (send_size != -1) {
604 		fprintf(stderr, "expected 'send(2)' failure, got %zi\n",
605 			send_size);
606 		exit(EXIT_FAILURE);
607 	}
608 
609 	if (errno != EMSGSIZE) {
610 		fprintf(stderr, "expected EMSGSIZE in 'errno', got %i\n",
611 			errno);
612 		exit(EXIT_FAILURE);
613 	}
614 
615 	control_writeln("CLISENT");
616 
617 	free(data);
618 	close(fd);
619 }
620 
621 static void test_seqpacket_bigmsg_server(const struct test_opts *opts)
622 {
623 	int fd;
624 
625 	fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
626 	if (fd < 0) {
627 		perror("accept");
628 		exit(EXIT_FAILURE);
629 	}
630 
631 	control_expectln("CLISENT");
632 
633 	close(fd);
634 }
635 
636 #define BUF_PATTERN_1 'a'
637 #define BUF_PATTERN_2 'b'
638 
639 static void test_seqpacket_invalid_rec_buffer_client(const struct test_opts *opts)
640 {
641 	int fd;
642 	unsigned char *buf1;
643 	unsigned char *buf2;
644 	int buf_size = getpagesize() * 3;
645 
646 	fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
647 	if (fd < 0) {
648 		perror("connect");
649 		exit(EXIT_FAILURE);
650 	}
651 
652 	buf1 = malloc(buf_size);
653 	if (!buf1) {
654 		perror("'malloc()' for 'buf1'");
655 		exit(EXIT_FAILURE);
656 	}
657 
658 	buf2 = malloc(buf_size);
659 	if (!buf2) {
660 		perror("'malloc()' for 'buf2'");
661 		exit(EXIT_FAILURE);
662 	}
663 
664 	memset(buf1, BUF_PATTERN_1, buf_size);
665 	memset(buf2, BUF_PATTERN_2, buf_size);
666 
667 	if (send(fd, buf1, buf_size, 0) != buf_size) {
668 		perror("send failed");
669 		exit(EXIT_FAILURE);
670 	}
671 
672 	if (send(fd, buf2, buf_size, 0) != buf_size) {
673 		perror("send failed");
674 		exit(EXIT_FAILURE);
675 	}
676 
677 	close(fd);
678 }
679 
680 static void test_seqpacket_invalid_rec_buffer_server(const struct test_opts *opts)
681 {
682 	int fd;
683 	unsigned char *broken_buf;
684 	unsigned char *valid_buf;
685 	int page_size = getpagesize();
686 	int buf_size = page_size * 3;
687 	ssize_t res;
688 	int prot = PROT_READ | PROT_WRITE;
689 	int flags = MAP_PRIVATE | MAP_ANONYMOUS;
690 	int i;
691 
692 	fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
693 	if (fd < 0) {
694 		perror("accept");
695 		exit(EXIT_FAILURE);
696 	}
697 
698 	/* Setup first buffer. */
699 	broken_buf = mmap(NULL, buf_size, prot, flags, -1, 0);
700 	if (broken_buf == MAP_FAILED) {
701 		perror("mmap for 'broken_buf'");
702 		exit(EXIT_FAILURE);
703 	}
704 
705 	/* Unmap "hole" in buffer. */
706 	if (munmap(broken_buf + page_size, page_size)) {
707 		perror("'broken_buf' setup");
708 		exit(EXIT_FAILURE);
709 	}
710 
711 	valid_buf = mmap(NULL, buf_size, prot, flags, -1, 0);
712 	if (valid_buf == MAP_FAILED) {
713 		perror("mmap for 'valid_buf'");
714 		exit(EXIT_FAILURE);
715 	}
716 
717 	/* Try to fill buffer with unmapped middle. */
718 	res = read(fd, broken_buf, buf_size);
719 	if (res != -1) {
720 		fprintf(stderr,
721 			"expected 'broken_buf' read(2) failure, got %zi\n",
722 			res);
723 		exit(EXIT_FAILURE);
724 	}
725 
726 	if (errno != ENOMEM) {
727 		perror("unexpected errno of 'broken_buf'");
728 		exit(EXIT_FAILURE);
729 	}
730 
731 	/* Try to fill valid buffer. */
732 	res = read(fd, valid_buf, buf_size);
733 	if (res < 0) {
734 		perror("unexpected 'valid_buf' read(2) failure");
735 		exit(EXIT_FAILURE);
736 	}
737 
738 	if (res != buf_size) {
739 		fprintf(stderr,
740 			"invalid 'valid_buf' read(2), expected %i, got %zi\n",
741 			buf_size, res);
742 		exit(EXIT_FAILURE);
743 	}
744 
745 	for (i = 0; i < buf_size; i++) {
746 		if (valid_buf[i] != BUF_PATTERN_2) {
747 			fprintf(stderr,
748 				"invalid pattern for 'valid_buf' at %i, expected %hhX, got %hhX\n",
749 				i, BUF_PATTERN_2, valid_buf[i]);
750 			exit(EXIT_FAILURE);
751 		}
752 	}
753 
754 	/* Unmap buffers. */
755 	munmap(broken_buf, page_size);
756 	munmap(broken_buf + page_size * 2, page_size);
757 	munmap(valid_buf, buf_size);
758 	close(fd);
759 }
760 
761 #define RCVLOWAT_BUF_SIZE 128
762 
763 static void test_stream_poll_rcvlowat_server(const struct test_opts *opts)
764 {
765 	int fd;
766 	int i;
767 
768 	fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
769 	if (fd < 0) {
770 		perror("accept");
771 		exit(EXIT_FAILURE);
772 	}
773 
774 	/* Send 1 byte. */
775 	send_byte(fd, 1, 0);
776 
777 	control_writeln("SRVSENT");
778 
779 	/* Wait until client is ready to receive rest of data. */
780 	control_expectln("CLNSENT");
781 
782 	for (i = 0; i < RCVLOWAT_BUF_SIZE - 1; i++)
783 		send_byte(fd, 1, 0);
784 
785 	/* Keep socket in active state. */
786 	control_expectln("POLLDONE");
787 
788 	close(fd);
789 }
790 
791 static void test_stream_poll_rcvlowat_client(const struct test_opts *opts)
792 {
793 	unsigned long lowat_val = RCVLOWAT_BUF_SIZE;
794 	char buf[RCVLOWAT_BUF_SIZE];
795 	struct pollfd fds;
796 	ssize_t read_res;
797 	short poll_flags;
798 	int fd;
799 
800 	fd = vsock_stream_connect(opts->peer_cid, 1234);
801 	if (fd < 0) {
802 		perror("connect");
803 		exit(EXIT_FAILURE);
804 	}
805 
806 	if (setsockopt(fd, SOL_SOCKET, SO_RCVLOWAT,
807 		       &lowat_val, sizeof(lowat_val))) {
808 		perror("setsockopt(SO_RCVLOWAT)");
809 		exit(EXIT_FAILURE);
810 	}
811 
812 	control_expectln("SRVSENT");
813 
814 	/* At this point, server sent 1 byte. */
815 	fds.fd = fd;
816 	poll_flags = POLLIN | POLLRDNORM;
817 	fds.events = poll_flags;
818 
819 	/* Try to wait for 1 sec. */
820 	if (poll(&fds, 1, 1000) < 0) {
821 		perror("poll");
822 		exit(EXIT_FAILURE);
823 	}
824 
825 	/* poll() must return nothing. */
826 	if (fds.revents) {
827 		fprintf(stderr, "Unexpected poll result %hx\n",
828 			fds.revents);
829 		exit(EXIT_FAILURE);
830 	}
831 
832 	/* Tell server to send rest of data. */
833 	control_writeln("CLNSENT");
834 
835 	/* Poll for data. */
836 	if (poll(&fds, 1, 10000) < 0) {
837 		perror("poll");
838 		exit(EXIT_FAILURE);
839 	}
840 
841 	/* Only these two bits are expected. */
842 	if (fds.revents != poll_flags) {
843 		fprintf(stderr, "Unexpected poll result %hx\n",
844 			fds.revents);
845 		exit(EXIT_FAILURE);
846 	}
847 
848 	/* Use MSG_DONTWAIT, if call is going to wait, EAGAIN
849 	 * will be returned.
850 	 */
851 	read_res = recv(fd, buf, sizeof(buf), MSG_DONTWAIT);
852 	if (read_res != RCVLOWAT_BUF_SIZE) {
853 		fprintf(stderr, "Unexpected recv result %zi\n",
854 			read_res);
855 		exit(EXIT_FAILURE);
856 	}
857 
858 	control_writeln("POLLDONE");
859 
860 	close(fd);
861 }
862 
863 static struct test_case test_cases[] = {
864 	{
865 		.name = "SOCK_STREAM connection reset",
866 		.run_client = test_stream_connection_reset,
867 	},
868 	{
869 		.name = "SOCK_STREAM bind only",
870 		.run_client = test_stream_bind_only_client,
871 		.run_server = test_stream_bind_only_server,
872 	},
873 	{
874 		.name = "SOCK_STREAM client close",
875 		.run_client = test_stream_client_close_client,
876 		.run_server = test_stream_client_close_server,
877 	},
878 	{
879 		.name = "SOCK_STREAM server close",
880 		.run_client = test_stream_server_close_client,
881 		.run_server = test_stream_server_close_server,
882 	},
883 	{
884 		.name = "SOCK_STREAM multiple connections",
885 		.run_client = test_stream_multiconn_client,
886 		.run_server = test_stream_multiconn_server,
887 	},
888 	{
889 		.name = "SOCK_STREAM MSG_PEEK",
890 		.run_client = test_stream_msg_peek_client,
891 		.run_server = test_stream_msg_peek_server,
892 	},
893 	{
894 		.name = "SOCK_SEQPACKET msg bounds",
895 		.run_client = test_seqpacket_msg_bounds_client,
896 		.run_server = test_seqpacket_msg_bounds_server,
897 	},
898 	{
899 		.name = "SOCK_SEQPACKET MSG_TRUNC flag",
900 		.run_client = test_seqpacket_msg_trunc_client,
901 		.run_server = test_seqpacket_msg_trunc_server,
902 	},
903 	{
904 		.name = "SOCK_SEQPACKET timeout",
905 		.run_client = test_seqpacket_timeout_client,
906 		.run_server = test_seqpacket_timeout_server,
907 	},
908 	{
909 		.name = "SOCK_SEQPACKET invalid receive buffer",
910 		.run_client = test_seqpacket_invalid_rec_buffer_client,
911 		.run_server = test_seqpacket_invalid_rec_buffer_server,
912 	},
913 	{
914 		.name = "SOCK_STREAM poll() + SO_RCVLOWAT",
915 		.run_client = test_stream_poll_rcvlowat_client,
916 		.run_server = test_stream_poll_rcvlowat_server,
917 	},
918 	{
919 		.name = "SOCK_SEQPACKET big message",
920 		.run_client = test_seqpacket_bigmsg_client,
921 		.run_server = test_seqpacket_bigmsg_server,
922 	},
923 	{},
924 };
925 
926 static const char optstring[] = "";
927 static const struct option longopts[] = {
928 	{
929 		.name = "control-host",
930 		.has_arg = required_argument,
931 		.val = 'H',
932 	},
933 	{
934 		.name = "control-port",
935 		.has_arg = required_argument,
936 		.val = 'P',
937 	},
938 	{
939 		.name = "mode",
940 		.has_arg = required_argument,
941 		.val = 'm',
942 	},
943 	{
944 		.name = "peer-cid",
945 		.has_arg = required_argument,
946 		.val = 'p',
947 	},
948 	{
949 		.name = "list",
950 		.has_arg = no_argument,
951 		.val = 'l',
952 	},
953 	{
954 		.name = "skip",
955 		.has_arg = required_argument,
956 		.val = 's',
957 	},
958 	{
959 		.name = "help",
960 		.has_arg = no_argument,
961 		.val = '?',
962 	},
963 	{},
964 };
965 
966 static void usage(void)
967 {
968 	fprintf(stderr, "Usage: vsock_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid> [--list] [--skip=<test_id>]\n"
969 		"\n"
970 		"  Server: vsock_test --control-port=1234 --mode=server --peer-cid=3\n"
971 		"  Client: vsock_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
972 		"\n"
973 		"Run vsock.ko tests.  Must be launched in both guest\n"
974 		"and host.  One side must use --mode=client and\n"
975 		"the other side must use --mode=server.\n"
976 		"\n"
977 		"A TCP control socket connection is used to coordinate tests\n"
978 		"between the client and the server.  The server requires a\n"
979 		"listen address and the client requires an address to\n"
980 		"connect to.\n"
981 		"\n"
982 		"The CID of the other side must be given with --peer-cid=<cid>.\n"
983 		"\n"
984 		"Options:\n"
985 		"  --help                 This help message\n"
986 		"  --control-host <host>  Server IP address to connect to\n"
987 		"  --control-port <port>  Server port to listen on/connect to\n"
988 		"  --mode client|server   Server or client mode\n"
989 		"  --peer-cid <cid>       CID of the other side\n"
990 		"  --list                 List of tests that will be executed\n"
991 		"  --skip <test_id>       Test ID to skip;\n"
992 		"                         use multiple --skip options to skip more tests\n"
993 		);
994 	exit(EXIT_FAILURE);
995 }
996 
997 int main(int argc, char **argv)
998 {
999 	const char *control_host = NULL;
1000 	const char *control_port = NULL;
1001 	struct test_opts opts = {
1002 		.mode = TEST_MODE_UNSET,
1003 		.peer_cid = VMADDR_CID_ANY,
1004 	};
1005 
1006 	srand(time(NULL));
1007 	init_signals();
1008 
1009 	for (;;) {
1010 		int opt = getopt_long(argc, argv, optstring, longopts, NULL);
1011 
1012 		if (opt == -1)
1013 			break;
1014 
1015 		switch (opt) {
1016 		case 'H':
1017 			control_host = optarg;
1018 			break;
1019 		case 'm':
1020 			if (strcmp(optarg, "client") == 0)
1021 				opts.mode = TEST_MODE_CLIENT;
1022 			else if (strcmp(optarg, "server") == 0)
1023 				opts.mode = TEST_MODE_SERVER;
1024 			else {
1025 				fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
1026 				return EXIT_FAILURE;
1027 			}
1028 			break;
1029 		case 'p':
1030 			opts.peer_cid = parse_cid(optarg);
1031 			break;
1032 		case 'P':
1033 			control_port = optarg;
1034 			break;
1035 		case 'l':
1036 			list_tests(test_cases);
1037 			break;
1038 		case 's':
1039 			skip_test(test_cases, ARRAY_SIZE(test_cases) - 1,
1040 				  optarg);
1041 			break;
1042 		case '?':
1043 		default:
1044 			usage();
1045 		}
1046 	}
1047 
1048 	if (!control_port)
1049 		usage();
1050 	if (opts.mode == TEST_MODE_UNSET)
1051 		usage();
1052 	if (opts.peer_cid == VMADDR_CID_ANY)
1053 		usage();
1054 
1055 	if (!control_host) {
1056 		if (opts.mode != TEST_MODE_SERVER)
1057 			usage();
1058 		control_host = "0.0.0.0";
1059 	}
1060 
1061 	control_init(control_host, control_port,
1062 		     opts.mode == TEST_MODE_SERVER);
1063 
1064 	run_tests(test_cases, &opts);
1065 
1066 	control_cleanup();
1067 	return EXIT_SUCCESS;
1068 }
1069