xref: /freebsd/lib/libnv/msgio.c (revision fd253945ac76a54ff9c11cf02f5458561f711866)
1 /*-
2  * SPDX-License-Identifier: BSD-2-Clause-FreeBSD
3  *
4  * Copyright (c) 2013 The FreeBSD Foundation
5  * Copyright (c) 2013 Mariusz Zaborski <oshogbo@FreeBSD.org>
6  * All rights reserved.
7  *
8  * This software was developed by Pawel Jakub Dawidek under sponsorship from
9  * the FreeBSD Foundation.
10  *
11  * Redistribution and use in source and binary forms, with or without
12  * modification, are permitted provided that the following conditions
13  * are met:
14  * 1. Redistributions of source code must retain the above copyright
15  *    notice, this list of conditions and the following disclaimer.
16  * 2. Redistributions in binary form must reproduce the above copyright
17  *    notice, this list of conditions and the following disclaimer in the
18  *    documentation and/or other materials provided with the distribution.
19  *
20  * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
21  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
23  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
24  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
26  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
27  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
28  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
29  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
30  * SUCH DAMAGE.
31  */
32 
33 #include <sys/cdefs.h>
34 __FBSDID("$FreeBSD$");
35 
36 #include <sys/param.h>
37 #include <sys/socket.h>
38 
39 #include <errno.h>
40 #include <fcntl.h>
41 #include <stdbool.h>
42 #include <stdint.h>
43 #include <stdlib.h>
44 #include <string.h>
45 #include <unistd.h>
46 
47 #ifdef HAVE_PJDLOG
48 #include <pjdlog.h>
49 #endif
50 
51 #include "common_impl.h"
52 #include "msgio.h"
53 
54 #ifndef	HAVE_PJDLOG
55 #include <assert.h>
56 #define	PJDLOG_ASSERT(...)		assert(__VA_ARGS__)
57 #define	PJDLOG_RASSERT(expr, ...)	assert(expr)
58 #define	PJDLOG_ABORT(...)		abort()
59 #endif
60 
61 #define	PKG_MAX_SIZE	(MCLBYTES / CMSG_SPACE(sizeof(int)) - 1)
62 
63 static int
64 msghdr_add_fd(struct cmsghdr *cmsg, int fd)
65 {
66 
67 	PJDLOG_ASSERT(fd >= 0);
68 
69 	cmsg->cmsg_level = SOL_SOCKET;
70 	cmsg->cmsg_type = SCM_RIGHTS;
71 	cmsg->cmsg_len = CMSG_LEN(sizeof(fd));
72 	bcopy(&fd, CMSG_DATA(cmsg), sizeof(fd));
73 
74 	return (0);
75 }
76 
77 static int
78 msghdr_get_fd(struct cmsghdr *cmsg)
79 {
80 	int fd;
81 
82 	if (cmsg == NULL || cmsg->cmsg_level != SOL_SOCKET ||
83 	    cmsg->cmsg_type != SCM_RIGHTS ||
84 	    cmsg->cmsg_len != CMSG_LEN(sizeof(fd))) {
85 		errno = EINVAL;
86 		return (-1);
87 	}
88 
89 	bcopy(CMSG_DATA(cmsg), &fd, sizeof(fd));
90 #ifndef MSG_CMSG_CLOEXEC
91 	/*
92 	 * If the MSG_CMSG_CLOEXEC flag is not available we cannot set the
93 	 * close-on-exec flag atomically, but we still want to set it for
94 	 * consistency.
95 	 */
96 	(void) fcntl(fd, F_SETFD, FD_CLOEXEC);
97 #endif
98 
99 	return (fd);
100 }
101 
102 static void
103 fd_wait(int fd, bool doread)
104 {
105 	fd_set fds;
106 
107 	PJDLOG_ASSERT(fd >= 0);
108 
109 	FD_ZERO(&fds);
110 	FD_SET(fd, &fds);
111 	(void)select(fd + 1, doread ? &fds : NULL, doread ? NULL : &fds,
112 	    NULL, NULL);
113 }
114 
115 static int
116 msg_recv(int sock, struct msghdr *msg)
117 {
118 	int flags;
119 
120 	PJDLOG_ASSERT(sock >= 0);
121 
122 #ifdef MSG_CMSG_CLOEXEC
123 	flags = MSG_CMSG_CLOEXEC;
124 #else
125 	flags = 0;
126 #endif
127 
128 	for (;;) {
129 		fd_wait(sock, true);
130 		if (recvmsg(sock, msg, flags) == -1) {
131 			if (errno == EINTR)
132 				continue;
133 			return (-1);
134 		}
135 		break;
136 	}
137 
138 	return (0);
139 }
140 
141 static int
142 msg_send(int sock, const struct msghdr *msg)
143 {
144 
145 	PJDLOG_ASSERT(sock >= 0);
146 
147 	for (;;) {
148 		fd_wait(sock, false);
149 		if (sendmsg(sock, msg, 0) == -1) {
150 			if (errno == EINTR)
151 				continue;
152 			return (-1);
153 		}
154 		break;
155 	}
156 
157 	return (0);
158 }
159 
160 int
161 cred_send(int sock)
162 {
163 	unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))];
164 	struct msghdr msg;
165 	struct cmsghdr *cmsg;
166 	struct iovec iov;
167 	uint8_t dummy;
168 
169 	bzero(credbuf, sizeof(credbuf));
170 	bzero(&msg, sizeof(msg));
171 	bzero(&iov, sizeof(iov));
172 
173 	/*
174 	 * XXX: We send one byte along with the control message, because
175 	 *      setting msg_iov to NULL only works if this is the first
176 	 *      packet send over the socket. Once we send some data we
177 	 *      won't be able to send credentials anymore. This is most
178 	 *      likely a kernel bug.
179 	 */
180 	dummy = 0;
181 	iov.iov_base = &dummy;
182 	iov.iov_len = sizeof(dummy);
183 
184 	msg.msg_iov = &iov;
185 	msg.msg_iovlen = 1;
186 	msg.msg_control = credbuf;
187 	msg.msg_controllen = sizeof(credbuf);
188 
189 	cmsg = CMSG_FIRSTHDR(&msg);
190 	cmsg->cmsg_len = CMSG_LEN(sizeof(struct cmsgcred));
191 	cmsg->cmsg_level = SOL_SOCKET;
192 	cmsg->cmsg_type = SCM_CREDS;
193 
194 	if (msg_send(sock, &msg) == -1)
195 		return (-1);
196 
197 	return (0);
198 }
199 
200 int
201 cred_recv(int sock, struct cmsgcred *cred)
202 {
203 	unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))];
204 	struct msghdr msg;
205 	struct cmsghdr *cmsg;
206 	struct iovec iov;
207 	uint8_t dummy;
208 
209 	bzero(credbuf, sizeof(credbuf));
210 	bzero(&msg, sizeof(msg));
211 	bzero(&iov, sizeof(iov));
212 
213 	iov.iov_base = &dummy;
214 	iov.iov_len = sizeof(dummy);
215 
216 	msg.msg_iov = &iov;
217 	msg.msg_iovlen = 1;
218 	msg.msg_control = credbuf;
219 	msg.msg_controllen = sizeof(credbuf);
220 
221 	if (msg_recv(sock, &msg) == -1)
222 		return (-1);
223 
224 	cmsg = CMSG_FIRSTHDR(&msg);
225 	if (cmsg == NULL ||
226 	    cmsg->cmsg_len != CMSG_LEN(sizeof(struct cmsgcred)) ||
227 	    cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_CREDS) {
228 		errno = EINVAL;
229 		return (-1);
230 	}
231 	bcopy(CMSG_DATA(cmsg), cred, sizeof(*cred));
232 
233 	return (0);
234 }
235 
236 static int
237 fd_package_send(int sock, const int *fds, size_t nfds)
238 {
239 	struct msghdr msg;
240 	struct cmsghdr *cmsg;
241 	struct iovec iov;
242 	unsigned int i;
243 	int serrno, ret;
244 	uint8_t dummy;
245 
246 	PJDLOG_ASSERT(sock >= 0);
247 	PJDLOG_ASSERT(fds != NULL);
248 	PJDLOG_ASSERT(nfds > 0);
249 
250 	bzero(&msg, sizeof(msg));
251 
252 	/*
253 	 * XXX: Look into cred_send function for more details.
254 	 */
255 	dummy = 0;
256 	iov.iov_base = &dummy;
257 	iov.iov_len = sizeof(dummy);
258 
259 	msg.msg_iov = &iov;
260 	msg.msg_iovlen = 1;
261 	msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int));
262 	msg.msg_control = calloc(1, msg.msg_controllen);
263 	if (msg.msg_control == NULL)
264 		return (-1);
265 
266 	ret = -1;
267 
268 	for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL;
269 	    i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) {
270 		if (msghdr_add_fd(cmsg, fds[i]) == -1)
271 			goto end;
272 	}
273 
274 	if (msg_send(sock, &msg) == -1)
275 		goto end;
276 
277 	ret = 0;
278 end:
279 	serrno = errno;
280 	free(msg.msg_control);
281 	errno = serrno;
282 	return (ret);
283 }
284 
285 static int
286 fd_package_recv(int sock, int *fds, size_t nfds)
287 {
288 	struct msghdr msg;
289 	struct cmsghdr *cmsg;
290 	unsigned int i;
291 	int serrno, ret;
292 	struct iovec iov;
293 	uint8_t dummy;
294 
295 	PJDLOG_ASSERT(sock >= 0);
296 	PJDLOG_ASSERT(nfds > 0);
297 	PJDLOG_ASSERT(fds != NULL);
298 
299 	bzero(&msg, sizeof(msg));
300 	bzero(&iov, sizeof(iov));
301 
302 	/*
303 	 * XXX: Look into cred_send function for more details.
304 	 */
305 	iov.iov_base = &dummy;
306 	iov.iov_len = sizeof(dummy);
307 
308 	msg.msg_iov = &iov;
309 	msg.msg_iovlen = 1;
310 	msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int));
311 	msg.msg_control = calloc(1, msg.msg_controllen);
312 	if (msg.msg_control == NULL)
313 		return (-1);
314 
315 	ret = -1;
316 
317 	if (msg_recv(sock, &msg) == -1)
318 		goto end;
319 
320 	for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL;
321 	    i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) {
322 		fds[i] = msghdr_get_fd(cmsg);
323 		if (fds[i] < 0)
324 			break;
325 	}
326 
327 	if (cmsg != NULL || i < nfds) {
328 		int fd;
329 
330 		/*
331 		 * We need to close all received descriptors, even if we have
332 		 * different control message (eg. SCM_CREDS) in between.
333 		 */
334 		for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL;
335 		    cmsg = CMSG_NXTHDR(&msg, cmsg)) {
336 			fd = msghdr_get_fd(cmsg);
337 			if (fd >= 0)
338 				close(fd);
339 		}
340 		errno = EINVAL;
341 		goto end;
342 	}
343 
344 	ret = 0;
345 end:
346 	serrno = errno;
347 	free(msg.msg_control);
348 	errno = serrno;
349 	return (ret);
350 }
351 
352 int
353 fd_recv(int sock, int *fds, size_t nfds)
354 {
355 	unsigned int i, step, j;
356 	int ret, serrno;
357 
358 	if (nfds == 0 || fds == NULL) {
359 		errno = EINVAL;
360 		return (-1);
361 	}
362 
363 	ret = i = step = 0;
364 	while (i < nfds) {
365 		if (PKG_MAX_SIZE < nfds - i)
366 			step = PKG_MAX_SIZE;
367 		else
368 			step = nfds - i;
369 		ret = fd_package_recv(sock, fds + i, step);
370 		if (ret != 0) {
371 			/* Close all received descriptors. */
372 			serrno = errno;
373 			for (j = 0; j < i; j++)
374 				close(fds[j]);
375 			errno = serrno;
376 			break;
377 		}
378 		i += step;
379 	}
380 
381 	return (ret);
382 }
383 
384 int
385 fd_send(int sock, const int *fds, size_t nfds)
386 {
387 	unsigned int i, step;
388 	int ret;
389 
390 	if (nfds == 0 || fds == NULL) {
391 		errno = EINVAL;
392 		return (-1);
393 	}
394 
395 	ret = i = step = 0;
396 	while (i < nfds) {
397 		if (PKG_MAX_SIZE < nfds - i)
398 			step = PKG_MAX_SIZE;
399 		else
400 			step = nfds - i;
401 		ret = fd_package_send(sock, fds + i, step);
402 		if (ret != 0)
403 			break;
404 		i += step;
405 	}
406 
407 	return (ret);
408 }
409 
410 int
411 buf_send(int sock, void *buf, size_t size)
412 {
413 	ssize_t done;
414 	unsigned char *ptr;
415 
416 	PJDLOG_ASSERT(sock >= 0);
417 	PJDLOG_ASSERT(size > 0);
418 	PJDLOG_ASSERT(buf != NULL);
419 
420 	ptr = buf;
421 	do {
422 		fd_wait(sock, false);
423 		done = send(sock, ptr, size, 0);
424 		if (done == -1) {
425 			if (errno == EINTR)
426 				continue;
427 			return (-1);
428 		} else if (done == 0) {
429 			errno = ENOTCONN;
430 			return (-1);
431 		}
432 		size -= done;
433 		ptr += done;
434 	} while (size > 0);
435 
436 	return (0);
437 }
438 
439 int
440 buf_recv(int sock, void *buf, size_t size)
441 {
442 	ssize_t done;
443 	unsigned char *ptr;
444 
445 	PJDLOG_ASSERT(sock >= 0);
446 	PJDLOG_ASSERT(buf != NULL);
447 
448 	ptr = buf;
449 	while (size > 0) {
450 		fd_wait(sock, true);
451 		done = recv(sock, ptr, size, 0);
452 		if (done == -1) {
453 			if (errno == EINTR)
454 				continue;
455 			return (-1);
456 		} else if (done == 0) {
457 			errno = ENOTCONN;
458 			return (-1);
459 		}
460 		size -= done;
461 		ptr += done;
462 	}
463 
464 	return (0);
465 }
466