1 /*- 2 * SPDX-License-Identifier: BSD-2-Clause-FreeBSD 3 * 4 * Copyright (c) 2013 The FreeBSD Foundation 5 * Copyright (c) 2013 Mariusz Zaborski <oshogbo@FreeBSD.org> 6 * All rights reserved. 7 * 8 * This software was developed by Pawel Jakub Dawidek under sponsorship from 9 * the FreeBSD Foundation. 10 * 11 * Redistribution and use in source and binary forms, with or without 12 * modification, are permitted provided that the following conditions 13 * are met: 14 * 1. Redistributions of source code must retain the above copyright 15 * notice, this list of conditions and the following disclaimer. 16 * 2. Redistributions in binary form must reproduce the above copyright 17 * notice, this list of conditions and the following disclaimer in the 18 * documentation and/or other materials provided with the distribution. 19 * 20 * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND 21 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 23 * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE 24 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS 26 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 27 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 28 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY 29 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF 30 * SUCH DAMAGE. 31 */ 32 33 #include <sys/cdefs.h> 34 __FBSDID("$FreeBSD$"); 35 36 #include <sys/param.h> 37 #include <sys/socket.h> 38 39 #include <errno.h> 40 #include <fcntl.h> 41 #include <stdbool.h> 42 #include <stdint.h> 43 #include <stdlib.h> 44 #include <string.h> 45 #include <unistd.h> 46 47 #ifdef HAVE_PJDLOG 48 #include <pjdlog.h> 49 #endif 50 51 #include "common_impl.h" 52 #include "msgio.h" 53 54 #ifndef HAVE_PJDLOG 55 #include <assert.h> 56 #define PJDLOG_ASSERT(...) assert(__VA_ARGS__) 57 #define PJDLOG_RASSERT(expr, ...) assert(expr) 58 #define PJDLOG_ABORT(...) abort() 59 #endif 60 61 #define PKG_MAX_SIZE (MCLBYTES / CMSG_SPACE(sizeof(int)) - 1) 62 63 static int 64 msghdr_add_fd(struct cmsghdr *cmsg, int fd) 65 { 66 67 PJDLOG_ASSERT(fd >= 0); 68 69 cmsg->cmsg_level = SOL_SOCKET; 70 cmsg->cmsg_type = SCM_RIGHTS; 71 cmsg->cmsg_len = CMSG_LEN(sizeof(fd)); 72 bcopy(&fd, CMSG_DATA(cmsg), sizeof(fd)); 73 74 return (0); 75 } 76 77 static int 78 msghdr_get_fd(struct cmsghdr *cmsg) 79 { 80 int fd; 81 82 if (cmsg == NULL || cmsg->cmsg_level != SOL_SOCKET || 83 cmsg->cmsg_type != SCM_RIGHTS || 84 cmsg->cmsg_len != CMSG_LEN(sizeof(fd))) { 85 errno = EINVAL; 86 return (-1); 87 } 88 89 bcopy(CMSG_DATA(cmsg), &fd, sizeof(fd)); 90 #ifndef MSG_CMSG_CLOEXEC 91 /* 92 * If the MSG_CMSG_CLOEXEC flag is not available we cannot set the 93 * close-on-exec flag atomically, but we still want to set it for 94 * consistency. 95 */ 96 (void) fcntl(fd, F_SETFD, FD_CLOEXEC); 97 #endif 98 99 return (fd); 100 } 101 102 static void 103 fd_wait(int fd, bool doread) 104 { 105 fd_set fds; 106 107 PJDLOG_ASSERT(fd >= 0); 108 109 FD_ZERO(&fds); 110 FD_SET(fd, &fds); 111 (void)select(fd + 1, doread ? &fds : NULL, doread ? NULL : &fds, 112 NULL, NULL); 113 } 114 115 static int 116 msg_recv(int sock, struct msghdr *msg) 117 { 118 int flags; 119 120 PJDLOG_ASSERT(sock >= 0); 121 122 #ifdef MSG_CMSG_CLOEXEC 123 flags = MSG_CMSG_CLOEXEC; 124 #else 125 flags = 0; 126 #endif 127 128 for (;;) { 129 fd_wait(sock, true); 130 if (recvmsg(sock, msg, flags) == -1) { 131 if (errno == EINTR) 132 continue; 133 return (-1); 134 } 135 break; 136 } 137 138 return (0); 139 } 140 141 static int 142 msg_send(int sock, const struct msghdr *msg) 143 { 144 145 PJDLOG_ASSERT(sock >= 0); 146 147 for (;;) { 148 fd_wait(sock, false); 149 if (sendmsg(sock, msg, 0) == -1) { 150 if (errno == EINTR) 151 continue; 152 return (-1); 153 } 154 break; 155 } 156 157 return (0); 158 } 159 160 /* 161 * MacOS/Linux do not define struct cmsgcred but we need to bootstrap libnv 162 * when building on non-FreeBSD systems. Since they are not used during 163 * bootstrap we can just omit these two functions there. 164 */ 165 #ifndef __FreeBSD__ 166 #warning "cred_send() not supported on non-FreeBSD systems" 167 #else 168 int 169 cred_send(int sock) 170 { 171 unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))]; 172 struct msghdr msg; 173 struct cmsghdr *cmsg; 174 struct iovec iov; 175 uint8_t dummy; 176 177 bzero(credbuf, sizeof(credbuf)); 178 bzero(&msg, sizeof(msg)); 179 bzero(&iov, sizeof(iov)); 180 181 /* 182 * XXX: We send one byte along with the control message, because 183 * setting msg_iov to NULL only works if this is the first 184 * packet send over the socket. Once we send some data we 185 * won't be able to send credentials anymore. This is most 186 * likely a kernel bug. 187 */ 188 dummy = 0; 189 iov.iov_base = &dummy; 190 iov.iov_len = sizeof(dummy); 191 192 msg.msg_iov = &iov; 193 msg.msg_iovlen = 1; 194 msg.msg_control = credbuf; 195 msg.msg_controllen = sizeof(credbuf); 196 197 cmsg = CMSG_FIRSTHDR(&msg); 198 cmsg->cmsg_len = CMSG_LEN(sizeof(struct cmsgcred)); 199 cmsg->cmsg_level = SOL_SOCKET; 200 cmsg->cmsg_type = SCM_CREDS; 201 202 if (msg_send(sock, &msg) == -1) 203 return (-1); 204 205 return (0); 206 } 207 208 int 209 cred_recv(int sock, struct cmsgcred *cred) 210 { 211 unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))]; 212 struct msghdr msg; 213 struct cmsghdr *cmsg; 214 struct iovec iov; 215 uint8_t dummy; 216 217 bzero(credbuf, sizeof(credbuf)); 218 bzero(&msg, sizeof(msg)); 219 bzero(&iov, sizeof(iov)); 220 221 iov.iov_base = &dummy; 222 iov.iov_len = sizeof(dummy); 223 224 msg.msg_iov = &iov; 225 msg.msg_iovlen = 1; 226 msg.msg_control = credbuf; 227 msg.msg_controllen = sizeof(credbuf); 228 229 if (msg_recv(sock, &msg) == -1) 230 return (-1); 231 232 cmsg = CMSG_FIRSTHDR(&msg); 233 if (cmsg == NULL || 234 cmsg->cmsg_len != CMSG_LEN(sizeof(struct cmsgcred)) || 235 cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_CREDS) { 236 errno = EINVAL; 237 return (-1); 238 } 239 bcopy(CMSG_DATA(cmsg), cred, sizeof(*cred)); 240 241 return (0); 242 } 243 #endif 244 245 static int 246 fd_package_send(int sock, const int *fds, size_t nfds) 247 { 248 struct msghdr msg; 249 struct cmsghdr *cmsg; 250 struct iovec iov; 251 unsigned int i; 252 int serrno, ret; 253 uint8_t dummy; 254 255 PJDLOG_ASSERT(sock >= 0); 256 PJDLOG_ASSERT(fds != NULL); 257 PJDLOG_ASSERT(nfds > 0); 258 259 bzero(&msg, sizeof(msg)); 260 261 /* 262 * XXX: Look into cred_send function for more details. 263 */ 264 dummy = 0; 265 iov.iov_base = &dummy; 266 iov.iov_len = sizeof(dummy); 267 268 msg.msg_iov = &iov; 269 msg.msg_iovlen = 1; 270 msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int)); 271 msg.msg_control = calloc(1, msg.msg_controllen); 272 if (msg.msg_control == NULL) 273 return (-1); 274 275 ret = -1; 276 277 for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL; 278 i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) { 279 if (msghdr_add_fd(cmsg, fds[i]) == -1) 280 goto end; 281 } 282 283 if (msg_send(sock, &msg) == -1) 284 goto end; 285 286 ret = 0; 287 end: 288 serrno = errno; 289 free(msg.msg_control); 290 errno = serrno; 291 return (ret); 292 } 293 294 static int 295 fd_package_recv(int sock, int *fds, size_t nfds) 296 { 297 struct msghdr msg; 298 struct cmsghdr *cmsg; 299 unsigned int i; 300 int serrno, ret; 301 struct iovec iov; 302 uint8_t dummy; 303 304 PJDLOG_ASSERT(sock >= 0); 305 PJDLOG_ASSERT(nfds > 0); 306 PJDLOG_ASSERT(fds != NULL); 307 308 bzero(&msg, sizeof(msg)); 309 bzero(&iov, sizeof(iov)); 310 311 /* 312 * XXX: Look into cred_send function for more details. 313 */ 314 iov.iov_base = &dummy; 315 iov.iov_len = sizeof(dummy); 316 317 msg.msg_iov = &iov; 318 msg.msg_iovlen = 1; 319 msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int)); 320 msg.msg_control = calloc(1, msg.msg_controllen); 321 if (msg.msg_control == NULL) 322 return (-1); 323 324 ret = -1; 325 326 if (msg_recv(sock, &msg) == -1) 327 goto end; 328 329 for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL; 330 i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) { 331 fds[i] = msghdr_get_fd(cmsg); 332 if (fds[i] < 0) 333 break; 334 } 335 336 if (cmsg != NULL || i < nfds) { 337 int fd; 338 339 /* 340 * We need to close all received descriptors, even if we have 341 * different control message (eg. SCM_CREDS) in between. 342 */ 343 for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL; 344 cmsg = CMSG_NXTHDR(&msg, cmsg)) { 345 fd = msghdr_get_fd(cmsg); 346 if (fd >= 0) 347 close(fd); 348 } 349 errno = EINVAL; 350 goto end; 351 } 352 353 ret = 0; 354 end: 355 serrno = errno; 356 free(msg.msg_control); 357 errno = serrno; 358 return (ret); 359 } 360 361 int 362 fd_recv(int sock, int *fds, size_t nfds) 363 { 364 unsigned int i, step, j; 365 int ret, serrno; 366 367 if (nfds == 0 || fds == NULL) { 368 errno = EINVAL; 369 return (-1); 370 } 371 372 ret = i = step = 0; 373 while (i < nfds) { 374 if (PKG_MAX_SIZE < nfds - i) 375 step = PKG_MAX_SIZE; 376 else 377 step = nfds - i; 378 ret = fd_package_recv(sock, fds + i, step); 379 if (ret != 0) { 380 /* Close all received descriptors. */ 381 serrno = errno; 382 for (j = 0; j < i; j++) 383 close(fds[j]); 384 errno = serrno; 385 break; 386 } 387 i += step; 388 } 389 390 return (ret); 391 } 392 393 int 394 fd_send(int sock, const int *fds, size_t nfds) 395 { 396 unsigned int i, step; 397 int ret; 398 399 if (nfds == 0 || fds == NULL) { 400 errno = EINVAL; 401 return (-1); 402 } 403 404 ret = i = step = 0; 405 while (i < nfds) { 406 if (PKG_MAX_SIZE < nfds - i) 407 step = PKG_MAX_SIZE; 408 else 409 step = nfds - i; 410 ret = fd_package_send(sock, fds + i, step); 411 if (ret != 0) 412 break; 413 i += step; 414 } 415 416 return (ret); 417 } 418 419 int 420 buf_send(int sock, void *buf, size_t size) 421 { 422 ssize_t done; 423 unsigned char *ptr; 424 425 PJDLOG_ASSERT(sock >= 0); 426 PJDLOG_ASSERT(size > 0); 427 PJDLOG_ASSERT(buf != NULL); 428 429 ptr = buf; 430 do { 431 fd_wait(sock, false); 432 done = send(sock, ptr, size, 0); 433 if (done == -1) { 434 if (errno == EINTR) 435 continue; 436 return (-1); 437 } else if (done == 0) { 438 errno = ENOTCONN; 439 return (-1); 440 } 441 size -= done; 442 ptr += done; 443 } while (size > 0); 444 445 return (0); 446 } 447 448 int 449 buf_recv(int sock, void *buf, size_t size) 450 { 451 ssize_t done; 452 unsigned char *ptr; 453 454 PJDLOG_ASSERT(sock >= 0); 455 PJDLOG_ASSERT(buf != NULL); 456 457 ptr = buf; 458 while (size > 0) { 459 fd_wait(sock, true); 460 done = recv(sock, ptr, size, 0); 461 if (done == -1) { 462 if (errno == EINTR) 463 continue; 464 return (-1); 465 } else if (done == 0) { 466 errno = ENOTCONN; 467 return (-1); 468 } 469 size -= done; 470 ptr += done; 471 } 472 473 return (0); 474 } 475