1 /*- 2 * Copyright (c) 2013 The FreeBSD Foundation 3 * Copyright (c) 2013 Mariusz Zaborski <oshogbo@FreeBSD.org> 4 * All rights reserved. 5 * 6 * This software was developed by Pawel Jakub Dawidek under sponsorship from 7 * the FreeBSD Foundation. 8 * 9 * Redistribution and use in source and binary forms, with or without 10 * modification, are permitted provided that the following conditions 11 * are met: 12 * 1. Redistributions of source code must retain the above copyright 13 * notice, this list of conditions and the following disclaimer. 14 * 2. Redistributions in binary form must reproduce the above copyright 15 * notice, this list of conditions and the following disclaimer in the 16 * documentation and/or other materials provided with the distribution. 17 * 18 * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND 19 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE 22 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS 24 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 25 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 26 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY 27 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF 28 * SUCH DAMAGE. 29 */ 30 31 #include <sys/cdefs.h> 32 __FBSDID("$FreeBSD$"); 33 34 #include <sys/param.h> 35 #include <sys/socket.h> 36 37 #include <errno.h> 38 #include <fcntl.h> 39 #include <stdbool.h> 40 #include <stdint.h> 41 #include <stdlib.h> 42 #include <string.h> 43 #include <unistd.h> 44 45 #ifdef HAVE_PJDLOG 46 #include <pjdlog.h> 47 #endif 48 49 #include "common_impl.h" 50 #include "msgio.h" 51 52 #ifndef HAVE_PJDLOG 53 #include <assert.h> 54 #define PJDLOG_ASSERT(...) assert(__VA_ARGS__) 55 #define PJDLOG_RASSERT(expr, ...) assert(expr) 56 #define PJDLOG_ABORT(...) abort() 57 #endif 58 59 #define PKG_MAX_SIZE (MCLBYTES / CMSG_SPACE(sizeof(int)) - 1) 60 61 static int 62 msghdr_add_fd(struct cmsghdr *cmsg, int fd) 63 { 64 65 PJDLOG_ASSERT(fd >= 0); 66 67 if (!fd_is_valid(fd)) { 68 errno = EBADF; 69 return (-1); 70 } 71 72 cmsg->cmsg_level = SOL_SOCKET; 73 cmsg->cmsg_type = SCM_RIGHTS; 74 cmsg->cmsg_len = CMSG_LEN(sizeof(fd)); 75 bcopy(&fd, CMSG_DATA(cmsg), sizeof(fd)); 76 77 return (0); 78 } 79 80 static int 81 msghdr_get_fd(struct cmsghdr *cmsg) 82 { 83 int fd; 84 85 if (cmsg == NULL || cmsg->cmsg_level != SOL_SOCKET || 86 cmsg->cmsg_type != SCM_RIGHTS || 87 cmsg->cmsg_len != CMSG_LEN(sizeof(fd))) { 88 errno = EINVAL; 89 return (-1); 90 } 91 92 bcopy(CMSG_DATA(cmsg), &fd, sizeof(fd)); 93 #ifndef MSG_CMSG_CLOEXEC 94 /* 95 * If the MSG_CMSG_CLOEXEC flag is not available we cannot set the 96 * close-on-exec flag atomically, but we still want to set it for 97 * consistency. 98 */ 99 (void) fcntl(fd, F_SETFD, FD_CLOEXEC); 100 #endif 101 102 return (fd); 103 } 104 105 static void 106 fd_wait(int fd, bool doread) 107 { 108 fd_set fds; 109 110 PJDLOG_ASSERT(fd >= 0); 111 112 FD_ZERO(&fds); 113 FD_SET(fd, &fds); 114 (void)select(fd + 1, doread ? &fds : NULL, doread ? NULL : &fds, 115 NULL, NULL); 116 } 117 118 static int 119 msg_recv(int sock, struct msghdr *msg) 120 { 121 int flags; 122 123 PJDLOG_ASSERT(sock >= 0); 124 125 #ifdef MSG_CMSG_CLOEXEC 126 flags = MSG_CMSG_CLOEXEC; 127 #else 128 flags = 0; 129 #endif 130 131 for (;;) { 132 fd_wait(sock, true); 133 if (recvmsg(sock, msg, flags) == -1) { 134 if (errno == EINTR) 135 continue; 136 return (-1); 137 } 138 break; 139 } 140 141 return (0); 142 } 143 144 static int 145 msg_send(int sock, const struct msghdr *msg) 146 { 147 148 PJDLOG_ASSERT(sock >= 0); 149 150 for (;;) { 151 fd_wait(sock, false); 152 if (sendmsg(sock, msg, 0) == -1) { 153 if (errno == EINTR) 154 continue; 155 return (-1); 156 } 157 break; 158 } 159 160 return (0); 161 } 162 163 int 164 cred_send(int sock) 165 { 166 unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))]; 167 struct msghdr msg; 168 struct cmsghdr *cmsg; 169 struct iovec iov; 170 uint8_t dummy; 171 172 bzero(credbuf, sizeof(credbuf)); 173 bzero(&msg, sizeof(msg)); 174 bzero(&iov, sizeof(iov)); 175 176 /* 177 * XXX: We send one byte along with the control message, because 178 * setting msg_iov to NULL only works if this is the first 179 * packet send over the socket. Once we send some data we 180 * won't be able to send credentials anymore. This is most 181 * likely a kernel bug. 182 */ 183 dummy = 0; 184 iov.iov_base = &dummy; 185 iov.iov_len = sizeof(dummy); 186 187 msg.msg_iov = &iov; 188 msg.msg_iovlen = 1; 189 msg.msg_control = credbuf; 190 msg.msg_controllen = sizeof(credbuf); 191 192 cmsg = CMSG_FIRSTHDR(&msg); 193 cmsg->cmsg_len = CMSG_LEN(sizeof(struct cmsgcred)); 194 cmsg->cmsg_level = SOL_SOCKET; 195 cmsg->cmsg_type = SCM_CREDS; 196 197 if (msg_send(sock, &msg) == -1) 198 return (-1); 199 200 return (0); 201 } 202 203 int 204 cred_recv(int sock, struct cmsgcred *cred) 205 { 206 unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))]; 207 struct msghdr msg; 208 struct cmsghdr *cmsg; 209 struct iovec iov; 210 uint8_t dummy; 211 212 bzero(credbuf, sizeof(credbuf)); 213 bzero(&msg, sizeof(msg)); 214 bzero(&iov, sizeof(iov)); 215 216 iov.iov_base = &dummy; 217 iov.iov_len = sizeof(dummy); 218 219 msg.msg_iov = &iov; 220 msg.msg_iovlen = 1; 221 msg.msg_control = credbuf; 222 msg.msg_controllen = sizeof(credbuf); 223 224 if (msg_recv(sock, &msg) == -1) 225 return (-1); 226 227 cmsg = CMSG_FIRSTHDR(&msg); 228 if (cmsg == NULL || 229 cmsg->cmsg_len != CMSG_LEN(sizeof(struct cmsgcred)) || 230 cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_CREDS) { 231 errno = EINVAL; 232 return (-1); 233 } 234 bcopy(CMSG_DATA(cmsg), cred, sizeof(*cred)); 235 236 return (0); 237 } 238 239 static int 240 fd_package_send(int sock, const int *fds, size_t nfds) 241 { 242 struct msghdr msg; 243 struct cmsghdr *cmsg; 244 struct iovec iov; 245 unsigned int i; 246 int serrno, ret; 247 uint8_t dummy; 248 249 PJDLOG_ASSERT(sock >= 0); 250 PJDLOG_ASSERT(fds != NULL); 251 PJDLOG_ASSERT(nfds > 0); 252 253 bzero(&msg, sizeof(msg)); 254 255 /* 256 * XXX: Look into cred_send function for more details. 257 */ 258 dummy = 0; 259 iov.iov_base = &dummy; 260 iov.iov_len = sizeof(dummy); 261 262 msg.msg_iov = &iov; 263 msg.msg_iovlen = 1; 264 msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int)); 265 msg.msg_control = calloc(1, msg.msg_controllen); 266 if (msg.msg_control == NULL) 267 return (-1); 268 269 ret = -1; 270 271 for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL; 272 i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) { 273 if (msghdr_add_fd(cmsg, fds[i]) == -1) 274 goto end; 275 } 276 277 if (msg_send(sock, &msg) == -1) 278 goto end; 279 280 ret = 0; 281 end: 282 serrno = errno; 283 free(msg.msg_control); 284 errno = serrno; 285 return (ret); 286 } 287 288 static int 289 fd_package_recv(int sock, int *fds, size_t nfds) 290 { 291 struct msghdr msg; 292 struct cmsghdr *cmsg; 293 unsigned int i; 294 int serrno, ret; 295 struct iovec iov; 296 uint8_t dummy; 297 298 PJDLOG_ASSERT(sock >= 0); 299 PJDLOG_ASSERT(nfds > 0); 300 PJDLOG_ASSERT(fds != NULL); 301 302 bzero(&msg, sizeof(msg)); 303 bzero(&iov, sizeof(iov)); 304 305 /* 306 * XXX: Look into cred_send function for more details. 307 */ 308 iov.iov_base = &dummy; 309 iov.iov_len = sizeof(dummy); 310 311 msg.msg_iov = &iov; 312 msg.msg_iovlen = 1; 313 msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int)); 314 msg.msg_control = calloc(1, msg.msg_controllen); 315 if (msg.msg_control == NULL) 316 return (-1); 317 318 ret = -1; 319 320 if (msg_recv(sock, &msg) == -1) 321 goto end; 322 323 for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL; 324 i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) { 325 fds[i] = msghdr_get_fd(cmsg); 326 if (fds[i] < 0) 327 break; 328 } 329 330 if (cmsg != NULL || i < nfds) { 331 int fd; 332 333 /* 334 * We need to close all received descriptors, even if we have 335 * different control message (eg. SCM_CREDS) in between. 336 */ 337 for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL; 338 cmsg = CMSG_NXTHDR(&msg, cmsg)) { 339 fd = msghdr_get_fd(cmsg); 340 if (fd >= 0) 341 close(fd); 342 } 343 errno = EINVAL; 344 goto end; 345 } 346 347 ret = 0; 348 end: 349 serrno = errno; 350 free(msg.msg_control); 351 errno = serrno; 352 return (ret); 353 } 354 355 int 356 fd_recv(int sock, int *fds, size_t nfds) 357 { 358 unsigned int i, step, j; 359 int ret, serrno; 360 361 if (nfds == 0 || fds == NULL) { 362 errno = EINVAL; 363 return (-1); 364 } 365 366 ret = i = step = 0; 367 while (i < nfds) { 368 if (PKG_MAX_SIZE < nfds - i) 369 step = PKG_MAX_SIZE; 370 else 371 step = nfds - i; 372 ret = fd_package_recv(sock, fds + i, step); 373 if (ret != 0) { 374 /* Close all received descriptors. */ 375 serrno = errno; 376 for (j = 0; j < i; j++) 377 close(fds[j]); 378 errno = serrno; 379 break; 380 } 381 i += step; 382 } 383 384 return (ret); 385 } 386 387 int 388 fd_send(int sock, const int *fds, size_t nfds) 389 { 390 unsigned int i, step; 391 int ret; 392 393 if (nfds == 0 || fds == NULL) { 394 errno = EINVAL; 395 return (-1); 396 } 397 398 ret = i = step = 0; 399 while (i < nfds) { 400 if (PKG_MAX_SIZE < nfds - i) 401 step = PKG_MAX_SIZE; 402 else 403 step = nfds - i; 404 ret = fd_package_send(sock, fds + i, step); 405 if (ret != 0) 406 break; 407 i += step; 408 } 409 410 return (ret); 411 } 412 413 int 414 buf_send(int sock, void *buf, size_t size) 415 { 416 ssize_t done; 417 unsigned char *ptr; 418 419 PJDLOG_ASSERT(sock >= 0); 420 PJDLOG_ASSERT(size > 0); 421 PJDLOG_ASSERT(buf != NULL); 422 423 ptr = buf; 424 do { 425 fd_wait(sock, false); 426 done = send(sock, ptr, size, 0); 427 if (done == -1) { 428 if (errno == EINTR) 429 continue; 430 return (-1); 431 } else if (done == 0) { 432 errno = ENOTCONN; 433 return (-1); 434 } 435 size -= done; 436 ptr += done; 437 } while (size > 0); 438 439 return (0); 440 } 441 442 int 443 buf_recv(int sock, void *buf, size_t size) 444 { 445 ssize_t done; 446 unsigned char *ptr; 447 448 PJDLOG_ASSERT(sock >= 0); 449 PJDLOG_ASSERT(buf != NULL); 450 451 ptr = buf; 452 while (size > 0) { 453 fd_wait(sock, true); 454 done = recv(sock, ptr, size, 0); 455 if (done == -1) { 456 if (errno == EINTR) 457 continue; 458 return (-1); 459 } else if (done == 0) { 460 errno = ENOTCONN; 461 return (-1); 462 } 463 size -= done; 464 ptr += done; 465 } 466 467 return (0); 468 } 469