xref: /linux/tools/testing/vsock/util.c (revision 6f19b2c136d98a84d79030b53e23d405edfdc783)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * vsock test utilities
4  *
5  * Copyright (C) 2017 Red Hat, Inc.
6  *
7  * Author: Stefan Hajnoczi <stefanha@redhat.com>
8  */
9 
10 #include <errno.h>
11 #include <stdio.h>
12 #include <stdint.h>
13 #include <stdlib.h>
14 #include <string.h>
15 #include <signal.h>
16 #include <unistd.h>
17 #include <assert.h>
18 #include <sys/epoll.h>
19 #include <sys/mman.h>
20 
21 #include "timeout.h"
22 #include "control.h"
23 #include "util.h"
24 
25 /* Install signal handlers */
26 void init_signals(void)
27 {
28 	struct sigaction act = {
29 		.sa_handler = sigalrm,
30 	};
31 
32 	sigaction(SIGALRM, &act, NULL);
33 	signal(SIGPIPE, SIG_IGN);
34 }
35 
36 /* Parse a CID in string representation */
37 unsigned int parse_cid(const char *str)
38 {
39 	char *endptr = NULL;
40 	unsigned long n;
41 
42 	errno = 0;
43 	n = strtoul(str, &endptr, 10);
44 	if (errno || *endptr != '\0') {
45 		fprintf(stderr, "malformed CID \"%s\"\n", str);
46 		exit(EXIT_FAILURE);
47 	}
48 	return n;
49 }
50 
51 /* Wait for the remote to close the connection */
52 void vsock_wait_remote_close(int fd)
53 {
54 	struct epoll_event ev;
55 	int epollfd, nfds;
56 
57 	epollfd = epoll_create1(0);
58 	if (epollfd == -1) {
59 		perror("epoll_create1");
60 		exit(EXIT_FAILURE);
61 	}
62 
63 	ev.events = EPOLLRDHUP | EPOLLHUP;
64 	ev.data.fd = fd;
65 	if (epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &ev) == -1) {
66 		perror("epoll_ctl");
67 		exit(EXIT_FAILURE);
68 	}
69 
70 	nfds = epoll_wait(epollfd, &ev, 1, TIMEOUT * 1000);
71 	if (nfds == -1) {
72 		perror("epoll_wait");
73 		exit(EXIT_FAILURE);
74 	}
75 
76 	if (nfds == 0) {
77 		fprintf(stderr, "epoll_wait timed out\n");
78 		exit(EXIT_FAILURE);
79 	}
80 
81 	assert(nfds == 1);
82 	assert(ev.events & (EPOLLRDHUP | EPOLLHUP));
83 	assert(ev.data.fd == fd);
84 
85 	close(epollfd);
86 }
87 
88 /* Connect to <cid, port> and return the file descriptor. */
89 static int vsock_connect(unsigned int cid, unsigned int port, int type)
90 {
91 	union {
92 		struct sockaddr sa;
93 		struct sockaddr_vm svm;
94 	} addr = {
95 		.svm = {
96 			.svm_family = AF_VSOCK,
97 			.svm_port = port,
98 			.svm_cid = cid,
99 		},
100 	};
101 	int ret;
102 	int fd;
103 
104 	control_expectln("LISTENING");
105 
106 	fd = socket(AF_VSOCK, type, 0);
107 
108 	timeout_begin(TIMEOUT);
109 	do {
110 		ret = connect(fd, &addr.sa, sizeof(addr.svm));
111 		timeout_check("connect");
112 	} while (ret < 0 && errno == EINTR);
113 	timeout_end();
114 
115 	if (ret < 0) {
116 		int old_errno = errno;
117 
118 		close(fd);
119 		fd = -1;
120 		errno = old_errno;
121 	}
122 	return fd;
123 }
124 
125 int vsock_stream_connect(unsigned int cid, unsigned int port)
126 {
127 	return vsock_connect(cid, port, SOCK_STREAM);
128 }
129 
130 int vsock_seqpacket_connect(unsigned int cid, unsigned int port)
131 {
132 	return vsock_connect(cid, port, SOCK_SEQPACKET);
133 }
134 
135 /* Listen on <cid, port> and return the first incoming connection.  The remote
136  * address is stored to clientaddrp.  clientaddrp may be NULL.
137  */
138 static int vsock_accept(unsigned int cid, unsigned int port,
139 			struct sockaddr_vm *clientaddrp, int type)
140 {
141 	union {
142 		struct sockaddr sa;
143 		struct sockaddr_vm svm;
144 	} addr = {
145 		.svm = {
146 			.svm_family = AF_VSOCK,
147 			.svm_port = port,
148 			.svm_cid = cid,
149 		},
150 	};
151 	union {
152 		struct sockaddr sa;
153 		struct sockaddr_vm svm;
154 	} clientaddr;
155 	socklen_t clientaddr_len = sizeof(clientaddr.svm);
156 	int fd;
157 	int client_fd;
158 	int old_errno;
159 
160 	fd = socket(AF_VSOCK, type, 0);
161 
162 	if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
163 		perror("bind");
164 		exit(EXIT_FAILURE);
165 	}
166 
167 	if (listen(fd, 1) < 0) {
168 		perror("listen");
169 		exit(EXIT_FAILURE);
170 	}
171 
172 	control_writeln("LISTENING");
173 
174 	timeout_begin(TIMEOUT);
175 	do {
176 		client_fd = accept(fd, &clientaddr.sa, &clientaddr_len);
177 		timeout_check("accept");
178 	} while (client_fd < 0 && errno == EINTR);
179 	timeout_end();
180 
181 	old_errno = errno;
182 	close(fd);
183 	errno = old_errno;
184 
185 	if (client_fd < 0)
186 		return client_fd;
187 
188 	if (clientaddr_len != sizeof(clientaddr.svm)) {
189 		fprintf(stderr, "unexpected addrlen from accept(2), %zu\n",
190 			(size_t)clientaddr_len);
191 		exit(EXIT_FAILURE);
192 	}
193 	if (clientaddr.sa.sa_family != AF_VSOCK) {
194 		fprintf(stderr, "expected AF_VSOCK from accept(2), got %d\n",
195 			clientaddr.sa.sa_family);
196 		exit(EXIT_FAILURE);
197 	}
198 
199 	if (clientaddrp)
200 		*clientaddrp = clientaddr.svm;
201 	return client_fd;
202 }
203 
204 int vsock_stream_accept(unsigned int cid, unsigned int port,
205 			struct sockaddr_vm *clientaddrp)
206 {
207 	return vsock_accept(cid, port, clientaddrp, SOCK_STREAM);
208 }
209 
210 int vsock_seqpacket_accept(unsigned int cid, unsigned int port,
211 			   struct sockaddr_vm *clientaddrp)
212 {
213 	return vsock_accept(cid, port, clientaddrp, SOCK_SEQPACKET);
214 }
215 
216 /* Transmit bytes from a buffer and check the return value.
217  *
218  * expected_ret:
219  *  <0 Negative errno (for testing errors)
220  *   0 End-of-file
221  *  >0 Success (bytes successfully written)
222  */
223 void send_buf(int fd, const void *buf, size_t len, int flags,
224 	      ssize_t expected_ret)
225 {
226 	ssize_t nwritten = 0;
227 	ssize_t ret;
228 
229 	timeout_begin(TIMEOUT);
230 	do {
231 		ret = send(fd, buf + nwritten, len - nwritten, flags);
232 		timeout_check("send");
233 
234 		if (ret == 0 || (ret < 0 && errno != EINTR))
235 			break;
236 
237 		nwritten += ret;
238 	} while (nwritten < len);
239 	timeout_end();
240 
241 	if (expected_ret < 0) {
242 		if (ret != -1) {
243 			fprintf(stderr, "bogus send(2) return value %zd (expected %zd)\n",
244 				ret, expected_ret);
245 			exit(EXIT_FAILURE);
246 		}
247 		if (errno != -expected_ret) {
248 			perror("send");
249 			exit(EXIT_FAILURE);
250 		}
251 		return;
252 	}
253 
254 	if (ret < 0) {
255 		perror("send");
256 		exit(EXIT_FAILURE);
257 	}
258 
259 	if (nwritten != expected_ret) {
260 		if (ret == 0)
261 			fprintf(stderr, "unexpected EOF while sending bytes\n");
262 
263 		fprintf(stderr, "bogus send(2) bytes written %zd (expected %zd)\n",
264 			nwritten, expected_ret);
265 		exit(EXIT_FAILURE);
266 	}
267 }
268 
269 /* Receive bytes in a buffer and check the return value.
270  *
271  * expected_ret:
272  *  <0 Negative errno (for testing errors)
273  *   0 End-of-file
274  *  >0 Success (bytes successfully read)
275  */
276 void recv_buf(int fd, void *buf, size_t len, int flags, ssize_t expected_ret)
277 {
278 	ssize_t nread = 0;
279 	ssize_t ret;
280 
281 	timeout_begin(TIMEOUT);
282 	do {
283 		ret = recv(fd, buf + nread, len - nread, flags);
284 		timeout_check("recv");
285 
286 		if (ret == 0 || (ret < 0 && errno != EINTR))
287 			break;
288 
289 		nread += ret;
290 	} while (nread < len);
291 	timeout_end();
292 
293 	if (expected_ret < 0) {
294 		if (ret != -1) {
295 			fprintf(stderr, "bogus recv(2) return value %zd (expected %zd)\n",
296 				ret, expected_ret);
297 			exit(EXIT_FAILURE);
298 		}
299 		if (errno != -expected_ret) {
300 			perror("recv");
301 			exit(EXIT_FAILURE);
302 		}
303 		return;
304 	}
305 
306 	if (ret < 0) {
307 		perror("recv");
308 		exit(EXIT_FAILURE);
309 	}
310 
311 	if (nread != expected_ret) {
312 		if (ret == 0)
313 			fprintf(stderr, "unexpected EOF while receiving bytes\n");
314 
315 		fprintf(stderr, "bogus recv(2) bytes read %zd (expected %zd)\n",
316 			nread, expected_ret);
317 		exit(EXIT_FAILURE);
318 	}
319 }
320 
321 /* Transmit one byte and check the return value.
322  *
323  * expected_ret:
324  *  <0 Negative errno (for testing errors)
325  *   0 End-of-file
326  *   1 Success
327  */
328 void send_byte(int fd, int expected_ret, int flags)
329 {
330 	const uint8_t byte = 'A';
331 
332 	send_buf(fd, &byte, sizeof(byte), flags, expected_ret);
333 }
334 
335 /* Receive one byte and check the return value.
336  *
337  * expected_ret:
338  *  <0 Negative errno (for testing errors)
339  *   0 End-of-file
340  *   1 Success
341  */
342 void recv_byte(int fd, int expected_ret, int flags)
343 {
344 	uint8_t byte;
345 
346 	recv_buf(fd, &byte, sizeof(byte), flags, expected_ret);
347 
348 	if (byte != 'A') {
349 		fprintf(stderr, "unexpected byte read %c\n", byte);
350 		exit(EXIT_FAILURE);
351 	}
352 }
353 
354 /* Run test cases.  The program terminates if a failure occurs. */
355 void run_tests(const struct test_case *test_cases,
356 	       const struct test_opts *opts)
357 {
358 	int i;
359 
360 	for (i = 0; test_cases[i].name; i++) {
361 		void (*run)(const struct test_opts *opts);
362 		char *line;
363 
364 		printf("%d - %s...", i, test_cases[i].name);
365 		fflush(stdout);
366 
367 		/* Full barrier before executing the next test.  This
368 		 * ensures that client and server are executing the
369 		 * same test case.  In particular, it means whoever is
370 		 * faster will not see the peer still executing the
371 		 * last test.  This is important because port numbers
372 		 * can be used by multiple test cases.
373 		 */
374 		if (test_cases[i].skip)
375 			control_writeln("SKIP");
376 		else
377 			control_writeln("NEXT");
378 
379 		line = control_readln();
380 		if (control_cmpln(line, "SKIP", false) || test_cases[i].skip) {
381 
382 			printf("skipped\n");
383 
384 			free(line);
385 			continue;
386 		}
387 
388 		control_cmpln(line, "NEXT", true);
389 		free(line);
390 
391 		if (opts->mode == TEST_MODE_CLIENT)
392 			run = test_cases[i].run_client;
393 		else
394 			run = test_cases[i].run_server;
395 
396 		if (run)
397 			run(opts);
398 
399 		printf("ok\n");
400 	}
401 }
402 
403 void list_tests(const struct test_case *test_cases)
404 {
405 	int i;
406 
407 	printf("ID\tTest name\n");
408 
409 	for (i = 0; test_cases[i].name; i++)
410 		printf("%d\t%s\n", i, test_cases[i].name);
411 
412 	exit(EXIT_FAILURE);
413 }
414 
415 void skip_test(struct test_case *test_cases, size_t test_cases_len,
416 	       const char *test_id_str)
417 {
418 	unsigned long test_id;
419 	char *endptr = NULL;
420 
421 	errno = 0;
422 	test_id = strtoul(test_id_str, &endptr, 10);
423 	if (errno || *endptr != '\0') {
424 		fprintf(stderr, "malformed test ID \"%s\"\n", test_id_str);
425 		exit(EXIT_FAILURE);
426 	}
427 
428 	if (test_id >= test_cases_len) {
429 		fprintf(stderr, "test ID (%lu) larger than the max allowed (%lu)\n",
430 			test_id, test_cases_len - 1);
431 		exit(EXIT_FAILURE);
432 	}
433 
434 	test_cases[test_id].skip = true;
435 }
436 
437 unsigned long hash_djb2(const void *data, size_t len)
438 {
439 	unsigned long hash = 5381;
440 	int i = 0;
441 
442 	while (i < len) {
443 		hash = ((hash << 5) + hash) + ((unsigned char *)data)[i];
444 		i++;
445 	}
446 
447 	return hash;
448 }
449 
450 size_t iovec_bytes(const struct iovec *iov, size_t iovnum)
451 {
452 	size_t bytes;
453 	int i;
454 
455 	for (bytes = 0, i = 0; i < iovnum; i++)
456 		bytes += iov[i].iov_len;
457 
458 	return bytes;
459 }
460 
461 unsigned long iovec_hash_djb2(const struct iovec *iov, size_t iovnum)
462 {
463 	unsigned long hash;
464 	size_t iov_bytes;
465 	size_t offs;
466 	void *tmp;
467 	int i;
468 
469 	iov_bytes = iovec_bytes(iov, iovnum);
470 
471 	tmp = malloc(iov_bytes);
472 	if (!tmp) {
473 		perror("malloc");
474 		exit(EXIT_FAILURE);
475 	}
476 
477 	for (offs = 0, i = 0; i < iovnum; i++) {
478 		memcpy(tmp + offs, iov[i].iov_base, iov[i].iov_len);
479 		offs += iov[i].iov_len;
480 	}
481 
482 	hash = hash_djb2(tmp, iov_bytes);
483 	free(tmp);
484 
485 	return hash;
486 }
487 
488 /* Allocates and returns new 'struct iovec *' according pattern
489  * in the 'test_iovec'. For each element in the 'test_iovec' it
490  * allocates new element in the resulting 'iovec'. 'iov_len'
491  * of the new element is copied from 'test_iovec'. 'iov_base' is
492  * allocated depending on the 'iov_base' of 'test_iovec':
493  *
494  * 'iov_base' == NULL -> valid buf: mmap('iov_len').
495  *
496  * 'iov_base' == MAP_FAILED -> invalid buf:
497  *               mmap('iov_len'), then munmap('iov_len').
498  *               'iov_base' still contains result of
499  *               mmap().
500  *
501  * 'iov_base' == number -> unaligned valid buf:
502  *               mmap('iov_len') + number.
503  *
504  * 'iovnum' is number of elements in 'test_iovec'.
505  *
506  * Returns new 'iovec' or calls 'exit()' on error.
507  */
508 struct iovec *alloc_test_iovec(const struct iovec *test_iovec, int iovnum)
509 {
510 	struct iovec *iovec;
511 	int i;
512 
513 	iovec = malloc(sizeof(*iovec) * iovnum);
514 	if (!iovec) {
515 		perror("malloc");
516 		exit(EXIT_FAILURE);
517 	}
518 
519 	for (i = 0; i < iovnum; i++) {
520 		iovec[i].iov_len = test_iovec[i].iov_len;
521 
522 		iovec[i].iov_base = mmap(NULL, iovec[i].iov_len,
523 					 PROT_READ | PROT_WRITE,
524 					 MAP_PRIVATE | MAP_ANONYMOUS | MAP_POPULATE,
525 					 -1, 0);
526 		if (iovec[i].iov_base == MAP_FAILED) {
527 			perror("mmap");
528 			exit(EXIT_FAILURE);
529 		}
530 
531 		if (test_iovec[i].iov_base != MAP_FAILED)
532 			iovec[i].iov_base += (uintptr_t)test_iovec[i].iov_base;
533 	}
534 
535 	/* Unmap "invalid" elements. */
536 	for (i = 0; i < iovnum; i++) {
537 		if (test_iovec[i].iov_base == MAP_FAILED) {
538 			if (munmap(iovec[i].iov_base, iovec[i].iov_len)) {
539 				perror("munmap");
540 				exit(EXIT_FAILURE);
541 			}
542 		}
543 	}
544 
545 	for (i = 0; i < iovnum; i++) {
546 		int j;
547 
548 		if (test_iovec[i].iov_base == MAP_FAILED)
549 			continue;
550 
551 		for (j = 0; j < iovec[i].iov_len; j++)
552 			((uint8_t *)iovec[i].iov_base)[j] = rand() & 0xff;
553 	}
554 
555 	return iovec;
556 }
557 
558 /* Frees 'iovec *', previously allocated by 'alloc_test_iovec()'.
559  * On error calls 'exit()'.
560  */
561 void free_test_iovec(const struct iovec *test_iovec,
562 		     struct iovec *iovec, int iovnum)
563 {
564 	int i;
565 
566 	for (i = 0; i < iovnum; i++) {
567 		if (test_iovec[i].iov_base != MAP_FAILED) {
568 			if (test_iovec[i].iov_base)
569 				iovec[i].iov_base -= (uintptr_t)test_iovec[i].iov_base;
570 
571 			if (munmap(iovec[i].iov_base, iovec[i].iov_len)) {
572 				perror("munmap");
573 				exit(EXIT_FAILURE);
574 			}
575 		}
576 	}
577 
578 	free(iovec);
579 }
580