1 /*- 2 * Copyright (c) 2013 The FreeBSD Foundation 3 * Copyright (c) 2013 Mariusz Zaborski <oshogbo@FreeBSD.org> 4 * All rights reserved. 5 * 6 * This software was developed by Pawel Jakub Dawidek under sponsorship from 7 * the FreeBSD Foundation. 8 * 9 * Redistribution and use in source and binary forms, with or without 10 * modification, are permitted provided that the following conditions 11 * are met: 12 * 1. Redistributions of source code must retain the above copyright 13 * notice, this list of conditions and the following disclaimer. 14 * 2. Redistributions in binary form must reproduce the above copyright 15 * notice, this list of conditions and the following disclaimer in the 16 * documentation and/or other materials provided with the distribution. 17 * 18 * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND 19 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE 22 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS 24 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 25 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 26 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY 27 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF 28 * SUCH DAMAGE. 29 */ 30 31 #include <sys/cdefs.h> 32 __FBSDID("$FreeBSD$"); 33 34 #include <sys/param.h> 35 #include <sys/socket.h> 36 37 #include <errno.h> 38 #include <fcntl.h> 39 #include <stdbool.h> 40 #include <stdint.h> 41 #include <stdlib.h> 42 #include <string.h> 43 #include <unistd.h> 44 45 #ifdef HAVE_PJDLOG 46 #include <pjdlog.h> 47 #endif 48 49 #include "common_impl.h" 50 #include "msgio.h" 51 52 #ifndef HAVE_PJDLOG 53 #include <assert.h> 54 #define PJDLOG_ASSERT(...) assert(__VA_ARGS__) 55 #define PJDLOG_RASSERT(expr, ...) assert(expr) 56 #define PJDLOG_ABORT(...) abort() 57 #endif 58 59 #define PKG_MAX_SIZE (MCLBYTES / CMSG_SPACE(sizeof(int)) - 1) 60 61 static int 62 msghdr_add_fd(struct cmsghdr *cmsg, int fd) 63 { 64 65 PJDLOG_ASSERT(fd >= 0); 66 67 if (!fd_is_valid(fd)) { 68 errno = EBADF; 69 return (-1); 70 } 71 72 cmsg->cmsg_level = SOL_SOCKET; 73 cmsg->cmsg_type = SCM_RIGHTS; 74 cmsg->cmsg_len = CMSG_LEN(sizeof(fd)); 75 bcopy(&fd, CMSG_DATA(cmsg), sizeof(fd)); 76 77 return (0); 78 } 79 80 static int 81 msghdr_get_fd(struct cmsghdr *cmsg) 82 { 83 int fd; 84 85 if (cmsg == NULL || cmsg->cmsg_level != SOL_SOCKET || 86 cmsg->cmsg_type != SCM_RIGHTS || 87 cmsg->cmsg_len != CMSG_LEN(sizeof(fd))) { 88 errno = EINVAL; 89 return (-1); 90 } 91 92 bcopy(CMSG_DATA(cmsg), &fd, sizeof(fd)); 93 #ifndef MSG_CMSG_CLOEXEC 94 /* 95 * If the MSG_CMSG_CLOEXEC flag is not available we cannot set the 96 * close-on-exec flag atomically, but we still want to set it for 97 * consistency. 98 */ 99 (void) fcntl(fd, F_SETFD, FD_CLOEXEC); 100 #endif 101 102 return (fd); 103 } 104 105 static void 106 fd_wait(int fd, bool doread) 107 { 108 fd_set fds; 109 110 PJDLOG_ASSERT(fd >= 0); 111 112 FD_ZERO(&fds); 113 FD_SET(fd, &fds); 114 (void)select(fd + 1, doread ? &fds : NULL, doread ? NULL : &fds, 115 NULL, NULL); 116 } 117 118 static int 119 msg_recv(int sock, struct msghdr *msg) 120 { 121 int flags; 122 123 PJDLOG_ASSERT(sock >= 0); 124 125 #ifdef MSG_CMSG_CLOEXEC 126 flags = MSG_CMSG_CLOEXEC; 127 #else 128 flags = 0; 129 #endif 130 131 for (;;) { 132 fd_wait(sock, true); 133 if (recvmsg(sock, msg, flags) == -1) { 134 if (errno == EINTR) 135 continue; 136 return (-1); 137 } 138 break; 139 } 140 141 return (0); 142 } 143 144 static int 145 msg_send(int sock, const struct msghdr *msg) 146 { 147 148 PJDLOG_ASSERT(sock >= 0); 149 150 for (;;) { 151 fd_wait(sock, false); 152 if (sendmsg(sock, msg, 0) == -1) { 153 if (errno == EINTR) 154 continue; 155 return (-1); 156 } 157 break; 158 } 159 160 return (0); 161 } 162 163 int 164 cred_send(int sock) 165 { 166 unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))]; 167 struct msghdr msg; 168 struct cmsghdr *cmsg; 169 struct iovec iov; 170 uint8_t dummy; 171 172 bzero(credbuf, sizeof(credbuf)); 173 bzero(&msg, sizeof(msg)); 174 bzero(&iov, sizeof(iov)); 175 176 /* 177 * XXX: We send one byte along with the control message, because 178 * setting msg_iov to NULL only works if this is the first 179 * packet send over the socket. Once we send some data we 180 * won't be able to send credentials anymore. This is most 181 * likely a kernel bug. 182 */ 183 dummy = 0; 184 iov.iov_base = &dummy; 185 iov.iov_len = sizeof(dummy); 186 187 msg.msg_iov = &iov; 188 msg.msg_iovlen = 1; 189 msg.msg_control = credbuf; 190 msg.msg_controllen = sizeof(credbuf); 191 192 cmsg = CMSG_FIRSTHDR(&msg); 193 cmsg->cmsg_len = CMSG_LEN(sizeof(struct cmsgcred)); 194 cmsg->cmsg_level = SOL_SOCKET; 195 cmsg->cmsg_type = SCM_CREDS; 196 197 if (msg_send(sock, &msg) == -1) 198 return (-1); 199 200 return (0); 201 } 202 203 int 204 cred_recv(int sock, struct cmsgcred *cred) 205 { 206 unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))]; 207 struct msghdr msg; 208 struct cmsghdr *cmsg; 209 struct iovec iov; 210 uint8_t dummy; 211 212 bzero(credbuf, sizeof(credbuf)); 213 bzero(&msg, sizeof(msg)); 214 bzero(&iov, sizeof(iov)); 215 216 iov.iov_base = &dummy; 217 iov.iov_len = sizeof(dummy); 218 219 msg.msg_iov = &iov; 220 msg.msg_iovlen = 1; 221 msg.msg_control = credbuf; 222 msg.msg_controllen = sizeof(credbuf); 223 224 if (msg_recv(sock, &msg) == -1) 225 return (-1); 226 227 cmsg = CMSG_FIRSTHDR(&msg); 228 if (cmsg == NULL || 229 cmsg->cmsg_len != CMSG_LEN(sizeof(struct cmsgcred)) || 230 cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_CREDS) { 231 errno = EINVAL; 232 return (-1); 233 } 234 bcopy(CMSG_DATA(cmsg), cred, sizeof(*cred)); 235 236 return (0); 237 } 238 239 static int 240 fd_package_send(int sock, const int *fds, size_t nfds) 241 { 242 struct msghdr msg; 243 struct cmsghdr *cmsg; 244 struct iovec iov; 245 unsigned int i; 246 int serrno, ret; 247 uint8_t dummy; 248 249 PJDLOG_ASSERT(sock >= 0); 250 PJDLOG_ASSERT(fds != NULL); 251 PJDLOG_ASSERT(nfds > 0); 252 253 bzero(&msg, sizeof(msg)); 254 255 /* 256 * XXX: Look into cred_send function for more details. 257 */ 258 dummy = 0; 259 iov.iov_base = &dummy; 260 iov.iov_len = sizeof(dummy); 261 262 msg.msg_iov = &iov; 263 msg.msg_iovlen = 1; 264 msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int)); 265 msg.msg_control = calloc(1, msg.msg_controllen); 266 if (msg.msg_control == NULL) 267 return (-1); 268 269 ret = -1; 270 271 for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL; 272 i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) { 273 if (msghdr_add_fd(cmsg, fds[i]) == -1) 274 goto end; 275 } 276 277 if (msg_send(sock, &msg) == -1) 278 goto end; 279 280 ret = 0; 281 end: 282 serrno = errno; 283 free(msg.msg_control); 284 errno = serrno; 285 return (ret); 286 } 287 288 static int 289 fd_package_recv(int sock, int *fds, size_t nfds) 290 { 291 struct msghdr msg; 292 struct cmsghdr *cmsg; 293 unsigned int i; 294 int serrno, ret; 295 struct iovec iov; 296 uint8_t dummy; 297 298 PJDLOG_ASSERT(sock >= 0); 299 PJDLOG_ASSERT(nfds > 0); 300 PJDLOG_ASSERT(fds != NULL); 301 302 i = 0; 303 bzero(&msg, sizeof(msg)); 304 bzero(&iov, sizeof(iov)); 305 306 /* 307 * XXX: Look into cred_send function for more details. 308 */ 309 iov.iov_base = &dummy; 310 iov.iov_len = sizeof(dummy); 311 312 msg.msg_iov = &iov; 313 msg.msg_iovlen = 1; 314 msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int)); 315 msg.msg_control = calloc(1, msg.msg_controllen); 316 if (msg.msg_control == NULL) 317 return (-1); 318 319 ret = -1; 320 321 if (msg_recv(sock, &msg) == -1) 322 goto end; 323 324 for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL; 325 i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) { 326 fds[i] = msghdr_get_fd(cmsg); 327 if (fds[i] < 0) 328 break; 329 } 330 331 if (cmsg != NULL || i < nfds) { 332 int fd; 333 334 /* 335 * We need to close all received descriptors, even if we have 336 * different control message (eg. SCM_CREDS) in between. 337 */ 338 for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL; 339 cmsg = CMSG_NXTHDR(&msg, cmsg)) { 340 fd = msghdr_get_fd(cmsg); 341 if (fd >= 0) 342 close(fd); 343 } 344 errno = EINVAL; 345 goto end; 346 } 347 348 ret = 0; 349 end: 350 serrno = errno; 351 free(msg.msg_control); 352 errno = serrno; 353 return (ret); 354 } 355 356 int 357 fd_recv(int sock, int *fds, size_t nfds) 358 { 359 unsigned int i, step, j; 360 int ret, serrno; 361 362 if (nfds == 0 || fds == NULL) { 363 errno = EINVAL; 364 return (-1); 365 } 366 367 ret = i = step = 0; 368 while (i < nfds) { 369 if (PKG_MAX_SIZE < nfds - i) 370 step = PKG_MAX_SIZE; 371 else 372 step = nfds - i; 373 ret = fd_package_recv(sock, fds + i, step); 374 if (ret != 0) { 375 /* Close all received descriptors. */ 376 serrno = errno; 377 for (j = 0; j < i; j++) 378 close(fds[j]); 379 errno = serrno; 380 break; 381 } 382 i += step; 383 } 384 385 return (ret); 386 } 387 388 int 389 fd_send(int sock, const int *fds, size_t nfds) 390 { 391 unsigned int i, step; 392 int ret; 393 394 if (nfds == 0 || fds == NULL) { 395 errno = EINVAL; 396 return (-1); 397 } 398 399 ret = i = step = 0; 400 while (i < nfds) { 401 if (PKG_MAX_SIZE < nfds - i) 402 step = PKG_MAX_SIZE; 403 else 404 step = nfds - i; 405 ret = fd_package_send(sock, fds + i, step); 406 if (ret != 0) 407 break; 408 i += step; 409 } 410 411 return (ret); 412 } 413 414 int 415 buf_send(int sock, void *buf, size_t size) 416 { 417 ssize_t done; 418 unsigned char *ptr; 419 420 PJDLOG_ASSERT(sock >= 0); 421 PJDLOG_ASSERT(size > 0); 422 PJDLOG_ASSERT(buf != NULL); 423 424 ptr = buf; 425 do { 426 fd_wait(sock, false); 427 done = send(sock, ptr, size, 0); 428 if (done == -1) { 429 if (errno == EINTR) 430 continue; 431 return (-1); 432 } else if (done == 0) { 433 errno = ENOTCONN; 434 return (-1); 435 } 436 size -= done; 437 ptr += done; 438 } while (size > 0); 439 440 return (0); 441 } 442 443 int 444 buf_recv(int sock, void *buf, size_t size) 445 { 446 ssize_t done; 447 unsigned char *ptr; 448 449 PJDLOG_ASSERT(sock >= 0); 450 PJDLOG_ASSERT(buf != NULL); 451 452 ptr = buf; 453 while (size > 0) { 454 fd_wait(sock, true); 455 done = recv(sock, ptr, size, 0); 456 if (done == -1) { 457 if (errno == EINTR) 458 continue; 459 return (-1); 460 } else if (done == 0) { 461 errno = ENOTCONN; 462 return (-1); 463 } 464 size -= done; 465 ptr += done; 466 } 467 468 return (0); 469 } 470