xref: /freebsd/lib/libnv/msgio.c (revision 6683132d54bd6d589889e43dabdc53d35e38a028)
1 /*-
2  * SPDX-License-Identifier: BSD-2-Clause-FreeBSD
3  *
4  * Copyright (c) 2013 The FreeBSD Foundation
5  * Copyright (c) 2013 Mariusz Zaborski <oshogbo@FreeBSD.org>
6  * All rights reserved.
7  *
8  * This software was developed by Pawel Jakub Dawidek under sponsorship from
9  * the FreeBSD Foundation.
10  *
11  * Redistribution and use in source and binary forms, with or without
12  * modification, are permitted provided that the following conditions
13  * are met:
14  * 1. Redistributions of source code must retain the above copyright
15  *    notice, this list of conditions and the following disclaimer.
16  * 2. Redistributions in binary form must reproduce the above copyright
17  *    notice, this list of conditions and the following disclaimer in the
18  *    documentation and/or other materials provided with the distribution.
19  *
20  * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
21  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
23  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
24  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
26  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
27  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
28  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
29  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
30  * SUCH DAMAGE.
31  */
32 
33 #include <sys/cdefs.h>
34 __FBSDID("$FreeBSD$");
35 
36 #include <sys/param.h>
37 #include <sys/socket.h>
38 #include <sys/select.h>
39 
40 #include <errno.h>
41 #include <fcntl.h>
42 #include <stdbool.h>
43 #include <stdint.h>
44 #include <stdlib.h>
45 #include <string.h>
46 #include <unistd.h>
47 
48 #ifdef HAVE_PJDLOG
49 #include <pjdlog.h>
50 #endif
51 
52 #include "common_impl.h"
53 #include "msgio.h"
54 
55 #ifndef	HAVE_PJDLOG
56 #include <assert.h>
57 #define	PJDLOG_ASSERT(...)		assert(__VA_ARGS__)
58 #define	PJDLOG_RASSERT(expr, ...)	assert(expr)
59 #define	PJDLOG_ABORT(...)		abort()
60 #endif
61 
62 #ifdef __linux__
63 /* Linux: arbitrary size, but must be lower than SCM_MAX_FD. */
64 #define	PKG_MAX_SIZE	((64U - 1) * CMSG_SPACE(sizeof(int)))
65 #else
66 #define	PKG_MAX_SIZE	(MCLBYTES / CMSG_SPACE(sizeof(int)) - 1)
67 #endif
68 
69 static int
70 msghdr_add_fd(struct cmsghdr *cmsg, int fd)
71 {
72 
73 	PJDLOG_ASSERT(fd >= 0);
74 
75 	cmsg->cmsg_level = SOL_SOCKET;
76 	cmsg->cmsg_type = SCM_RIGHTS;
77 	cmsg->cmsg_len = CMSG_LEN(sizeof(fd));
78 	bcopy(&fd, CMSG_DATA(cmsg), sizeof(fd));
79 
80 	return (0);
81 }
82 
83 static void
84 fd_wait(int fd, bool doread)
85 {
86 	fd_set fds;
87 
88 	PJDLOG_ASSERT(fd >= 0);
89 
90 	FD_ZERO(&fds);
91 	FD_SET(fd, &fds);
92 	(void)select(fd + 1, doread ? &fds : NULL, doread ? NULL : &fds,
93 	    NULL, NULL);
94 }
95 
96 static int
97 msg_recv(int sock, struct msghdr *msg)
98 {
99 	int flags;
100 
101 	PJDLOG_ASSERT(sock >= 0);
102 
103 #ifdef MSG_CMSG_CLOEXEC
104 	flags = MSG_CMSG_CLOEXEC;
105 #else
106 	flags = 0;
107 #endif
108 
109 	for (;;) {
110 		fd_wait(sock, true);
111 		if (recvmsg(sock, msg, flags) == -1) {
112 			if (errno == EINTR)
113 				continue;
114 			return (-1);
115 		}
116 		break;
117 	}
118 
119 	return (0);
120 }
121 
122 static int
123 msg_send(int sock, const struct msghdr *msg)
124 {
125 
126 	PJDLOG_ASSERT(sock >= 0);
127 
128 	for (;;) {
129 		fd_wait(sock, false);
130 		if (sendmsg(sock, msg, 0) == -1) {
131 			if (errno == EINTR)
132 				continue;
133 			return (-1);
134 		}
135 		break;
136 	}
137 
138 	return (0);
139 }
140 
141 #ifdef __FreeBSD__
142 int
143 cred_send(int sock)
144 {
145 	unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))];
146 	struct msghdr msg;
147 	struct cmsghdr *cmsg;
148 	struct iovec iov;
149 	uint8_t dummy;
150 
151 	bzero(credbuf, sizeof(credbuf));
152 	bzero(&msg, sizeof(msg));
153 	bzero(&iov, sizeof(iov));
154 
155 	/*
156 	 * XXX: We send one byte along with the control message, because
157 	 *      setting msg_iov to NULL only works if this is the first
158 	 *      packet send over the socket. Once we send some data we
159 	 *      won't be able to send credentials anymore. This is most
160 	 *      likely a kernel bug.
161 	 */
162 	dummy = 0;
163 	iov.iov_base = &dummy;
164 	iov.iov_len = sizeof(dummy);
165 
166 	msg.msg_iov = &iov;
167 	msg.msg_iovlen = 1;
168 	msg.msg_control = credbuf;
169 	msg.msg_controllen = sizeof(credbuf);
170 
171 	cmsg = CMSG_FIRSTHDR(&msg);
172 	cmsg->cmsg_len = CMSG_LEN(sizeof(struct cmsgcred));
173 	cmsg->cmsg_level = SOL_SOCKET;
174 	cmsg->cmsg_type = SCM_CREDS;
175 
176 	if (msg_send(sock, &msg) == -1)
177 		return (-1);
178 
179 	return (0);
180 }
181 
182 int
183 cred_recv(int sock, struct cmsgcred *cred)
184 {
185 	unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))];
186 	struct msghdr msg;
187 	struct cmsghdr *cmsg;
188 	struct iovec iov;
189 	uint8_t dummy;
190 
191 	bzero(credbuf, sizeof(credbuf));
192 	bzero(&msg, sizeof(msg));
193 	bzero(&iov, sizeof(iov));
194 
195 	iov.iov_base = &dummy;
196 	iov.iov_len = sizeof(dummy);
197 
198 	msg.msg_iov = &iov;
199 	msg.msg_iovlen = 1;
200 	msg.msg_control = credbuf;
201 	msg.msg_controllen = sizeof(credbuf);
202 
203 	if (msg_recv(sock, &msg) == -1)
204 		return (-1);
205 
206 	cmsg = CMSG_FIRSTHDR(&msg);
207 	if (cmsg == NULL ||
208 	    cmsg->cmsg_len != CMSG_LEN(sizeof(struct cmsgcred)) ||
209 	    cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_CREDS) {
210 		errno = EINVAL;
211 		return (-1);
212 	}
213 	bcopy(CMSG_DATA(cmsg), cred, sizeof(*cred));
214 
215 	return (0);
216 }
217 #endif
218 
219 static int
220 fd_package_send(int sock, const int *fds, size_t nfds)
221 {
222 	struct msghdr msg;
223 	struct cmsghdr *cmsg;
224 	struct iovec iov;
225 	unsigned int i;
226 	int serrno, ret;
227 	uint8_t dummy;
228 
229 	PJDLOG_ASSERT(sock >= 0);
230 	PJDLOG_ASSERT(fds != NULL);
231 	PJDLOG_ASSERT(nfds > 0);
232 
233 	bzero(&msg, sizeof(msg));
234 
235 	/*
236 	 * XXX: Look into cred_send function for more details.
237 	 */
238 	dummy = 0;
239 	iov.iov_base = &dummy;
240 	iov.iov_len = sizeof(dummy);
241 
242 	msg.msg_iov = &iov;
243 	msg.msg_iovlen = 1;
244 	msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int));
245 	msg.msg_control = calloc(1, msg.msg_controllen);
246 	if (msg.msg_control == NULL)
247 		return (-1);
248 
249 	ret = -1;
250 
251 	for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL;
252 	    i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) {
253 		if (msghdr_add_fd(cmsg, fds[i]) == -1)
254 			goto end;
255 	}
256 
257 	if (msg_send(sock, &msg) == -1)
258 		goto end;
259 
260 	ret = 0;
261 end:
262 	serrno = errno;
263 	free(msg.msg_control);
264 	errno = serrno;
265 	return (ret);
266 }
267 
268 static int
269 fd_package_recv(int sock, int *fds, size_t nfds)
270 {
271 	struct msghdr msg;
272 	struct cmsghdr *cmsg;
273 	unsigned int i;
274 	int serrno, ret;
275 	struct iovec iov;
276 	uint8_t dummy;
277 
278 	PJDLOG_ASSERT(sock >= 0);
279 	PJDLOG_ASSERT(nfds > 0);
280 	PJDLOG_ASSERT(fds != NULL);
281 
282 	bzero(&msg, sizeof(msg));
283 	bzero(&iov, sizeof(iov));
284 
285 	/*
286 	 * XXX: Look into cred_send function for more details.
287 	 */
288 	iov.iov_base = &dummy;
289 	iov.iov_len = sizeof(dummy);
290 
291 	msg.msg_iov = &iov;
292 	msg.msg_iovlen = 1;
293 	msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int));
294 	msg.msg_control = calloc(1, msg.msg_controllen);
295 	if (msg.msg_control == NULL)
296 		return (-1);
297 
298 	ret = -1;
299 
300 	if (msg_recv(sock, &msg) == -1)
301 		goto end;
302 
303 	i = 0;
304 	cmsg = CMSG_FIRSTHDR(&msg);
305 	while (cmsg && i < nfds) {
306 		unsigned int n;
307 
308 		if (cmsg->cmsg_level != SOL_SOCKET ||
309 		    cmsg->cmsg_type != SCM_RIGHTS) {
310 			errno = EINVAL;
311 			break;
312 		}
313 		n = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
314 		if (i + n > nfds) {
315 			errno = EINVAL;
316 			break;
317 		}
318 		bcopy(CMSG_DATA(cmsg), fds + i, sizeof(int) * n);
319 		cmsg = CMSG_NXTHDR(&msg, cmsg);
320 		i += n;
321 	}
322 
323 	if (cmsg != NULL || i < nfds) {
324 		unsigned int last;
325 
326 		/*
327 		 * We need to close all received descriptors, even if we have
328 		 * different control message (eg. SCM_CREDS) in between.
329 		 */
330 		last = i;
331 		for (i = 0; i < last; i++) {
332 			if (fds[i] >= 0) {
333 				close(fds[i]);
334 			}
335 		}
336 		errno = EINVAL;
337 		goto end;
338 	}
339 
340 #ifndef MSG_CMSG_CLOEXEC
341 	/*
342 	 * If the MSG_CMSG_CLOEXEC flag is not available we cannot set the
343 	 * close-on-exec flag atomically, but we still want to set it for
344 	 * consistency.
345 	 */
346 	for (i = 0; i < nfds; i++) {
347 		(void) fcntl(fds[i], F_SETFD, FD_CLOEXEC);
348 	}
349 #endif
350 
351 	ret = 0;
352 end:
353 	serrno = errno;
354 	free(msg.msg_control);
355 	errno = serrno;
356 	return (ret);
357 }
358 
359 int
360 fd_recv(int sock, int *fds, size_t nfds)
361 {
362 	unsigned int i, step, j;
363 	int ret, serrno;
364 
365 	if (nfds == 0 || fds == NULL) {
366 		errno = EINVAL;
367 		return (-1);
368 	}
369 
370 	ret = i = step = 0;
371 	while (i < nfds) {
372 		if (PKG_MAX_SIZE < nfds - i)
373 			step = PKG_MAX_SIZE;
374 		else
375 			step = nfds - i;
376 		ret = fd_package_recv(sock, fds + i, step);
377 		if (ret != 0) {
378 			/* Close all received descriptors. */
379 			serrno = errno;
380 			for (j = 0; j < i; j++)
381 				close(fds[j]);
382 			errno = serrno;
383 			break;
384 		}
385 		i += step;
386 	}
387 
388 	return (ret);
389 }
390 
391 int
392 fd_send(int sock, const int *fds, size_t nfds)
393 {
394 	unsigned int i, step;
395 	int ret;
396 
397 	if (nfds == 0 || fds == NULL) {
398 		errno = EINVAL;
399 		return (-1);
400 	}
401 
402 	ret = i = step = 0;
403 	while (i < nfds) {
404 		if (PKG_MAX_SIZE < nfds - i)
405 			step = PKG_MAX_SIZE;
406 		else
407 			step = nfds - i;
408 		ret = fd_package_send(sock, fds + i, step);
409 		if (ret != 0)
410 			break;
411 		i += step;
412 	}
413 
414 	return (ret);
415 }
416 
417 int
418 buf_send(int sock, void *buf, size_t size)
419 {
420 	ssize_t done;
421 	unsigned char *ptr;
422 
423 	PJDLOG_ASSERT(sock >= 0);
424 	PJDLOG_ASSERT(size > 0);
425 	PJDLOG_ASSERT(buf != NULL);
426 
427 	ptr = buf;
428 	do {
429 		fd_wait(sock, false);
430 		done = send(sock, ptr, size, 0);
431 		if (done == -1) {
432 			if (errno == EINTR)
433 				continue;
434 			return (-1);
435 		} else if (done == 0) {
436 			errno = ENOTCONN;
437 			return (-1);
438 		}
439 		size -= done;
440 		ptr += done;
441 	} while (size > 0);
442 
443 	return (0);
444 }
445 
446 int
447 buf_recv(int sock, void *buf, size_t size)
448 {
449 	ssize_t done;
450 	unsigned char *ptr;
451 
452 	PJDLOG_ASSERT(sock >= 0);
453 	PJDLOG_ASSERT(buf != NULL);
454 
455 	ptr = buf;
456 	while (size > 0) {
457 		fd_wait(sock, true);
458 		done = recv(sock, ptr, size, 0);
459 		if (done == -1) {
460 			if (errno == EINTR)
461 				continue;
462 			return (-1);
463 		} else if (done == 0) {
464 			errno = ENOTCONN;
465 			return (-1);
466 		}
467 		size -= done;
468 		ptr += done;
469 	}
470 
471 	return (0);
472 }
473