xref: /freebsd/lib/libnv/msgio.c (revision 52f45d8acee95199159b65a33c94142492c38e41)
1 /*-
2  * SPDX-License-Identifier: BSD-2-Clause-FreeBSD
3  *
4  * Copyright (c) 2013 The FreeBSD Foundation
5  * Copyright (c) 2013 Mariusz Zaborski <oshogbo@FreeBSD.org>
6  * All rights reserved.
7  *
8  * This software was developed by Pawel Jakub Dawidek under sponsorship from
9  * the FreeBSD Foundation.
10  *
11  * Redistribution and use in source and binary forms, with or without
12  * modification, are permitted provided that the following conditions
13  * are met:
14  * 1. Redistributions of source code must retain the above copyright
15  *    notice, this list of conditions and the following disclaimer.
16  * 2. Redistributions in binary form must reproduce the above copyright
17  *    notice, this list of conditions and the following disclaimer in the
18  *    documentation and/or other materials provided with the distribution.
19  *
20  * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
21  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
23  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
24  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
26  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
27  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
28  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
29  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
30  * SUCH DAMAGE.
31  */
32 
33 #include <sys/cdefs.h>
34 __FBSDID("$FreeBSD$");
35 
36 #include <sys/param.h>
37 #include <sys/socket.h>
38 #include <sys/select.h>
39 
40 #include <errno.h>
41 #include <fcntl.h>
42 #include <stdbool.h>
43 #include <stdint.h>
44 #include <stdlib.h>
45 #include <string.h>
46 #include <unistd.h>
47 
48 #ifdef HAVE_PJDLOG
49 #include <pjdlog.h>
50 #endif
51 
52 #include "common_impl.h"
53 #include "msgio.h"
54 
55 #ifndef	HAVE_PJDLOG
56 #include <assert.h>
57 #define	PJDLOG_ASSERT(...)		assert(__VA_ARGS__)
58 #define	PJDLOG_RASSERT(expr, ...)	assert(expr)
59 #define	PJDLOG_ABORT(...)		abort()
60 #endif
61 
62 #ifdef __linux__
63 /* Linux: arbitrary size, but must be lower than SCM_MAX_FD. */
64 #define	PKG_MAX_SIZE	((64U - 1) * CMSG_SPACE(sizeof(int)))
65 #else
66 /*
67  * To work around limitations in 32-bit emulation on 64-bit kernels, use a
68  * machine-independent limit on the number of FDs per message.  Each control
69  * message contains 1 FD and requires 12 bytes for the header, 4 pad bytes,
70  * 4 bytes for the descriptor, and another 4 pad bytes.
71  */
72 #define	PKG_MAX_SIZE	(MCLBYTES / 24)
73 #endif
74 
75 static int
76 msghdr_add_fd(struct cmsghdr *cmsg, int fd)
77 {
78 
79 	PJDLOG_ASSERT(fd >= 0);
80 
81 	cmsg->cmsg_level = SOL_SOCKET;
82 	cmsg->cmsg_type = SCM_RIGHTS;
83 	cmsg->cmsg_len = CMSG_LEN(sizeof(fd));
84 	bcopy(&fd, CMSG_DATA(cmsg), sizeof(fd));
85 
86 	return (0);
87 }
88 
89 static void
90 fd_wait(int fd, bool doread)
91 {
92 	fd_set fds;
93 
94 	PJDLOG_ASSERT(fd >= 0);
95 
96 	FD_ZERO(&fds);
97 	FD_SET(fd, &fds);
98 	(void)select(fd + 1, doread ? &fds : NULL, doread ? NULL : &fds,
99 	    NULL, NULL);
100 }
101 
102 static int
103 msg_recv(int sock, struct msghdr *msg)
104 {
105 	int flags;
106 
107 	PJDLOG_ASSERT(sock >= 0);
108 
109 #ifdef MSG_CMSG_CLOEXEC
110 	flags = MSG_CMSG_CLOEXEC;
111 #else
112 	flags = 0;
113 #endif
114 
115 	for (;;) {
116 		fd_wait(sock, true);
117 		if (recvmsg(sock, msg, flags) == -1) {
118 			if (errno == EINTR)
119 				continue;
120 			return (-1);
121 		}
122 		break;
123 	}
124 
125 	return (0);
126 }
127 
128 static int
129 msg_send(int sock, const struct msghdr *msg)
130 {
131 
132 	PJDLOG_ASSERT(sock >= 0);
133 
134 	for (;;) {
135 		fd_wait(sock, false);
136 		if (sendmsg(sock, msg, 0) == -1) {
137 			if (errno == EINTR)
138 				continue;
139 			return (-1);
140 		}
141 		break;
142 	}
143 
144 	return (0);
145 }
146 
147 #ifdef __FreeBSD__
148 int
149 cred_send(int sock)
150 {
151 	unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))];
152 	struct msghdr msg;
153 	struct cmsghdr *cmsg;
154 	struct iovec iov;
155 	uint8_t dummy;
156 
157 	bzero(credbuf, sizeof(credbuf));
158 	bzero(&msg, sizeof(msg));
159 	bzero(&iov, sizeof(iov));
160 
161 	/*
162 	 * XXX: We send one byte along with the control message, because
163 	 *      setting msg_iov to NULL only works if this is the first
164 	 *      packet send over the socket. Once we send some data we
165 	 *      won't be able to send credentials anymore. This is most
166 	 *      likely a kernel bug.
167 	 */
168 	dummy = 0;
169 	iov.iov_base = &dummy;
170 	iov.iov_len = sizeof(dummy);
171 
172 	msg.msg_iov = &iov;
173 	msg.msg_iovlen = 1;
174 	msg.msg_control = credbuf;
175 	msg.msg_controllen = sizeof(credbuf);
176 
177 	cmsg = CMSG_FIRSTHDR(&msg);
178 	cmsg->cmsg_len = CMSG_LEN(sizeof(struct cmsgcred));
179 	cmsg->cmsg_level = SOL_SOCKET;
180 	cmsg->cmsg_type = SCM_CREDS;
181 
182 	if (msg_send(sock, &msg) == -1)
183 		return (-1);
184 
185 	return (0);
186 }
187 
188 int
189 cred_recv(int sock, struct cmsgcred *cred)
190 {
191 	unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))];
192 	struct msghdr msg;
193 	struct cmsghdr *cmsg;
194 	struct iovec iov;
195 	uint8_t dummy;
196 
197 	bzero(credbuf, sizeof(credbuf));
198 	bzero(&msg, sizeof(msg));
199 	bzero(&iov, sizeof(iov));
200 
201 	iov.iov_base = &dummy;
202 	iov.iov_len = sizeof(dummy);
203 
204 	msg.msg_iov = &iov;
205 	msg.msg_iovlen = 1;
206 	msg.msg_control = credbuf;
207 	msg.msg_controllen = sizeof(credbuf);
208 
209 	if (msg_recv(sock, &msg) == -1)
210 		return (-1);
211 
212 	cmsg = CMSG_FIRSTHDR(&msg);
213 	if (cmsg == NULL ||
214 	    cmsg->cmsg_len != CMSG_LEN(sizeof(struct cmsgcred)) ||
215 	    cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_CREDS) {
216 		errno = EINVAL;
217 		return (-1);
218 	}
219 	bcopy(CMSG_DATA(cmsg), cred, sizeof(*cred));
220 
221 	return (0);
222 }
223 #endif
224 
225 static int
226 fd_package_send(int sock, const int *fds, size_t nfds)
227 {
228 	struct msghdr msg;
229 	struct cmsghdr *cmsg;
230 	struct iovec iov;
231 	unsigned int i;
232 	int serrno, ret;
233 	uint8_t dummy;
234 
235 	PJDLOG_ASSERT(sock >= 0);
236 	PJDLOG_ASSERT(fds != NULL);
237 	PJDLOG_ASSERT(nfds > 0);
238 
239 	bzero(&msg, sizeof(msg));
240 
241 	/*
242 	 * XXX: Look into cred_send function for more details.
243 	 */
244 	dummy = 0;
245 	iov.iov_base = &dummy;
246 	iov.iov_len = sizeof(dummy);
247 
248 	msg.msg_iov = &iov;
249 	msg.msg_iovlen = 1;
250 	msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int));
251 	msg.msg_control = calloc(1, msg.msg_controllen);
252 	if (msg.msg_control == NULL)
253 		return (-1);
254 
255 	ret = -1;
256 
257 	for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL;
258 	    i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) {
259 		if (msghdr_add_fd(cmsg, fds[i]) == -1)
260 			goto end;
261 	}
262 
263 	if (msg_send(sock, &msg) == -1)
264 		goto end;
265 
266 	ret = 0;
267 end:
268 	serrno = errno;
269 	free(msg.msg_control);
270 	errno = serrno;
271 	return (ret);
272 }
273 
274 static int
275 fd_package_recv(int sock, int *fds, size_t nfds)
276 {
277 	struct msghdr msg;
278 	struct cmsghdr *cmsg;
279 	unsigned int i;
280 	int serrno, ret;
281 	struct iovec iov;
282 	uint8_t dummy;
283 
284 	PJDLOG_ASSERT(sock >= 0);
285 	PJDLOG_ASSERT(nfds > 0);
286 	PJDLOG_ASSERT(fds != NULL);
287 
288 	bzero(&msg, sizeof(msg));
289 	bzero(&iov, sizeof(iov));
290 
291 	/*
292 	 * XXX: Look into cred_send function for more details.
293 	 */
294 	iov.iov_base = &dummy;
295 	iov.iov_len = sizeof(dummy);
296 
297 	msg.msg_iov = &iov;
298 	msg.msg_iovlen = 1;
299 	msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int));
300 	msg.msg_control = calloc(1, msg.msg_controllen);
301 	if (msg.msg_control == NULL)
302 		return (-1);
303 
304 	ret = -1;
305 
306 	if (msg_recv(sock, &msg) == -1)
307 		goto end;
308 
309 	i = 0;
310 	cmsg = CMSG_FIRSTHDR(&msg);
311 	while (cmsg && i < nfds) {
312 		unsigned int n;
313 
314 		if (cmsg->cmsg_level != SOL_SOCKET ||
315 		    cmsg->cmsg_type != SCM_RIGHTS) {
316 			errno = EINVAL;
317 			break;
318 		}
319 		n = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
320 		if (i + n > nfds) {
321 			errno = EINVAL;
322 			break;
323 		}
324 		bcopy(CMSG_DATA(cmsg), fds + i, sizeof(int) * n);
325 		cmsg = CMSG_NXTHDR(&msg, cmsg);
326 		i += n;
327 	}
328 
329 	if (cmsg != NULL || i < nfds) {
330 		unsigned int last;
331 
332 		/*
333 		 * We need to close all received descriptors, even if we have
334 		 * different control message (eg. SCM_CREDS) in between.
335 		 */
336 		last = i;
337 		for (i = 0; i < last; i++) {
338 			if (fds[i] >= 0) {
339 				close(fds[i]);
340 			}
341 		}
342 		errno = EINVAL;
343 		goto end;
344 	}
345 
346 #ifndef MSG_CMSG_CLOEXEC
347 	/*
348 	 * If the MSG_CMSG_CLOEXEC flag is not available we cannot set the
349 	 * close-on-exec flag atomically, but we still want to set it for
350 	 * consistency.
351 	 */
352 	for (i = 0; i < nfds; i++) {
353 		(void) fcntl(fds[i], F_SETFD, FD_CLOEXEC);
354 	}
355 #endif
356 
357 	ret = 0;
358 end:
359 	serrno = errno;
360 	free(msg.msg_control);
361 	errno = serrno;
362 	return (ret);
363 }
364 
365 int
366 fd_recv(int sock, int *fds, size_t nfds)
367 {
368 	unsigned int i, step, j;
369 	int ret, serrno;
370 
371 	if (nfds == 0 || fds == NULL) {
372 		errno = EINVAL;
373 		return (-1);
374 	}
375 
376 	ret = i = step = 0;
377 	while (i < nfds) {
378 		if (PKG_MAX_SIZE < nfds - i)
379 			step = PKG_MAX_SIZE;
380 		else
381 			step = nfds - i;
382 		ret = fd_package_recv(sock, fds + i, step);
383 		if (ret != 0) {
384 			/* Close all received descriptors. */
385 			serrno = errno;
386 			for (j = 0; j < i; j++)
387 				close(fds[j]);
388 			errno = serrno;
389 			break;
390 		}
391 		i += step;
392 	}
393 
394 	return (ret);
395 }
396 
397 int
398 fd_send(int sock, const int *fds, size_t nfds)
399 {
400 	unsigned int i, step;
401 	int ret;
402 
403 	if (nfds == 0 || fds == NULL) {
404 		errno = EINVAL;
405 		return (-1);
406 	}
407 
408 	ret = i = step = 0;
409 	while (i < nfds) {
410 		if (PKG_MAX_SIZE < nfds - i)
411 			step = PKG_MAX_SIZE;
412 		else
413 			step = nfds - i;
414 		ret = fd_package_send(sock, fds + i, step);
415 		if (ret != 0)
416 			break;
417 		i += step;
418 	}
419 
420 	return (ret);
421 }
422 
423 int
424 buf_send(int sock, void *buf, size_t size)
425 {
426 	ssize_t done;
427 	unsigned char *ptr;
428 
429 	PJDLOG_ASSERT(sock >= 0);
430 	PJDLOG_ASSERT(size > 0);
431 	PJDLOG_ASSERT(buf != NULL);
432 
433 	ptr = buf;
434 	do {
435 		fd_wait(sock, false);
436 		done = send(sock, ptr, size, 0);
437 		if (done == -1) {
438 			if (errno == EINTR)
439 				continue;
440 			return (-1);
441 		} else if (done == 0) {
442 			errno = ENOTCONN;
443 			return (-1);
444 		}
445 		size -= done;
446 		ptr += done;
447 	} while (size > 0);
448 
449 	return (0);
450 }
451 
452 int
453 buf_recv(int sock, void *buf, size_t size, int flags)
454 {
455 	ssize_t done;
456 	unsigned char *ptr;
457 
458 	PJDLOG_ASSERT(sock >= 0);
459 	PJDLOG_ASSERT(buf != NULL);
460 
461 	ptr = buf;
462 	while (size > 0) {
463 		fd_wait(sock, true);
464 		done = recv(sock, ptr, size, flags);
465 		if (done == -1) {
466 			if (errno == EINTR)
467 				continue;
468 			return (-1);
469 		} else if (done == 0) {
470 			errno = ENOTCONN;
471 			return (-1);
472 		}
473 		size -= done;
474 		ptr += done;
475 	}
476 
477 	return (0);
478 }
479