xref: /linux/tools/testing/vsock/vsock_diag_test.c (revision eb01fe7abbe2d0b38824d2a93fdb4cc3eaf2ccc1)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * vsock_diag_test - vsock_diag.ko test suite
4  *
5  * Copyright (C) 2017 Red Hat, Inc.
6  *
7  * Author: Stefan Hajnoczi <stefanha@redhat.com>
8  */
9 
10 #include <getopt.h>
11 #include <stdio.h>
12 #include <stdlib.h>
13 #include <string.h>
14 #include <errno.h>
15 #include <unistd.h>
16 #include <sys/stat.h>
17 #include <sys/types.h>
18 #include <linux/list.h>
19 #include <linux/net.h>
20 #include <linux/netlink.h>
21 #include <linux/sock_diag.h>
22 #include <linux/vm_sockets_diag.h>
23 #include <netinet/tcp.h>
24 
25 #include "timeout.h"
26 #include "control.h"
27 #include "util.h"
28 
29 /* Per-socket status */
30 struct vsock_stat {
31 	struct list_head list;
32 	struct vsock_diag_msg msg;
33 };
34 
35 static const char *sock_type_str(int type)
36 {
37 	switch (type) {
38 	case SOCK_DGRAM:
39 		return "DGRAM";
40 	case SOCK_STREAM:
41 		return "STREAM";
42 	case SOCK_SEQPACKET:
43 		return "SEQPACKET";
44 	default:
45 		return "INVALID TYPE";
46 	}
47 }
48 
49 static const char *sock_state_str(int state)
50 {
51 	switch (state) {
52 	case TCP_CLOSE:
53 		return "UNCONNECTED";
54 	case TCP_SYN_SENT:
55 		return "CONNECTING";
56 	case TCP_ESTABLISHED:
57 		return "CONNECTED";
58 	case TCP_CLOSING:
59 		return "DISCONNECTING";
60 	case TCP_LISTEN:
61 		return "LISTEN";
62 	default:
63 		return "INVALID STATE";
64 	}
65 }
66 
67 static const char *sock_shutdown_str(int shutdown)
68 {
69 	switch (shutdown) {
70 	case 1:
71 		return "RCV_SHUTDOWN";
72 	case 2:
73 		return "SEND_SHUTDOWN";
74 	case 3:
75 		return "RCV_SHUTDOWN | SEND_SHUTDOWN";
76 	default:
77 		return "0";
78 	}
79 }
80 
81 static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port)
82 {
83 	if (cid == VMADDR_CID_ANY)
84 		fprintf(fp, "*:");
85 	else
86 		fprintf(fp, "%u:", cid);
87 
88 	if (port == VMADDR_PORT_ANY)
89 		fprintf(fp, "*");
90 	else
91 		fprintf(fp, "%u", port);
92 }
93 
94 static void print_vsock_stat(FILE *fp, struct vsock_stat *st)
95 {
96 	print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port);
97 	fprintf(fp, " ");
98 	print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port);
99 	fprintf(fp, " %s %s %s %u\n",
100 		sock_type_str(st->msg.vdiag_type),
101 		sock_state_str(st->msg.vdiag_state),
102 		sock_shutdown_str(st->msg.vdiag_shutdown),
103 		st->msg.vdiag_ino);
104 }
105 
106 static void print_vsock_stats(FILE *fp, struct list_head *head)
107 {
108 	struct vsock_stat *st;
109 
110 	list_for_each_entry(st, head, list)
111 		print_vsock_stat(fp, st);
112 }
113 
114 static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd)
115 {
116 	struct vsock_stat *st;
117 	struct stat stat;
118 
119 	if (fstat(fd, &stat) < 0) {
120 		perror("fstat");
121 		exit(EXIT_FAILURE);
122 	}
123 
124 	list_for_each_entry(st, head, list)
125 		if (st->msg.vdiag_ino == stat.st_ino)
126 			return st;
127 
128 	fprintf(stderr, "cannot find fd %d\n", fd);
129 	exit(EXIT_FAILURE);
130 }
131 
132 static void check_no_sockets(struct list_head *head)
133 {
134 	if (!list_empty(head)) {
135 		fprintf(stderr, "expected no sockets\n");
136 		print_vsock_stats(stderr, head);
137 		exit(1);
138 	}
139 }
140 
141 static void check_num_sockets(struct list_head *head, int expected)
142 {
143 	struct list_head *node;
144 	int n = 0;
145 
146 	list_for_each(node, head)
147 		n++;
148 
149 	if (n != expected) {
150 		fprintf(stderr, "expected %d sockets, found %d\n",
151 			expected, n);
152 		print_vsock_stats(stderr, head);
153 		exit(EXIT_FAILURE);
154 	}
155 }
156 
157 static void check_socket_state(struct vsock_stat *st, __u8 state)
158 {
159 	if (st->msg.vdiag_state != state) {
160 		fprintf(stderr, "expected socket state %#x, got %#x\n",
161 			state, st->msg.vdiag_state);
162 		exit(EXIT_FAILURE);
163 	}
164 }
165 
166 static void send_req(int fd)
167 {
168 	struct sockaddr_nl nladdr = {
169 		.nl_family = AF_NETLINK,
170 	};
171 	struct {
172 		struct nlmsghdr nlh;
173 		struct vsock_diag_req vreq;
174 	} req = {
175 		.nlh = {
176 			.nlmsg_len = sizeof(req),
177 			.nlmsg_type = SOCK_DIAG_BY_FAMILY,
178 			.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
179 		},
180 		.vreq = {
181 			.sdiag_family = AF_VSOCK,
182 			.vdiag_states = ~(__u32)0,
183 		},
184 	};
185 	struct iovec iov = {
186 		.iov_base = &req,
187 		.iov_len = sizeof(req),
188 	};
189 	struct msghdr msg = {
190 		.msg_name = &nladdr,
191 		.msg_namelen = sizeof(nladdr),
192 		.msg_iov = &iov,
193 		.msg_iovlen = 1,
194 	};
195 
196 	for (;;) {
197 		if (sendmsg(fd, &msg, 0) < 0) {
198 			if (errno == EINTR)
199 				continue;
200 
201 			perror("sendmsg");
202 			exit(EXIT_FAILURE);
203 		}
204 
205 		return;
206 	}
207 }
208 
209 static ssize_t recv_resp(int fd, void *buf, size_t len)
210 {
211 	struct sockaddr_nl nladdr = {
212 		.nl_family = AF_NETLINK,
213 	};
214 	struct iovec iov = {
215 		.iov_base = buf,
216 		.iov_len = len,
217 	};
218 	struct msghdr msg = {
219 		.msg_name = &nladdr,
220 		.msg_namelen = sizeof(nladdr),
221 		.msg_iov = &iov,
222 		.msg_iovlen = 1,
223 	};
224 	ssize_t ret;
225 
226 	do {
227 		ret = recvmsg(fd, &msg, 0);
228 	} while (ret < 0 && errno == EINTR);
229 
230 	if (ret < 0) {
231 		perror("recvmsg");
232 		exit(EXIT_FAILURE);
233 	}
234 
235 	return ret;
236 }
237 
238 static void add_vsock_stat(struct list_head *sockets,
239 			   const struct vsock_diag_msg *resp)
240 {
241 	struct vsock_stat *st;
242 
243 	st = malloc(sizeof(*st));
244 	if (!st) {
245 		perror("malloc");
246 		exit(EXIT_FAILURE);
247 	}
248 
249 	st->msg = *resp;
250 	list_add_tail(&st->list, sockets);
251 }
252 
253 /*
254  * Read vsock stats into a list.
255  */
256 static void read_vsock_stat(struct list_head *sockets)
257 {
258 	long buf[8192 / sizeof(long)];
259 	int fd;
260 
261 	fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);
262 	if (fd < 0) {
263 		perror("socket");
264 		exit(EXIT_FAILURE);
265 	}
266 
267 	send_req(fd);
268 
269 	for (;;) {
270 		const struct nlmsghdr *h;
271 		ssize_t ret;
272 
273 		ret = recv_resp(fd, buf, sizeof(buf));
274 		if (ret == 0)
275 			goto done;
276 		if (ret < sizeof(*h)) {
277 			fprintf(stderr, "short read of %zd bytes\n", ret);
278 			exit(EXIT_FAILURE);
279 		}
280 
281 		h = (struct nlmsghdr *)buf;
282 
283 		while (NLMSG_OK(h, ret)) {
284 			if (h->nlmsg_type == NLMSG_DONE)
285 				goto done;
286 
287 			if (h->nlmsg_type == NLMSG_ERROR) {
288 				const struct nlmsgerr *err = NLMSG_DATA(h);
289 
290 				if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err)))
291 					fprintf(stderr, "NLMSG_ERROR\n");
292 				else {
293 					errno = -err->error;
294 					perror("NLMSG_ERROR");
295 				}
296 
297 				exit(EXIT_FAILURE);
298 			}
299 
300 			if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
301 				fprintf(stderr, "unexpected nlmsg_type %#x\n",
302 					h->nlmsg_type);
303 				exit(EXIT_FAILURE);
304 			}
305 			if (h->nlmsg_len <
306 			    NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) {
307 				fprintf(stderr, "short vsock_diag_msg\n");
308 				exit(EXIT_FAILURE);
309 			}
310 
311 			add_vsock_stat(sockets, NLMSG_DATA(h));
312 
313 			h = NLMSG_NEXT(h, ret);
314 		}
315 	}
316 
317 done:
318 	close(fd);
319 }
320 
321 static void free_sock_stat(struct list_head *sockets)
322 {
323 	struct vsock_stat *st;
324 	struct vsock_stat *next;
325 
326 	list_for_each_entry_safe(st, next, sockets, list)
327 		free(st);
328 }
329 
330 static void test_no_sockets(const struct test_opts *opts)
331 {
332 	LIST_HEAD(sockets);
333 
334 	read_vsock_stat(&sockets);
335 
336 	check_no_sockets(&sockets);
337 }
338 
339 static void test_listen_socket_server(const struct test_opts *opts)
340 {
341 	union {
342 		struct sockaddr sa;
343 		struct sockaddr_vm svm;
344 	} addr = {
345 		.svm = {
346 			.svm_family = AF_VSOCK,
347 			.svm_port = opts->peer_port,
348 			.svm_cid = VMADDR_CID_ANY,
349 		},
350 	};
351 	LIST_HEAD(sockets);
352 	struct vsock_stat *st;
353 	int fd;
354 
355 	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
356 
357 	if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
358 		perror("bind");
359 		exit(EXIT_FAILURE);
360 	}
361 
362 	if (listen(fd, 1) < 0) {
363 		perror("listen");
364 		exit(EXIT_FAILURE);
365 	}
366 
367 	read_vsock_stat(&sockets);
368 
369 	check_num_sockets(&sockets, 1);
370 	st = find_vsock_stat(&sockets, fd);
371 	check_socket_state(st, TCP_LISTEN);
372 
373 	close(fd);
374 	free_sock_stat(&sockets);
375 }
376 
377 static void test_connect_client(const struct test_opts *opts)
378 {
379 	int fd;
380 	LIST_HEAD(sockets);
381 	struct vsock_stat *st;
382 
383 	fd = vsock_stream_connect(opts->peer_cid, opts->peer_port);
384 	if (fd < 0) {
385 		perror("connect");
386 		exit(EXIT_FAILURE);
387 	}
388 
389 	read_vsock_stat(&sockets);
390 
391 	check_num_sockets(&sockets, 1);
392 	st = find_vsock_stat(&sockets, fd);
393 	check_socket_state(st, TCP_ESTABLISHED);
394 
395 	control_expectln("DONE");
396 	control_writeln("DONE");
397 
398 	close(fd);
399 	free_sock_stat(&sockets);
400 }
401 
402 static void test_connect_server(const struct test_opts *opts)
403 {
404 	struct vsock_stat *st;
405 	LIST_HEAD(sockets);
406 	int client_fd;
407 
408 	client_fd = vsock_stream_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
409 	if (client_fd < 0) {
410 		perror("accept");
411 		exit(EXIT_FAILURE);
412 	}
413 
414 	read_vsock_stat(&sockets);
415 
416 	check_num_sockets(&sockets, 1);
417 	st = find_vsock_stat(&sockets, client_fd);
418 	check_socket_state(st, TCP_ESTABLISHED);
419 
420 	control_writeln("DONE");
421 	control_expectln("DONE");
422 
423 	close(client_fd);
424 	free_sock_stat(&sockets);
425 }
426 
427 static struct test_case test_cases[] = {
428 	{
429 		.name = "No sockets",
430 		.run_server = test_no_sockets,
431 	},
432 	{
433 		.name = "Listen socket",
434 		.run_server = test_listen_socket_server,
435 	},
436 	{
437 		.name = "Connect",
438 		.run_client = test_connect_client,
439 		.run_server = test_connect_server,
440 	},
441 	{},
442 };
443 
444 static const char optstring[] = "";
445 static const struct option longopts[] = {
446 	{
447 		.name = "control-host",
448 		.has_arg = required_argument,
449 		.val = 'H',
450 	},
451 	{
452 		.name = "control-port",
453 		.has_arg = required_argument,
454 		.val = 'P',
455 	},
456 	{
457 		.name = "mode",
458 		.has_arg = required_argument,
459 		.val = 'm',
460 	},
461 	{
462 		.name = "peer-cid",
463 		.has_arg = required_argument,
464 		.val = 'p',
465 	},
466 	{
467 		.name = "peer-port",
468 		.has_arg = required_argument,
469 		.val = 'q',
470 	},
471 	{
472 		.name = "list",
473 		.has_arg = no_argument,
474 		.val = 'l',
475 	},
476 	{
477 		.name = "skip",
478 		.has_arg = required_argument,
479 		.val = 's',
480 	},
481 	{
482 		.name = "help",
483 		.has_arg = no_argument,
484 		.val = '?',
485 	},
486 	{},
487 };
488 
489 static void usage(void)
490 {
491 	fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid> [--peer-port=<port>] [--list] [--skip=<test_id>]\n"
492 		"\n"
493 		"  Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
494 		"  Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
495 		"\n"
496 		"Run vsock_diag.ko tests.  Must be launched in both\n"
497 		"guest and host.  One side must use --mode=client and\n"
498 		"the other side must use --mode=server.\n"
499 		"\n"
500 		"A TCP control socket connection is used to coordinate tests\n"
501 		"between the client and the server.  The server requires a\n"
502 		"listen address and the client requires an address to\n"
503 		"connect to.\n"
504 		"\n"
505 		"The CID of the other side must be given with --peer-cid=<cid>.\n"
506 		"\n"
507 		"Options:\n"
508 		"  --help                 This help message\n"
509 		"  --control-host <host>  Server IP address to connect to\n"
510 		"  --control-port <port>  Server port to listen on/connect to\n"
511 		"  --mode client|server   Server or client mode\n"
512 		"  --peer-cid <cid>       CID of the other side\n"
513 		"  --peer-port <port>     AF_VSOCK port used for the test [default: %d]\n"
514 		"  --list                 List of tests that will be executed\n"
515 		"  --skip <test_id>       Test ID to skip;\n"
516 		"                         use multiple --skip options to skip more tests\n",
517 		DEFAULT_PEER_PORT
518 		);
519 	exit(EXIT_FAILURE);
520 }
521 
522 int main(int argc, char **argv)
523 {
524 	const char *control_host = NULL;
525 	const char *control_port = NULL;
526 	struct test_opts opts = {
527 		.mode = TEST_MODE_UNSET,
528 		.peer_cid = VMADDR_CID_ANY,
529 		.peer_port = DEFAULT_PEER_PORT,
530 	};
531 
532 	init_signals();
533 
534 	for (;;) {
535 		int opt = getopt_long(argc, argv, optstring, longopts, NULL);
536 
537 		if (opt == -1)
538 			break;
539 
540 		switch (opt) {
541 		case 'H':
542 			control_host = optarg;
543 			break;
544 		case 'm':
545 			if (strcmp(optarg, "client") == 0)
546 				opts.mode = TEST_MODE_CLIENT;
547 			else if (strcmp(optarg, "server") == 0)
548 				opts.mode = TEST_MODE_SERVER;
549 			else {
550 				fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
551 				return EXIT_FAILURE;
552 			}
553 			break;
554 		case 'p':
555 			opts.peer_cid = parse_cid(optarg);
556 			break;
557 		case 'q':
558 			opts.peer_port = parse_port(optarg);
559 			break;
560 		case 'P':
561 			control_port = optarg;
562 			break;
563 		case 'l':
564 			list_tests(test_cases);
565 			break;
566 		case 's':
567 			skip_test(test_cases, ARRAY_SIZE(test_cases) - 1,
568 				  optarg);
569 			break;
570 		case '?':
571 		default:
572 			usage();
573 		}
574 	}
575 
576 	if (!control_port)
577 		usage();
578 	if (opts.mode == TEST_MODE_UNSET)
579 		usage();
580 	if (opts.peer_cid == VMADDR_CID_ANY)
581 		usage();
582 
583 	if (!control_host) {
584 		if (opts.mode != TEST_MODE_SERVER)
585 			usage();
586 		control_host = "0.0.0.0";
587 	}
588 
589 	control_init(control_host, control_port,
590 		     opts.mode == TEST_MODE_SERVER);
591 
592 	run_tests(test_cases, &opts);
593 
594 	control_cleanup();
595 	return EXIT_SUCCESS;
596 }
597