1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3 * vsock_diag_test - vsock_diag.ko test suite
4 *
5 * Copyright (C) 2017 Red Hat, Inc.
6 *
7 * Author: Stefan Hajnoczi <stefanha@redhat.com>
8 */
9
10 #include <getopt.h>
11 #include <stdio.h>
12 #include <stdlib.h>
13 #include <string.h>
14 #include <errno.h>
15 #include <unistd.h>
16 #include <sys/stat.h>
17 #include <sys/types.h>
18 #include <linux/list.h>
19 #include <linux/net.h>
20 #include <linux/netlink.h>
21 #include <linux/sock_diag.h>
22 #include <linux/vm_sockets_diag.h>
23 #include <netinet/tcp.h>
24
25 #include "timeout.h"
26 #include "control.h"
27 #include "util.h"
28
29 /* Per-socket status */
30 struct vsock_stat {
31 struct list_head list;
32 struct vsock_diag_msg msg;
33 };
34
sock_type_str(int type)35 static const char *sock_type_str(int type)
36 {
37 switch (type) {
38 case SOCK_DGRAM:
39 return "DGRAM";
40 case SOCK_STREAM:
41 return "STREAM";
42 case SOCK_SEQPACKET:
43 return "SEQPACKET";
44 default:
45 return "INVALID TYPE";
46 }
47 }
48
sock_state_str(int state)49 static const char *sock_state_str(int state)
50 {
51 switch (state) {
52 case TCP_CLOSE:
53 return "UNCONNECTED";
54 case TCP_SYN_SENT:
55 return "CONNECTING";
56 case TCP_ESTABLISHED:
57 return "CONNECTED";
58 case TCP_CLOSING:
59 return "DISCONNECTING";
60 case TCP_LISTEN:
61 return "LISTEN";
62 default:
63 return "INVALID STATE";
64 }
65 }
66
sock_shutdown_str(int shutdown)67 static const char *sock_shutdown_str(int shutdown)
68 {
69 switch (shutdown) {
70 case 1:
71 return "RCV_SHUTDOWN";
72 case 2:
73 return "SEND_SHUTDOWN";
74 case 3:
75 return "RCV_SHUTDOWN | SEND_SHUTDOWN";
76 default:
77 return "0";
78 }
79 }
80
print_vsock_addr(FILE * fp,unsigned int cid,unsigned int port)81 static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port)
82 {
83 if (cid == VMADDR_CID_ANY)
84 fprintf(fp, "*:");
85 else
86 fprintf(fp, "%u:", cid);
87
88 if (port == VMADDR_PORT_ANY)
89 fprintf(fp, "*");
90 else
91 fprintf(fp, "%u", port);
92 }
93
print_vsock_stat(FILE * fp,struct vsock_stat * st)94 static void print_vsock_stat(FILE *fp, struct vsock_stat *st)
95 {
96 print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port);
97 fprintf(fp, " ");
98 print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port);
99 fprintf(fp, " %s %s %s %u\n",
100 sock_type_str(st->msg.vdiag_type),
101 sock_state_str(st->msg.vdiag_state),
102 sock_shutdown_str(st->msg.vdiag_shutdown),
103 st->msg.vdiag_ino);
104 }
105
print_vsock_stats(FILE * fp,struct list_head * head)106 static void print_vsock_stats(FILE *fp, struct list_head *head)
107 {
108 struct vsock_stat *st;
109
110 list_for_each_entry(st, head, list)
111 print_vsock_stat(fp, st);
112 }
113
find_vsock_stat(struct list_head * head,int fd)114 static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd)
115 {
116 struct vsock_stat *st;
117 struct stat stat;
118
119 if (fstat(fd, &stat) < 0) {
120 perror("fstat");
121 exit(EXIT_FAILURE);
122 }
123
124 list_for_each_entry(st, head, list)
125 if (st->msg.vdiag_ino == stat.st_ino)
126 return st;
127
128 fprintf(stderr, "cannot find fd %d\n", fd);
129 exit(EXIT_FAILURE);
130 }
131
check_no_sockets(struct list_head * head)132 static void check_no_sockets(struct list_head *head)
133 {
134 if (!list_empty(head)) {
135 fprintf(stderr, "expected no sockets\n");
136 print_vsock_stats(stderr, head);
137 exit(1);
138 }
139 }
140
check_num_sockets(struct list_head * head,int expected)141 static void check_num_sockets(struct list_head *head, int expected)
142 {
143 struct list_head *node;
144 int n = 0;
145
146 list_for_each(node, head)
147 n++;
148
149 if (n != expected) {
150 fprintf(stderr, "expected %d sockets, found %d\n",
151 expected, n);
152 print_vsock_stats(stderr, head);
153 exit(EXIT_FAILURE);
154 }
155 }
156
check_socket_state(struct vsock_stat * st,__u8 state)157 static void check_socket_state(struct vsock_stat *st, __u8 state)
158 {
159 if (st->msg.vdiag_state != state) {
160 fprintf(stderr, "expected socket state %#x, got %#x\n",
161 state, st->msg.vdiag_state);
162 exit(EXIT_FAILURE);
163 }
164 }
165
send_req(int fd)166 static void send_req(int fd)
167 {
168 struct sockaddr_nl nladdr = {
169 .nl_family = AF_NETLINK,
170 };
171 struct {
172 struct nlmsghdr nlh;
173 struct vsock_diag_req vreq;
174 } req = {
175 .nlh = {
176 .nlmsg_len = sizeof(req),
177 .nlmsg_type = SOCK_DIAG_BY_FAMILY,
178 .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
179 },
180 .vreq = {
181 .sdiag_family = AF_VSOCK,
182 .vdiag_states = ~(__u32)0,
183 },
184 };
185 struct iovec iov = {
186 .iov_base = &req,
187 .iov_len = sizeof(req),
188 };
189 struct msghdr msg = {
190 .msg_name = &nladdr,
191 .msg_namelen = sizeof(nladdr),
192 .msg_iov = &iov,
193 .msg_iovlen = 1,
194 };
195
196 for (;;) {
197 if (sendmsg(fd, &msg, 0) < 0) {
198 if (errno == EINTR)
199 continue;
200
201 perror("sendmsg");
202 exit(EXIT_FAILURE);
203 }
204
205 return;
206 }
207 }
208
recv_resp(int fd,void * buf,size_t len)209 static ssize_t recv_resp(int fd, void *buf, size_t len)
210 {
211 struct sockaddr_nl nladdr = {
212 .nl_family = AF_NETLINK,
213 };
214 struct iovec iov = {
215 .iov_base = buf,
216 .iov_len = len,
217 };
218 struct msghdr msg = {
219 .msg_name = &nladdr,
220 .msg_namelen = sizeof(nladdr),
221 .msg_iov = &iov,
222 .msg_iovlen = 1,
223 };
224 ssize_t ret;
225
226 do {
227 ret = recvmsg(fd, &msg, 0);
228 } while (ret < 0 && errno == EINTR);
229
230 if (ret < 0) {
231 perror("recvmsg");
232 exit(EXIT_FAILURE);
233 }
234
235 return ret;
236 }
237
add_vsock_stat(struct list_head * sockets,const struct vsock_diag_msg * resp)238 static void add_vsock_stat(struct list_head *sockets,
239 const struct vsock_diag_msg *resp)
240 {
241 struct vsock_stat *st;
242
243 st = malloc(sizeof(*st));
244 if (!st) {
245 perror("malloc");
246 exit(EXIT_FAILURE);
247 }
248
249 st->msg = *resp;
250 list_add_tail(&st->list, sockets);
251 }
252
253 /*
254 * Read vsock stats into a list.
255 */
read_vsock_stat(struct list_head * sockets)256 static void read_vsock_stat(struct list_head *sockets)
257 {
258 long buf[8192 / sizeof(long)];
259 int fd;
260
261 fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);
262 if (fd < 0) {
263 perror("socket");
264 exit(EXIT_FAILURE);
265 }
266
267 send_req(fd);
268
269 for (;;) {
270 const struct nlmsghdr *h;
271 ssize_t ret;
272
273 ret = recv_resp(fd, buf, sizeof(buf));
274 if (ret == 0)
275 goto done;
276 if (ret < sizeof(*h)) {
277 fprintf(stderr, "short read of %zd bytes\n", ret);
278 exit(EXIT_FAILURE);
279 }
280
281 h = (struct nlmsghdr *)buf;
282
283 while (NLMSG_OK(h, ret)) {
284 if (h->nlmsg_type == NLMSG_DONE)
285 goto done;
286
287 if (h->nlmsg_type == NLMSG_ERROR) {
288 const struct nlmsgerr *err = NLMSG_DATA(h);
289
290 if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err)))
291 fprintf(stderr, "NLMSG_ERROR\n");
292 else {
293 errno = -err->error;
294 perror("NLMSG_ERROR");
295 }
296
297 exit(EXIT_FAILURE);
298 }
299
300 if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
301 fprintf(stderr, "unexpected nlmsg_type %#x\n",
302 h->nlmsg_type);
303 exit(EXIT_FAILURE);
304 }
305 if (h->nlmsg_len <
306 NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) {
307 fprintf(stderr, "short vsock_diag_msg\n");
308 exit(EXIT_FAILURE);
309 }
310
311 add_vsock_stat(sockets, NLMSG_DATA(h));
312
313 h = NLMSG_NEXT(h, ret);
314 }
315 }
316
317 done:
318 close(fd);
319 }
320
free_sock_stat(struct list_head * sockets)321 static void free_sock_stat(struct list_head *sockets)
322 {
323 struct vsock_stat *st;
324 struct vsock_stat *next;
325
326 list_for_each_entry_safe(st, next, sockets, list)
327 free(st);
328 }
329
test_no_sockets(const struct test_opts * opts)330 static void test_no_sockets(const struct test_opts *opts)
331 {
332 LIST_HEAD(sockets);
333
334 read_vsock_stat(&sockets);
335
336 check_no_sockets(&sockets);
337 }
338
test_listen_socket_server(const struct test_opts * opts)339 static void test_listen_socket_server(const struct test_opts *opts)
340 {
341 union {
342 struct sockaddr sa;
343 struct sockaddr_vm svm;
344 } addr = {
345 .svm = {
346 .svm_family = AF_VSOCK,
347 .svm_port = opts->peer_port,
348 .svm_cid = VMADDR_CID_ANY,
349 },
350 };
351 LIST_HEAD(sockets);
352 struct vsock_stat *st;
353 int fd;
354
355 fd = socket(AF_VSOCK, SOCK_STREAM, 0);
356
357 if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
358 perror("bind");
359 exit(EXIT_FAILURE);
360 }
361
362 if (listen(fd, 1) < 0) {
363 perror("listen");
364 exit(EXIT_FAILURE);
365 }
366
367 read_vsock_stat(&sockets);
368
369 check_num_sockets(&sockets, 1);
370 st = find_vsock_stat(&sockets, fd);
371 check_socket_state(st, TCP_LISTEN);
372
373 close(fd);
374 free_sock_stat(&sockets);
375 }
376
test_connect_client(const struct test_opts * opts)377 static void test_connect_client(const struct test_opts *opts)
378 {
379 int fd;
380 LIST_HEAD(sockets);
381 struct vsock_stat *st;
382
383 fd = vsock_stream_connect(opts->peer_cid, opts->peer_port);
384 if (fd < 0) {
385 perror("connect");
386 exit(EXIT_FAILURE);
387 }
388
389 read_vsock_stat(&sockets);
390
391 check_num_sockets(&sockets, 1);
392 st = find_vsock_stat(&sockets, fd);
393 check_socket_state(st, TCP_ESTABLISHED);
394
395 control_expectln("DONE");
396 control_writeln("DONE");
397
398 close(fd);
399 free_sock_stat(&sockets);
400 }
401
test_connect_server(const struct test_opts * opts)402 static void test_connect_server(const struct test_opts *opts)
403 {
404 struct vsock_stat *st;
405 LIST_HEAD(sockets);
406 int client_fd;
407
408 client_fd = vsock_stream_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
409 if (client_fd < 0) {
410 perror("accept");
411 exit(EXIT_FAILURE);
412 }
413
414 read_vsock_stat(&sockets);
415
416 check_num_sockets(&sockets, 1);
417 st = find_vsock_stat(&sockets, client_fd);
418 check_socket_state(st, TCP_ESTABLISHED);
419
420 control_writeln("DONE");
421 control_expectln("DONE");
422
423 close(client_fd);
424 free_sock_stat(&sockets);
425 }
426
427 static struct test_case test_cases[] = {
428 {
429 .name = "No sockets",
430 .run_server = test_no_sockets,
431 },
432 {
433 .name = "Listen socket",
434 .run_server = test_listen_socket_server,
435 },
436 {
437 .name = "Connect",
438 .run_client = test_connect_client,
439 .run_server = test_connect_server,
440 },
441 {},
442 };
443
444 static const char optstring[] = "";
445 static const struct option longopts[] = {
446 {
447 .name = "control-host",
448 .has_arg = required_argument,
449 .val = 'H',
450 },
451 {
452 .name = "control-port",
453 .has_arg = required_argument,
454 .val = 'P',
455 },
456 {
457 .name = "mode",
458 .has_arg = required_argument,
459 .val = 'm',
460 },
461 {
462 .name = "peer-cid",
463 .has_arg = required_argument,
464 .val = 'p',
465 },
466 {
467 .name = "peer-port",
468 .has_arg = required_argument,
469 .val = 'q',
470 },
471 {
472 .name = "list",
473 .has_arg = no_argument,
474 .val = 'l',
475 },
476 {
477 .name = "skip",
478 .has_arg = required_argument,
479 .val = 's',
480 },
481 {
482 .name = "help",
483 .has_arg = no_argument,
484 .val = '?',
485 },
486 {},
487 };
488
usage(void)489 static void usage(void)
490 {
491 fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid> [--peer-port=<port>] [--list] [--skip=<test_id>]\n"
492 "\n"
493 " Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
494 " Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
495 "\n"
496 "Run vsock_diag.ko tests. Must be launched in both\n"
497 "guest and host. One side must use --mode=client and\n"
498 "the other side must use --mode=server.\n"
499 "\n"
500 "A TCP control socket connection is used to coordinate tests\n"
501 "between the client and the server. The server requires a\n"
502 "listen address and the client requires an address to\n"
503 "connect to.\n"
504 "\n"
505 "The CID of the other side must be given with --peer-cid=<cid>.\n"
506 "\n"
507 "Options:\n"
508 " --help This help message\n"
509 " --control-host <host> Server IP address to connect to\n"
510 " --control-port <port> Server port to listen on/connect to\n"
511 " --mode client|server Server or client mode\n"
512 " --peer-cid <cid> CID of the other side\n"
513 " --peer-port <port> AF_VSOCK port used for the test [default: %d]\n"
514 " --list List of tests that will be executed\n"
515 " --skip <test_id> Test ID to skip;\n"
516 " use multiple --skip options to skip more tests\n",
517 DEFAULT_PEER_PORT
518 );
519 exit(EXIT_FAILURE);
520 }
521
main(int argc,char ** argv)522 int main(int argc, char **argv)
523 {
524 const char *control_host = NULL;
525 const char *control_port = NULL;
526 struct test_opts opts = {
527 .mode = TEST_MODE_UNSET,
528 .peer_cid = VMADDR_CID_ANY,
529 .peer_port = DEFAULT_PEER_PORT,
530 };
531
532 init_signals();
533
534 for (;;) {
535 int opt = getopt_long(argc, argv, optstring, longopts, NULL);
536
537 if (opt == -1)
538 break;
539
540 switch (opt) {
541 case 'H':
542 control_host = optarg;
543 break;
544 case 'm':
545 if (strcmp(optarg, "client") == 0)
546 opts.mode = TEST_MODE_CLIENT;
547 else if (strcmp(optarg, "server") == 0)
548 opts.mode = TEST_MODE_SERVER;
549 else {
550 fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
551 return EXIT_FAILURE;
552 }
553 break;
554 case 'p':
555 opts.peer_cid = parse_cid(optarg);
556 break;
557 case 'q':
558 opts.peer_port = parse_port(optarg);
559 break;
560 case 'P':
561 control_port = optarg;
562 break;
563 case 'l':
564 list_tests(test_cases);
565 break;
566 case 's':
567 skip_test(test_cases, ARRAY_SIZE(test_cases) - 1,
568 optarg);
569 break;
570 case '?':
571 default:
572 usage();
573 }
574 }
575
576 if (!control_port)
577 usage();
578 if (opts.mode == TEST_MODE_UNSET)
579 usage();
580 if (opts.peer_cid == VMADDR_CID_ANY)
581 usage();
582
583 if (!control_host) {
584 if (opts.mode != TEST_MODE_SERVER)
585 usage();
586 control_host = "0.0.0.0";
587 }
588
589 control_init(control_host, control_port,
590 opts.mode == TEST_MODE_SERVER);
591
592 run_tests(test_cases, &opts);
593
594 control_cleanup();
595 return EXIT_SUCCESS;
596 }
597