1 /*- 2 * SPDX-License-Identifier: BSD-2-Clause-FreeBSD 3 * 4 * Copyright (c) 2013 The FreeBSD Foundation 5 * Copyright (c) 2013 Mariusz Zaborski <oshogbo@FreeBSD.org> 6 * All rights reserved. 7 * 8 * This software was developed by Pawel Jakub Dawidek under sponsorship from 9 * the FreeBSD Foundation. 10 * 11 * Redistribution and use in source and binary forms, with or without 12 * modification, are permitted provided that the following conditions 13 * are met: 14 * 1. Redistributions of source code must retain the above copyright 15 * notice, this list of conditions and the following disclaimer. 16 * 2. Redistributions in binary form must reproduce the above copyright 17 * notice, this list of conditions and the following disclaimer in the 18 * documentation and/or other materials provided with the distribution. 19 * 20 * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND 21 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 23 * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE 24 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS 26 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 27 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 28 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY 29 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF 30 * SUCH DAMAGE. 31 */ 32 33 #include <sys/cdefs.h> 34 __FBSDID("$FreeBSD$"); 35 36 #include <sys/param.h> 37 #include <sys/socket.h> 38 39 #include <errno.h> 40 #include <fcntl.h> 41 #include <stdbool.h> 42 #include <stdint.h> 43 #include <stdlib.h> 44 #include <string.h> 45 #include <unistd.h> 46 47 #ifdef HAVE_PJDLOG 48 #include <pjdlog.h> 49 #endif 50 51 #include "common_impl.h" 52 #include "msgio.h" 53 54 #ifndef HAVE_PJDLOG 55 #include <assert.h> 56 #define PJDLOG_ASSERT(...) assert(__VA_ARGS__) 57 #define PJDLOG_RASSERT(expr, ...) assert(expr) 58 #define PJDLOG_ABORT(...) abort() 59 #endif 60 61 #define PKG_MAX_SIZE (MCLBYTES / CMSG_SPACE(sizeof(int)) - 1) 62 63 static int 64 msghdr_add_fd(struct cmsghdr *cmsg, int fd) 65 { 66 67 PJDLOG_ASSERT(fd >= 0); 68 69 if (!fd_is_valid(fd)) { 70 errno = EBADF; 71 return (-1); 72 } 73 74 cmsg->cmsg_level = SOL_SOCKET; 75 cmsg->cmsg_type = SCM_RIGHTS; 76 cmsg->cmsg_len = CMSG_LEN(sizeof(fd)); 77 bcopy(&fd, CMSG_DATA(cmsg), sizeof(fd)); 78 79 return (0); 80 } 81 82 static int 83 msghdr_get_fd(struct cmsghdr *cmsg) 84 { 85 int fd; 86 87 if (cmsg == NULL || cmsg->cmsg_level != SOL_SOCKET || 88 cmsg->cmsg_type != SCM_RIGHTS || 89 cmsg->cmsg_len != CMSG_LEN(sizeof(fd))) { 90 errno = EINVAL; 91 return (-1); 92 } 93 94 bcopy(CMSG_DATA(cmsg), &fd, sizeof(fd)); 95 #ifndef MSG_CMSG_CLOEXEC 96 /* 97 * If the MSG_CMSG_CLOEXEC flag is not available we cannot set the 98 * close-on-exec flag atomically, but we still want to set it for 99 * consistency. 100 */ 101 (void) fcntl(fd, F_SETFD, FD_CLOEXEC); 102 #endif 103 104 return (fd); 105 } 106 107 static void 108 fd_wait(int fd, bool doread) 109 { 110 fd_set fds; 111 112 PJDLOG_ASSERT(fd >= 0); 113 114 FD_ZERO(&fds); 115 FD_SET(fd, &fds); 116 (void)select(fd + 1, doread ? &fds : NULL, doread ? NULL : &fds, 117 NULL, NULL); 118 } 119 120 static int 121 msg_recv(int sock, struct msghdr *msg) 122 { 123 int flags; 124 125 PJDLOG_ASSERT(sock >= 0); 126 127 #ifdef MSG_CMSG_CLOEXEC 128 flags = MSG_CMSG_CLOEXEC; 129 #else 130 flags = 0; 131 #endif 132 133 for (;;) { 134 fd_wait(sock, true); 135 if (recvmsg(sock, msg, flags) == -1) { 136 if (errno == EINTR) 137 continue; 138 return (-1); 139 } 140 break; 141 } 142 143 return (0); 144 } 145 146 static int 147 msg_send(int sock, const struct msghdr *msg) 148 { 149 150 PJDLOG_ASSERT(sock >= 0); 151 152 for (;;) { 153 fd_wait(sock, false); 154 if (sendmsg(sock, msg, 0) == -1) { 155 if (errno == EINTR) 156 continue; 157 return (-1); 158 } 159 break; 160 } 161 162 return (0); 163 } 164 165 int 166 cred_send(int sock) 167 { 168 unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))]; 169 struct msghdr msg; 170 struct cmsghdr *cmsg; 171 struct iovec iov; 172 uint8_t dummy; 173 174 bzero(credbuf, sizeof(credbuf)); 175 bzero(&msg, sizeof(msg)); 176 bzero(&iov, sizeof(iov)); 177 178 /* 179 * XXX: We send one byte along with the control message, because 180 * setting msg_iov to NULL only works if this is the first 181 * packet send over the socket. Once we send some data we 182 * won't be able to send credentials anymore. This is most 183 * likely a kernel bug. 184 */ 185 dummy = 0; 186 iov.iov_base = &dummy; 187 iov.iov_len = sizeof(dummy); 188 189 msg.msg_iov = &iov; 190 msg.msg_iovlen = 1; 191 msg.msg_control = credbuf; 192 msg.msg_controllen = sizeof(credbuf); 193 194 cmsg = CMSG_FIRSTHDR(&msg); 195 cmsg->cmsg_len = CMSG_LEN(sizeof(struct cmsgcred)); 196 cmsg->cmsg_level = SOL_SOCKET; 197 cmsg->cmsg_type = SCM_CREDS; 198 199 if (msg_send(sock, &msg) == -1) 200 return (-1); 201 202 return (0); 203 } 204 205 int 206 cred_recv(int sock, struct cmsgcred *cred) 207 { 208 unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))]; 209 struct msghdr msg; 210 struct cmsghdr *cmsg; 211 struct iovec iov; 212 uint8_t dummy; 213 214 bzero(credbuf, sizeof(credbuf)); 215 bzero(&msg, sizeof(msg)); 216 bzero(&iov, sizeof(iov)); 217 218 iov.iov_base = &dummy; 219 iov.iov_len = sizeof(dummy); 220 221 msg.msg_iov = &iov; 222 msg.msg_iovlen = 1; 223 msg.msg_control = credbuf; 224 msg.msg_controllen = sizeof(credbuf); 225 226 if (msg_recv(sock, &msg) == -1) 227 return (-1); 228 229 cmsg = CMSG_FIRSTHDR(&msg); 230 if (cmsg == NULL || 231 cmsg->cmsg_len != CMSG_LEN(sizeof(struct cmsgcred)) || 232 cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_CREDS) { 233 errno = EINVAL; 234 return (-1); 235 } 236 bcopy(CMSG_DATA(cmsg), cred, sizeof(*cred)); 237 238 return (0); 239 } 240 241 static int 242 fd_package_send(int sock, const int *fds, size_t nfds) 243 { 244 struct msghdr msg; 245 struct cmsghdr *cmsg; 246 struct iovec iov; 247 unsigned int i; 248 int serrno, ret; 249 uint8_t dummy; 250 251 PJDLOG_ASSERT(sock >= 0); 252 PJDLOG_ASSERT(fds != NULL); 253 PJDLOG_ASSERT(nfds > 0); 254 255 bzero(&msg, sizeof(msg)); 256 257 /* 258 * XXX: Look into cred_send function for more details. 259 */ 260 dummy = 0; 261 iov.iov_base = &dummy; 262 iov.iov_len = sizeof(dummy); 263 264 msg.msg_iov = &iov; 265 msg.msg_iovlen = 1; 266 msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int)); 267 msg.msg_control = calloc(1, msg.msg_controllen); 268 if (msg.msg_control == NULL) 269 return (-1); 270 271 ret = -1; 272 273 for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL; 274 i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) { 275 if (msghdr_add_fd(cmsg, fds[i]) == -1) 276 goto end; 277 } 278 279 if (msg_send(sock, &msg) == -1) 280 goto end; 281 282 ret = 0; 283 end: 284 serrno = errno; 285 free(msg.msg_control); 286 errno = serrno; 287 return (ret); 288 } 289 290 static int 291 fd_package_recv(int sock, int *fds, size_t nfds) 292 { 293 struct msghdr msg; 294 struct cmsghdr *cmsg; 295 unsigned int i; 296 int serrno, ret; 297 struct iovec iov; 298 uint8_t dummy; 299 300 PJDLOG_ASSERT(sock >= 0); 301 PJDLOG_ASSERT(nfds > 0); 302 PJDLOG_ASSERT(fds != NULL); 303 304 bzero(&msg, sizeof(msg)); 305 bzero(&iov, sizeof(iov)); 306 307 /* 308 * XXX: Look into cred_send function for more details. 309 */ 310 iov.iov_base = &dummy; 311 iov.iov_len = sizeof(dummy); 312 313 msg.msg_iov = &iov; 314 msg.msg_iovlen = 1; 315 msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int)); 316 msg.msg_control = calloc(1, msg.msg_controllen); 317 if (msg.msg_control == NULL) 318 return (-1); 319 320 ret = -1; 321 322 if (msg_recv(sock, &msg) == -1) 323 goto end; 324 325 for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL; 326 i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) { 327 fds[i] = msghdr_get_fd(cmsg); 328 if (fds[i] < 0) 329 break; 330 } 331 332 if (cmsg != NULL || i < nfds) { 333 int fd; 334 335 /* 336 * We need to close all received descriptors, even if we have 337 * different control message (eg. SCM_CREDS) in between. 338 */ 339 for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL; 340 cmsg = CMSG_NXTHDR(&msg, cmsg)) { 341 fd = msghdr_get_fd(cmsg); 342 if (fd >= 0) 343 close(fd); 344 } 345 errno = EINVAL; 346 goto end; 347 } 348 349 ret = 0; 350 end: 351 serrno = errno; 352 free(msg.msg_control); 353 errno = serrno; 354 return (ret); 355 } 356 357 int 358 fd_recv(int sock, int *fds, size_t nfds) 359 { 360 unsigned int i, step, j; 361 int ret, serrno; 362 363 if (nfds == 0 || fds == NULL) { 364 errno = EINVAL; 365 return (-1); 366 } 367 368 ret = i = step = 0; 369 while (i < nfds) { 370 if (PKG_MAX_SIZE < nfds - i) 371 step = PKG_MAX_SIZE; 372 else 373 step = nfds - i; 374 ret = fd_package_recv(sock, fds + i, step); 375 if (ret != 0) { 376 /* Close all received descriptors. */ 377 serrno = errno; 378 for (j = 0; j < i; j++) 379 close(fds[j]); 380 errno = serrno; 381 break; 382 } 383 i += step; 384 } 385 386 return (ret); 387 } 388 389 int 390 fd_send(int sock, const int *fds, size_t nfds) 391 { 392 unsigned int i, step; 393 int ret; 394 395 if (nfds == 0 || fds == NULL) { 396 errno = EINVAL; 397 return (-1); 398 } 399 400 ret = i = step = 0; 401 while (i < nfds) { 402 if (PKG_MAX_SIZE < nfds - i) 403 step = PKG_MAX_SIZE; 404 else 405 step = nfds - i; 406 ret = fd_package_send(sock, fds + i, step); 407 if (ret != 0) 408 break; 409 i += step; 410 } 411 412 return (ret); 413 } 414 415 int 416 buf_send(int sock, void *buf, size_t size) 417 { 418 ssize_t done; 419 unsigned char *ptr; 420 421 PJDLOG_ASSERT(sock >= 0); 422 PJDLOG_ASSERT(size > 0); 423 PJDLOG_ASSERT(buf != NULL); 424 425 ptr = buf; 426 do { 427 fd_wait(sock, false); 428 done = send(sock, ptr, size, 0); 429 if (done == -1) { 430 if (errno == EINTR) 431 continue; 432 return (-1); 433 } else if (done == 0) { 434 errno = ENOTCONN; 435 return (-1); 436 } 437 size -= done; 438 ptr += done; 439 } while (size > 0); 440 441 return (0); 442 } 443 444 int 445 buf_recv(int sock, void *buf, size_t size) 446 { 447 ssize_t done; 448 unsigned char *ptr; 449 450 PJDLOG_ASSERT(sock >= 0); 451 PJDLOG_ASSERT(buf != NULL); 452 453 ptr = buf; 454 while (size > 0) { 455 fd_wait(sock, true); 456 done = recv(sock, ptr, size, 0); 457 if (done == -1) { 458 if (errno == EINTR) 459 continue; 460 return (-1); 461 } else if (done == 0) { 462 errno = ENOTCONN; 463 return (-1); 464 } 465 size -= done; 466 ptr += done; 467 } 468 469 return (0); 470 } 471