1 /*- 2 * SPDX-License-Identifier: BSD-2-Clause-FreeBSD 3 * 4 * Copyright (c) 2013 The FreeBSD Foundation 5 * Copyright (c) 2013 Mariusz Zaborski <oshogbo@FreeBSD.org> 6 * All rights reserved. 7 * 8 * This software was developed by Pawel Jakub Dawidek under sponsorship from 9 * the FreeBSD Foundation. 10 * 11 * Redistribution and use in source and binary forms, with or without 12 * modification, are permitted provided that the following conditions 13 * are met: 14 * 1. Redistributions of source code must retain the above copyright 15 * notice, this list of conditions and the following disclaimer. 16 * 2. Redistributions in binary form must reproduce the above copyright 17 * notice, this list of conditions and the following disclaimer in the 18 * documentation and/or other materials provided with the distribution. 19 * 20 * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND 21 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 23 * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE 24 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS 26 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 27 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 28 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY 29 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF 30 * SUCH DAMAGE. 31 */ 32 33 #include <sys/cdefs.h> 34 __FBSDID("$FreeBSD$"); 35 36 #include <sys/param.h> 37 #include <sys/socket.h> 38 39 #include <errno.h> 40 #include <fcntl.h> 41 #include <stdbool.h> 42 #include <stdint.h> 43 #include <stdlib.h> 44 #include <string.h> 45 #include <unistd.h> 46 47 #ifdef HAVE_PJDLOG 48 #include <pjdlog.h> 49 #endif 50 51 #include "common_impl.h" 52 #include "msgio.h" 53 54 #ifndef HAVE_PJDLOG 55 #include <assert.h> 56 #define PJDLOG_ASSERT(...) assert(__VA_ARGS__) 57 #define PJDLOG_RASSERT(expr, ...) assert(expr) 58 #define PJDLOG_ABORT(...) abort() 59 #endif 60 61 #define PKG_MAX_SIZE (MCLBYTES / CMSG_SPACE(sizeof(int)) - 1) 62 63 static int 64 msghdr_add_fd(struct cmsghdr *cmsg, int fd) 65 { 66 67 PJDLOG_ASSERT(fd >= 0); 68 69 cmsg->cmsg_level = SOL_SOCKET; 70 cmsg->cmsg_type = SCM_RIGHTS; 71 cmsg->cmsg_len = CMSG_LEN(sizeof(fd)); 72 bcopy(&fd, CMSG_DATA(cmsg), sizeof(fd)); 73 74 return (0); 75 } 76 77 static int 78 msghdr_get_fd(struct cmsghdr *cmsg) 79 { 80 int fd; 81 82 if (cmsg == NULL || cmsg->cmsg_level != SOL_SOCKET || 83 cmsg->cmsg_type != SCM_RIGHTS || 84 cmsg->cmsg_len != CMSG_LEN(sizeof(fd))) { 85 errno = EINVAL; 86 return (-1); 87 } 88 89 bcopy(CMSG_DATA(cmsg), &fd, sizeof(fd)); 90 #ifndef MSG_CMSG_CLOEXEC 91 /* 92 * If the MSG_CMSG_CLOEXEC flag is not available we cannot set the 93 * close-on-exec flag atomically, but we still want to set it for 94 * consistency. 95 */ 96 (void) fcntl(fd, F_SETFD, FD_CLOEXEC); 97 #endif 98 99 return (fd); 100 } 101 102 static void 103 fd_wait(int fd, bool doread) 104 { 105 fd_set fds; 106 107 PJDLOG_ASSERT(fd >= 0); 108 109 FD_ZERO(&fds); 110 FD_SET(fd, &fds); 111 (void)select(fd + 1, doread ? &fds : NULL, doread ? NULL : &fds, 112 NULL, NULL); 113 } 114 115 static int 116 msg_recv(int sock, struct msghdr *msg) 117 { 118 int flags; 119 120 PJDLOG_ASSERT(sock >= 0); 121 122 #ifdef MSG_CMSG_CLOEXEC 123 flags = MSG_CMSG_CLOEXEC; 124 #else 125 flags = 0; 126 #endif 127 128 for (;;) { 129 fd_wait(sock, true); 130 if (recvmsg(sock, msg, flags) == -1) { 131 if (errno == EINTR) 132 continue; 133 return (-1); 134 } 135 break; 136 } 137 138 return (0); 139 } 140 141 static int 142 msg_send(int sock, const struct msghdr *msg) 143 { 144 145 PJDLOG_ASSERT(sock >= 0); 146 147 for (;;) { 148 fd_wait(sock, false); 149 if (sendmsg(sock, msg, 0) == -1) { 150 if (errno == EINTR) 151 continue; 152 return (-1); 153 } 154 break; 155 } 156 157 return (0); 158 } 159 160 int 161 cred_send(int sock) 162 { 163 unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))]; 164 struct msghdr msg; 165 struct cmsghdr *cmsg; 166 struct iovec iov; 167 uint8_t dummy; 168 169 bzero(credbuf, sizeof(credbuf)); 170 bzero(&msg, sizeof(msg)); 171 bzero(&iov, sizeof(iov)); 172 173 /* 174 * XXX: We send one byte along with the control message, because 175 * setting msg_iov to NULL only works if this is the first 176 * packet send over the socket. Once we send some data we 177 * won't be able to send credentials anymore. This is most 178 * likely a kernel bug. 179 */ 180 dummy = 0; 181 iov.iov_base = &dummy; 182 iov.iov_len = sizeof(dummy); 183 184 msg.msg_iov = &iov; 185 msg.msg_iovlen = 1; 186 msg.msg_control = credbuf; 187 msg.msg_controllen = sizeof(credbuf); 188 189 cmsg = CMSG_FIRSTHDR(&msg); 190 cmsg->cmsg_len = CMSG_LEN(sizeof(struct cmsgcred)); 191 cmsg->cmsg_level = SOL_SOCKET; 192 cmsg->cmsg_type = SCM_CREDS; 193 194 if (msg_send(sock, &msg) == -1) 195 return (-1); 196 197 return (0); 198 } 199 200 int 201 cred_recv(int sock, struct cmsgcred *cred) 202 { 203 unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))]; 204 struct msghdr msg; 205 struct cmsghdr *cmsg; 206 struct iovec iov; 207 uint8_t dummy; 208 209 bzero(credbuf, sizeof(credbuf)); 210 bzero(&msg, sizeof(msg)); 211 bzero(&iov, sizeof(iov)); 212 213 iov.iov_base = &dummy; 214 iov.iov_len = sizeof(dummy); 215 216 msg.msg_iov = &iov; 217 msg.msg_iovlen = 1; 218 msg.msg_control = credbuf; 219 msg.msg_controllen = sizeof(credbuf); 220 221 if (msg_recv(sock, &msg) == -1) 222 return (-1); 223 224 cmsg = CMSG_FIRSTHDR(&msg); 225 if (cmsg == NULL || 226 cmsg->cmsg_len != CMSG_LEN(sizeof(struct cmsgcred)) || 227 cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_CREDS) { 228 errno = EINVAL; 229 return (-1); 230 } 231 bcopy(CMSG_DATA(cmsg), cred, sizeof(*cred)); 232 233 return (0); 234 } 235 236 static int 237 fd_package_send(int sock, const int *fds, size_t nfds) 238 { 239 struct msghdr msg; 240 struct cmsghdr *cmsg; 241 struct iovec iov; 242 unsigned int i; 243 int serrno, ret; 244 uint8_t dummy; 245 246 PJDLOG_ASSERT(sock >= 0); 247 PJDLOG_ASSERT(fds != NULL); 248 PJDLOG_ASSERT(nfds > 0); 249 250 bzero(&msg, sizeof(msg)); 251 252 /* 253 * XXX: Look into cred_send function for more details. 254 */ 255 dummy = 0; 256 iov.iov_base = &dummy; 257 iov.iov_len = sizeof(dummy); 258 259 msg.msg_iov = &iov; 260 msg.msg_iovlen = 1; 261 msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int)); 262 msg.msg_control = calloc(1, msg.msg_controllen); 263 if (msg.msg_control == NULL) 264 return (-1); 265 266 ret = -1; 267 268 for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL; 269 i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) { 270 if (msghdr_add_fd(cmsg, fds[i]) == -1) 271 goto end; 272 } 273 274 if (msg_send(sock, &msg) == -1) 275 goto end; 276 277 ret = 0; 278 end: 279 serrno = errno; 280 free(msg.msg_control); 281 errno = serrno; 282 return (ret); 283 } 284 285 static int 286 fd_package_recv(int sock, int *fds, size_t nfds) 287 { 288 struct msghdr msg; 289 struct cmsghdr *cmsg; 290 unsigned int i; 291 int serrno, ret; 292 struct iovec iov; 293 uint8_t dummy; 294 295 PJDLOG_ASSERT(sock >= 0); 296 PJDLOG_ASSERT(nfds > 0); 297 PJDLOG_ASSERT(fds != NULL); 298 299 bzero(&msg, sizeof(msg)); 300 bzero(&iov, sizeof(iov)); 301 302 /* 303 * XXX: Look into cred_send function for more details. 304 */ 305 iov.iov_base = &dummy; 306 iov.iov_len = sizeof(dummy); 307 308 msg.msg_iov = &iov; 309 msg.msg_iovlen = 1; 310 msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int)); 311 msg.msg_control = calloc(1, msg.msg_controllen); 312 if (msg.msg_control == NULL) 313 return (-1); 314 315 ret = -1; 316 317 if (msg_recv(sock, &msg) == -1) 318 goto end; 319 320 for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL; 321 i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) { 322 fds[i] = msghdr_get_fd(cmsg); 323 if (fds[i] < 0) 324 break; 325 } 326 327 if (cmsg != NULL || i < nfds) { 328 int fd; 329 330 /* 331 * We need to close all received descriptors, even if we have 332 * different control message (eg. SCM_CREDS) in between. 333 */ 334 for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL; 335 cmsg = CMSG_NXTHDR(&msg, cmsg)) { 336 fd = msghdr_get_fd(cmsg); 337 if (fd >= 0) 338 close(fd); 339 } 340 errno = EINVAL; 341 goto end; 342 } 343 344 ret = 0; 345 end: 346 serrno = errno; 347 free(msg.msg_control); 348 errno = serrno; 349 return (ret); 350 } 351 352 int 353 fd_recv(int sock, int *fds, size_t nfds) 354 { 355 unsigned int i, step, j; 356 int ret, serrno; 357 358 if (nfds == 0 || fds == NULL) { 359 errno = EINVAL; 360 return (-1); 361 } 362 363 ret = i = step = 0; 364 while (i < nfds) { 365 if (PKG_MAX_SIZE < nfds - i) 366 step = PKG_MAX_SIZE; 367 else 368 step = nfds - i; 369 ret = fd_package_recv(sock, fds + i, step); 370 if (ret != 0) { 371 /* Close all received descriptors. */ 372 serrno = errno; 373 for (j = 0; j < i; j++) 374 close(fds[j]); 375 errno = serrno; 376 break; 377 } 378 i += step; 379 } 380 381 return (ret); 382 } 383 384 int 385 fd_send(int sock, const int *fds, size_t nfds) 386 { 387 unsigned int i, step; 388 int ret; 389 390 if (nfds == 0 || fds == NULL) { 391 errno = EINVAL; 392 return (-1); 393 } 394 395 ret = i = step = 0; 396 while (i < nfds) { 397 if (PKG_MAX_SIZE < nfds - i) 398 step = PKG_MAX_SIZE; 399 else 400 step = nfds - i; 401 ret = fd_package_send(sock, fds + i, step); 402 if (ret != 0) 403 break; 404 i += step; 405 } 406 407 return (ret); 408 } 409 410 int 411 buf_send(int sock, void *buf, size_t size) 412 { 413 ssize_t done; 414 unsigned char *ptr; 415 416 PJDLOG_ASSERT(sock >= 0); 417 PJDLOG_ASSERT(size > 0); 418 PJDLOG_ASSERT(buf != NULL); 419 420 ptr = buf; 421 do { 422 fd_wait(sock, false); 423 done = send(sock, ptr, size, 0); 424 if (done == -1) { 425 if (errno == EINTR) 426 continue; 427 return (-1); 428 } else if (done == 0) { 429 errno = ENOTCONN; 430 return (-1); 431 } 432 size -= done; 433 ptr += done; 434 } while (size > 0); 435 436 return (0); 437 } 438 439 int 440 buf_recv(int sock, void *buf, size_t size) 441 { 442 ssize_t done; 443 unsigned char *ptr; 444 445 PJDLOG_ASSERT(sock >= 0); 446 PJDLOG_ASSERT(buf != NULL); 447 448 ptr = buf; 449 while (size > 0) { 450 fd_wait(sock, true); 451 done = recv(sock, ptr, size, 0); 452 if (done == -1) { 453 if (errno == EINTR) 454 continue; 455 return (-1); 456 } else if (done == 0) { 457 errno = ENOTCONN; 458 return (-1); 459 } 460 size -= done; 461 ptr += done; 462 } 463 464 return (0); 465 } 466