xref: /linux/tools/testing/vsock/util.c (revision df561f6688fef775baa341a0f5d960becd248b11)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * vsock test utilities
4  *
5  * Copyright (C) 2017 Red Hat, Inc.
6  *
7  * Author: Stefan Hajnoczi <stefanha@redhat.com>
8  */
9 
10 #include <errno.h>
11 #include <stdio.h>
12 #include <stdint.h>
13 #include <stdlib.h>
14 #include <signal.h>
15 #include <unistd.h>
16 #include <assert.h>
17 #include <sys/epoll.h>
18 
19 #include "timeout.h"
20 #include "control.h"
21 #include "util.h"
22 
23 /* Install signal handlers */
24 void init_signals(void)
25 {
26 	struct sigaction act = {
27 		.sa_handler = sigalrm,
28 	};
29 
30 	sigaction(SIGALRM, &act, NULL);
31 	signal(SIGPIPE, SIG_IGN);
32 }
33 
34 /* Parse a CID in string representation */
35 unsigned int parse_cid(const char *str)
36 {
37 	char *endptr = NULL;
38 	unsigned long n;
39 
40 	errno = 0;
41 	n = strtoul(str, &endptr, 10);
42 	if (errno || *endptr != '\0') {
43 		fprintf(stderr, "malformed CID \"%s\"\n", str);
44 		exit(EXIT_FAILURE);
45 	}
46 	return n;
47 }
48 
49 /* Wait for the remote to close the connection */
50 void vsock_wait_remote_close(int fd)
51 {
52 	struct epoll_event ev;
53 	int epollfd, nfds;
54 
55 	epollfd = epoll_create1(0);
56 	if (epollfd == -1) {
57 		perror("epoll_create1");
58 		exit(EXIT_FAILURE);
59 	}
60 
61 	ev.events = EPOLLRDHUP | EPOLLHUP;
62 	ev.data.fd = fd;
63 	if (epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &ev) == -1) {
64 		perror("epoll_ctl");
65 		exit(EXIT_FAILURE);
66 	}
67 
68 	nfds = epoll_wait(epollfd, &ev, 1, TIMEOUT * 1000);
69 	if (nfds == -1) {
70 		perror("epoll_wait");
71 		exit(EXIT_FAILURE);
72 	}
73 
74 	if (nfds == 0) {
75 		fprintf(stderr, "epoll_wait timed out\n");
76 		exit(EXIT_FAILURE);
77 	}
78 
79 	assert(nfds == 1);
80 	assert(ev.events & (EPOLLRDHUP | EPOLLHUP));
81 	assert(ev.data.fd == fd);
82 
83 	close(epollfd);
84 }
85 
86 /* Connect to <cid, port> and return the file descriptor. */
87 int vsock_stream_connect(unsigned int cid, unsigned int port)
88 {
89 	union {
90 		struct sockaddr sa;
91 		struct sockaddr_vm svm;
92 	} addr = {
93 		.svm = {
94 			.svm_family = AF_VSOCK,
95 			.svm_port = port,
96 			.svm_cid = cid,
97 		},
98 	};
99 	int ret;
100 	int fd;
101 
102 	control_expectln("LISTENING");
103 
104 	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
105 
106 	timeout_begin(TIMEOUT);
107 	do {
108 		ret = connect(fd, &addr.sa, sizeof(addr.svm));
109 		timeout_check("connect");
110 	} while (ret < 0 && errno == EINTR);
111 	timeout_end();
112 
113 	if (ret < 0) {
114 		int old_errno = errno;
115 
116 		close(fd);
117 		fd = -1;
118 		errno = old_errno;
119 	}
120 	return fd;
121 }
122 
123 /* Listen on <cid, port> and return the first incoming connection.  The remote
124  * address is stored to clientaddrp.  clientaddrp may be NULL.
125  */
126 int vsock_stream_accept(unsigned int cid, unsigned int port,
127 			struct sockaddr_vm *clientaddrp)
128 {
129 	union {
130 		struct sockaddr sa;
131 		struct sockaddr_vm svm;
132 	} addr = {
133 		.svm = {
134 			.svm_family = AF_VSOCK,
135 			.svm_port = port,
136 			.svm_cid = cid,
137 		},
138 	};
139 	union {
140 		struct sockaddr sa;
141 		struct sockaddr_vm svm;
142 	} clientaddr;
143 	socklen_t clientaddr_len = sizeof(clientaddr.svm);
144 	int fd;
145 	int client_fd;
146 	int old_errno;
147 
148 	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
149 
150 	if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
151 		perror("bind");
152 		exit(EXIT_FAILURE);
153 	}
154 
155 	if (listen(fd, 1) < 0) {
156 		perror("listen");
157 		exit(EXIT_FAILURE);
158 	}
159 
160 	control_writeln("LISTENING");
161 
162 	timeout_begin(TIMEOUT);
163 	do {
164 		client_fd = accept(fd, &clientaddr.sa, &clientaddr_len);
165 		timeout_check("accept");
166 	} while (client_fd < 0 && errno == EINTR);
167 	timeout_end();
168 
169 	old_errno = errno;
170 	close(fd);
171 	errno = old_errno;
172 
173 	if (client_fd < 0)
174 		return client_fd;
175 
176 	if (clientaddr_len != sizeof(clientaddr.svm)) {
177 		fprintf(stderr, "unexpected addrlen from accept(2), %zu\n",
178 			(size_t)clientaddr_len);
179 		exit(EXIT_FAILURE);
180 	}
181 	if (clientaddr.sa.sa_family != AF_VSOCK) {
182 		fprintf(stderr, "expected AF_VSOCK from accept(2), got %d\n",
183 			clientaddr.sa.sa_family);
184 		exit(EXIT_FAILURE);
185 	}
186 
187 	if (clientaddrp)
188 		*clientaddrp = clientaddr.svm;
189 	return client_fd;
190 }
191 
192 /* Transmit one byte and check the return value.
193  *
194  * expected_ret:
195  *  <0 Negative errno (for testing errors)
196  *   0 End-of-file
197  *   1 Success
198  */
199 void send_byte(int fd, int expected_ret, int flags)
200 {
201 	const uint8_t byte = 'A';
202 	ssize_t nwritten;
203 
204 	timeout_begin(TIMEOUT);
205 	do {
206 		nwritten = send(fd, &byte, sizeof(byte), flags);
207 		timeout_check("write");
208 	} while (nwritten < 0 && errno == EINTR);
209 	timeout_end();
210 
211 	if (expected_ret < 0) {
212 		if (nwritten != -1) {
213 			fprintf(stderr, "bogus send(2) return value %zd\n",
214 				nwritten);
215 			exit(EXIT_FAILURE);
216 		}
217 		if (errno != -expected_ret) {
218 			perror("write");
219 			exit(EXIT_FAILURE);
220 		}
221 		return;
222 	}
223 
224 	if (nwritten < 0) {
225 		perror("write");
226 		exit(EXIT_FAILURE);
227 	}
228 	if (nwritten == 0) {
229 		if (expected_ret == 0)
230 			return;
231 
232 		fprintf(stderr, "unexpected EOF while sending byte\n");
233 		exit(EXIT_FAILURE);
234 	}
235 	if (nwritten != sizeof(byte)) {
236 		fprintf(stderr, "bogus send(2) return value %zd\n", nwritten);
237 		exit(EXIT_FAILURE);
238 	}
239 }
240 
241 /* Receive one byte and check the return value.
242  *
243  * expected_ret:
244  *  <0 Negative errno (for testing errors)
245  *   0 End-of-file
246  *   1 Success
247  */
248 void recv_byte(int fd, int expected_ret, int flags)
249 {
250 	uint8_t byte;
251 	ssize_t nread;
252 
253 	timeout_begin(TIMEOUT);
254 	do {
255 		nread = recv(fd, &byte, sizeof(byte), flags);
256 		timeout_check("read");
257 	} while (nread < 0 && errno == EINTR);
258 	timeout_end();
259 
260 	if (expected_ret < 0) {
261 		if (nread != -1) {
262 			fprintf(stderr, "bogus recv(2) return value %zd\n",
263 				nread);
264 			exit(EXIT_FAILURE);
265 		}
266 		if (errno != -expected_ret) {
267 			perror("read");
268 			exit(EXIT_FAILURE);
269 		}
270 		return;
271 	}
272 
273 	if (nread < 0) {
274 		perror("read");
275 		exit(EXIT_FAILURE);
276 	}
277 	if (nread == 0) {
278 		if (expected_ret == 0)
279 			return;
280 
281 		fprintf(stderr, "unexpected EOF while receiving byte\n");
282 		exit(EXIT_FAILURE);
283 	}
284 	if (nread != sizeof(byte)) {
285 		fprintf(stderr, "bogus recv(2) return value %zd\n", nread);
286 		exit(EXIT_FAILURE);
287 	}
288 	if (byte != 'A') {
289 		fprintf(stderr, "unexpected byte read %c\n", byte);
290 		exit(EXIT_FAILURE);
291 	}
292 }
293 
294 /* Run test cases.  The program terminates if a failure occurs. */
295 void run_tests(const struct test_case *test_cases,
296 	       const struct test_opts *opts)
297 {
298 	int i;
299 
300 	for (i = 0; test_cases[i].name; i++) {
301 		void (*run)(const struct test_opts *opts);
302 		char *line;
303 
304 		printf("%d - %s...", i, test_cases[i].name);
305 		fflush(stdout);
306 
307 		/* Full barrier before executing the next test.  This
308 		 * ensures that client and server are executing the
309 		 * same test case.  In particular, it means whoever is
310 		 * faster will not see the peer still executing the
311 		 * last test.  This is important because port numbers
312 		 * can be used by multiple test cases.
313 		 */
314 		if (test_cases[i].skip)
315 			control_writeln("SKIP");
316 		else
317 			control_writeln("NEXT");
318 
319 		line = control_readln();
320 		if (control_cmpln(line, "SKIP", false) || test_cases[i].skip) {
321 
322 			printf("skipped\n");
323 
324 			free(line);
325 			continue;
326 		}
327 
328 		control_cmpln(line, "NEXT", true);
329 		free(line);
330 
331 		if (opts->mode == TEST_MODE_CLIENT)
332 			run = test_cases[i].run_client;
333 		else
334 			run = test_cases[i].run_server;
335 
336 		if (run)
337 			run(opts);
338 
339 		printf("ok\n");
340 	}
341 }
342 
343 void list_tests(const struct test_case *test_cases)
344 {
345 	int i;
346 
347 	printf("ID\tTest name\n");
348 
349 	for (i = 0; test_cases[i].name; i++)
350 		printf("%d\t%s\n", i, test_cases[i].name);
351 
352 	exit(EXIT_FAILURE);
353 }
354 
355 void skip_test(struct test_case *test_cases, size_t test_cases_len,
356 	       const char *test_id_str)
357 {
358 	unsigned long test_id;
359 	char *endptr = NULL;
360 
361 	errno = 0;
362 	test_id = strtoul(test_id_str, &endptr, 10);
363 	if (errno || *endptr != '\0') {
364 		fprintf(stderr, "malformed test ID \"%s\"\n", test_id_str);
365 		exit(EXIT_FAILURE);
366 	}
367 
368 	if (test_id >= test_cases_len) {
369 		fprintf(stderr, "test ID (%lu) larger than the max allowed (%lu)\n",
370 			test_id, test_cases_len - 1);
371 		exit(EXIT_FAILURE);
372 	}
373 
374 	test_cases[test_id].skip = true;
375 }
376