xref: /freebsd/lib/libnv/msgio.c (revision 6829dae12bb055451fa467da4589c43bd03b1e64)
1 /*-
2  * SPDX-License-Identifier: BSD-2-Clause-FreeBSD
3  *
4  * Copyright (c) 2013 The FreeBSD Foundation
5  * Copyright (c) 2013 Mariusz Zaborski <oshogbo@FreeBSD.org>
6  * All rights reserved.
7  *
8  * This software was developed by Pawel Jakub Dawidek under sponsorship from
9  * the FreeBSD Foundation.
10  *
11  * Redistribution and use in source and binary forms, with or without
12  * modification, are permitted provided that the following conditions
13  * are met:
14  * 1. Redistributions of source code must retain the above copyright
15  *    notice, this list of conditions and the following disclaimer.
16  * 2. Redistributions in binary form must reproduce the above copyright
17  *    notice, this list of conditions and the following disclaimer in the
18  *    documentation and/or other materials provided with the distribution.
19  *
20  * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
21  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
23  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
24  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
26  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
27  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
28  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
29  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
30  * SUCH DAMAGE.
31  */
32 
33 #include <sys/cdefs.h>
34 __FBSDID("$FreeBSD$");
35 
36 #include <sys/param.h>
37 #include <sys/socket.h>
38 
39 #include <errno.h>
40 #include <fcntl.h>
41 #include <stdbool.h>
42 #include <stdint.h>
43 #include <stdlib.h>
44 #include <string.h>
45 #include <unistd.h>
46 
47 #ifdef HAVE_PJDLOG
48 #include <pjdlog.h>
49 #endif
50 
51 #include "common_impl.h"
52 #include "msgio.h"
53 
54 #ifndef	HAVE_PJDLOG
55 #include <assert.h>
56 #define	PJDLOG_ASSERT(...)		assert(__VA_ARGS__)
57 #define	PJDLOG_RASSERT(expr, ...)	assert(expr)
58 #define	PJDLOG_ABORT(...)		abort()
59 #endif
60 
61 #define	PKG_MAX_SIZE	(MCLBYTES / CMSG_SPACE(sizeof(int)) - 1)
62 
63 static int
64 msghdr_add_fd(struct cmsghdr *cmsg, int fd)
65 {
66 
67 	PJDLOG_ASSERT(fd >= 0);
68 
69 	cmsg->cmsg_level = SOL_SOCKET;
70 	cmsg->cmsg_type = SCM_RIGHTS;
71 	cmsg->cmsg_len = CMSG_LEN(sizeof(fd));
72 	bcopy(&fd, CMSG_DATA(cmsg), sizeof(fd));
73 
74 	return (0);
75 }
76 
77 static int
78 msghdr_get_fd(struct cmsghdr *cmsg)
79 {
80 	int fd;
81 
82 	if (cmsg == NULL || cmsg->cmsg_level != SOL_SOCKET ||
83 	    cmsg->cmsg_type != SCM_RIGHTS ||
84 	    cmsg->cmsg_len != CMSG_LEN(sizeof(fd))) {
85 		errno = EINVAL;
86 		return (-1);
87 	}
88 
89 	bcopy(CMSG_DATA(cmsg), &fd, sizeof(fd));
90 #ifndef MSG_CMSG_CLOEXEC
91 	/*
92 	 * If the MSG_CMSG_CLOEXEC flag is not available we cannot set the
93 	 * close-on-exec flag atomically, but we still want to set it for
94 	 * consistency.
95 	 */
96 	(void) fcntl(fd, F_SETFD, FD_CLOEXEC);
97 #endif
98 
99 	return (fd);
100 }
101 
102 static void
103 fd_wait(int fd, bool doread)
104 {
105 	fd_set fds;
106 
107 	PJDLOG_ASSERT(fd >= 0);
108 
109 	FD_ZERO(&fds);
110 	FD_SET(fd, &fds);
111 	(void)select(fd + 1, doread ? &fds : NULL, doread ? NULL : &fds,
112 	    NULL, NULL);
113 }
114 
115 static int
116 msg_recv(int sock, struct msghdr *msg)
117 {
118 	int flags;
119 
120 	PJDLOG_ASSERT(sock >= 0);
121 
122 #ifdef MSG_CMSG_CLOEXEC
123 	flags = MSG_CMSG_CLOEXEC;
124 #else
125 	flags = 0;
126 #endif
127 
128 	for (;;) {
129 		fd_wait(sock, true);
130 		if (recvmsg(sock, msg, flags) == -1) {
131 			if (errno == EINTR)
132 				continue;
133 			return (-1);
134 		}
135 		break;
136 	}
137 
138 	return (0);
139 }
140 
141 static int
142 msg_send(int sock, const struct msghdr *msg)
143 {
144 
145 	PJDLOG_ASSERT(sock >= 0);
146 
147 	for (;;) {
148 		fd_wait(sock, false);
149 		if (sendmsg(sock, msg, 0) == -1) {
150 			if (errno == EINTR)
151 				continue;
152 			return (-1);
153 		}
154 		break;
155 	}
156 
157 	return (0);
158 }
159 
160 /*
161  * MacOS/Linux do not define struct cmsgcred but we need to bootstrap libnv
162  * when building on non-FreeBSD systems. Since they are not used during
163  * bootstrap we can just omit these two functions there.
164  */
165 #ifndef __FreeBSD__
166 #warning "cred_send() not supported on non-FreeBSD systems"
167 #else
168 int
169 cred_send(int sock)
170 {
171 	unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))];
172 	struct msghdr msg;
173 	struct cmsghdr *cmsg;
174 	struct iovec iov;
175 	uint8_t dummy;
176 
177 	bzero(credbuf, sizeof(credbuf));
178 	bzero(&msg, sizeof(msg));
179 	bzero(&iov, sizeof(iov));
180 
181 	/*
182 	 * XXX: We send one byte along with the control message, because
183 	 *      setting msg_iov to NULL only works if this is the first
184 	 *      packet send over the socket. Once we send some data we
185 	 *      won't be able to send credentials anymore. This is most
186 	 *      likely a kernel bug.
187 	 */
188 	dummy = 0;
189 	iov.iov_base = &dummy;
190 	iov.iov_len = sizeof(dummy);
191 
192 	msg.msg_iov = &iov;
193 	msg.msg_iovlen = 1;
194 	msg.msg_control = credbuf;
195 	msg.msg_controllen = sizeof(credbuf);
196 
197 	cmsg = CMSG_FIRSTHDR(&msg);
198 	cmsg->cmsg_len = CMSG_LEN(sizeof(struct cmsgcred));
199 	cmsg->cmsg_level = SOL_SOCKET;
200 	cmsg->cmsg_type = SCM_CREDS;
201 
202 	if (msg_send(sock, &msg) == -1)
203 		return (-1);
204 
205 	return (0);
206 }
207 
208 int
209 cred_recv(int sock, struct cmsgcred *cred)
210 {
211 	unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))];
212 	struct msghdr msg;
213 	struct cmsghdr *cmsg;
214 	struct iovec iov;
215 	uint8_t dummy;
216 
217 	bzero(credbuf, sizeof(credbuf));
218 	bzero(&msg, sizeof(msg));
219 	bzero(&iov, sizeof(iov));
220 
221 	iov.iov_base = &dummy;
222 	iov.iov_len = sizeof(dummy);
223 
224 	msg.msg_iov = &iov;
225 	msg.msg_iovlen = 1;
226 	msg.msg_control = credbuf;
227 	msg.msg_controllen = sizeof(credbuf);
228 
229 	if (msg_recv(sock, &msg) == -1)
230 		return (-1);
231 
232 	cmsg = CMSG_FIRSTHDR(&msg);
233 	if (cmsg == NULL ||
234 	    cmsg->cmsg_len != CMSG_LEN(sizeof(struct cmsgcred)) ||
235 	    cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_CREDS) {
236 		errno = EINVAL;
237 		return (-1);
238 	}
239 	bcopy(CMSG_DATA(cmsg), cred, sizeof(*cred));
240 
241 	return (0);
242 }
243 #endif
244 
245 static int
246 fd_package_send(int sock, const int *fds, size_t nfds)
247 {
248 	struct msghdr msg;
249 	struct cmsghdr *cmsg;
250 	struct iovec iov;
251 	unsigned int i;
252 	int serrno, ret;
253 	uint8_t dummy;
254 
255 	PJDLOG_ASSERT(sock >= 0);
256 	PJDLOG_ASSERT(fds != NULL);
257 	PJDLOG_ASSERT(nfds > 0);
258 
259 	bzero(&msg, sizeof(msg));
260 
261 	/*
262 	 * XXX: Look into cred_send function for more details.
263 	 */
264 	dummy = 0;
265 	iov.iov_base = &dummy;
266 	iov.iov_len = sizeof(dummy);
267 
268 	msg.msg_iov = &iov;
269 	msg.msg_iovlen = 1;
270 	msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int));
271 	msg.msg_control = calloc(1, msg.msg_controllen);
272 	if (msg.msg_control == NULL)
273 		return (-1);
274 
275 	ret = -1;
276 
277 	for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL;
278 	    i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) {
279 		if (msghdr_add_fd(cmsg, fds[i]) == -1)
280 			goto end;
281 	}
282 
283 	if (msg_send(sock, &msg) == -1)
284 		goto end;
285 
286 	ret = 0;
287 end:
288 	serrno = errno;
289 	free(msg.msg_control);
290 	errno = serrno;
291 	return (ret);
292 }
293 
294 static int
295 fd_package_recv(int sock, int *fds, size_t nfds)
296 {
297 	struct msghdr msg;
298 	struct cmsghdr *cmsg;
299 	unsigned int i;
300 	int serrno, ret;
301 	struct iovec iov;
302 	uint8_t dummy;
303 
304 	PJDLOG_ASSERT(sock >= 0);
305 	PJDLOG_ASSERT(nfds > 0);
306 	PJDLOG_ASSERT(fds != NULL);
307 
308 	bzero(&msg, sizeof(msg));
309 	bzero(&iov, sizeof(iov));
310 
311 	/*
312 	 * XXX: Look into cred_send function for more details.
313 	 */
314 	iov.iov_base = &dummy;
315 	iov.iov_len = sizeof(dummy);
316 
317 	msg.msg_iov = &iov;
318 	msg.msg_iovlen = 1;
319 	msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int));
320 	msg.msg_control = calloc(1, msg.msg_controllen);
321 	if (msg.msg_control == NULL)
322 		return (-1);
323 
324 	ret = -1;
325 
326 	if (msg_recv(sock, &msg) == -1)
327 		goto end;
328 
329 	for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL;
330 	    i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) {
331 		fds[i] = msghdr_get_fd(cmsg);
332 		if (fds[i] < 0)
333 			break;
334 	}
335 
336 	if (cmsg != NULL || i < nfds) {
337 		int fd;
338 
339 		/*
340 		 * We need to close all received descriptors, even if we have
341 		 * different control message (eg. SCM_CREDS) in between.
342 		 */
343 		for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL;
344 		    cmsg = CMSG_NXTHDR(&msg, cmsg)) {
345 			fd = msghdr_get_fd(cmsg);
346 			if (fd >= 0)
347 				close(fd);
348 		}
349 		errno = EINVAL;
350 		goto end;
351 	}
352 
353 	ret = 0;
354 end:
355 	serrno = errno;
356 	free(msg.msg_control);
357 	errno = serrno;
358 	return (ret);
359 }
360 
361 int
362 fd_recv(int sock, int *fds, size_t nfds)
363 {
364 	unsigned int i, step, j;
365 	int ret, serrno;
366 
367 	if (nfds == 0 || fds == NULL) {
368 		errno = EINVAL;
369 		return (-1);
370 	}
371 
372 	ret = i = step = 0;
373 	while (i < nfds) {
374 		if (PKG_MAX_SIZE < nfds - i)
375 			step = PKG_MAX_SIZE;
376 		else
377 			step = nfds - i;
378 		ret = fd_package_recv(sock, fds + i, step);
379 		if (ret != 0) {
380 			/* Close all received descriptors. */
381 			serrno = errno;
382 			for (j = 0; j < i; j++)
383 				close(fds[j]);
384 			errno = serrno;
385 			break;
386 		}
387 		i += step;
388 	}
389 
390 	return (ret);
391 }
392 
393 int
394 fd_send(int sock, const int *fds, size_t nfds)
395 {
396 	unsigned int i, step;
397 	int ret;
398 
399 	if (nfds == 0 || fds == NULL) {
400 		errno = EINVAL;
401 		return (-1);
402 	}
403 
404 	ret = i = step = 0;
405 	while (i < nfds) {
406 		if (PKG_MAX_SIZE < nfds - i)
407 			step = PKG_MAX_SIZE;
408 		else
409 			step = nfds - i;
410 		ret = fd_package_send(sock, fds + i, step);
411 		if (ret != 0)
412 			break;
413 		i += step;
414 	}
415 
416 	return (ret);
417 }
418 
419 int
420 buf_send(int sock, void *buf, size_t size)
421 {
422 	ssize_t done;
423 	unsigned char *ptr;
424 
425 	PJDLOG_ASSERT(sock >= 0);
426 	PJDLOG_ASSERT(size > 0);
427 	PJDLOG_ASSERT(buf != NULL);
428 
429 	ptr = buf;
430 	do {
431 		fd_wait(sock, false);
432 		done = send(sock, ptr, size, 0);
433 		if (done == -1) {
434 			if (errno == EINTR)
435 				continue;
436 			return (-1);
437 		} else if (done == 0) {
438 			errno = ENOTCONN;
439 			return (-1);
440 		}
441 		size -= done;
442 		ptr += done;
443 	} while (size > 0);
444 
445 	return (0);
446 }
447 
448 int
449 buf_recv(int sock, void *buf, size_t size)
450 {
451 	ssize_t done;
452 	unsigned char *ptr;
453 
454 	PJDLOG_ASSERT(sock >= 0);
455 	PJDLOG_ASSERT(buf != NULL);
456 
457 	ptr = buf;
458 	while (size > 0) {
459 		fd_wait(sock, true);
460 		done = recv(sock, ptr, size, 0);
461 		if (done == -1) {
462 			if (errno == EINTR)
463 				continue;
464 			return (-1);
465 		} else if (done == 0) {
466 			errno = ENOTCONN;
467 			return (-1);
468 		}
469 		size -= done;
470 		ptr += done;
471 	}
472 
473 	return (0);
474 }
475