1 /*- 2 * Copyright (c) 2011 The FreeBSD Foundation 3 * All rights reserved. 4 * 5 * This software was developed by Pawel Jakub Dawidek under sponsorship from 6 * the FreeBSD Foundation. 7 * 8 * Redistribution and use in source and binary forms, with or without 9 * modification, are permitted provided that the following conditions 10 * are met: 11 * 1. Redistributions of source code must retain the above copyright 12 * notice, this list of conditions and the following disclaimer. 13 * 2. Redistributions in binary form must reproduce the above copyright 14 * notice, this list of conditions and the following disclaimer in the 15 * documentation and/or other materials provided with the distribution. 16 * 17 * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND 18 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE 21 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 22 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS 23 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 24 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 25 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY 26 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF 27 * SUCH DAMAGE. 28 * 29 * $P4: //depot/projects/trustedbsd/openbsm/bin/auditdistd/proto_tls.c#2 $ 30 */ 31 32 #include <config/config.h> 33 34 #include <sys/param.h> /* MAXHOSTNAMELEN */ 35 #include <sys/socket.h> 36 37 #include <arpa/inet.h> 38 39 #include <netinet/in.h> 40 #include <netinet/tcp.h> 41 42 #include <errno.h> 43 #include <fcntl.h> 44 #include <netdb.h> 45 #include <signal.h> 46 #include <stdbool.h> 47 #include <stdint.h> 48 #include <stdio.h> 49 #include <string.h> 50 #include <unistd.h> 51 52 #include <openssl/err.h> 53 #include <openssl/ssl.h> 54 55 #include <compat/compat.h> 56 #ifndef HAVE_CLOSEFROM 57 #include <compat/closefrom.h> 58 #endif 59 #ifndef HAVE_STRLCPY 60 #include <compat/strlcpy.h> 61 #endif 62 63 #include "pjdlog.h" 64 #include "proto_impl.h" 65 #include "sandbox.h" 66 #include "subr.h" 67 68 #define TLS_CTX_MAGIC 0x715c7 69 struct tls_ctx { 70 int tls_magic; 71 struct proto_conn *tls_sock; 72 struct proto_conn *tls_tcp; 73 char tls_laddr[256]; 74 char tls_raddr[256]; 75 int tls_side; 76 #define TLS_SIDE_CLIENT 0 77 #define TLS_SIDE_SERVER_LISTEN 1 78 #define TLS_SIDE_SERVER_WORK 2 79 bool tls_wait_called; 80 }; 81 82 #define TLS_DEFAULT_TIMEOUT 30 83 84 static int tls_connect_wait(void *ctx, int timeout); 85 static void tls_close(void *ctx); 86 87 static void 88 block(int fd) 89 { 90 int flags; 91 92 flags = fcntl(fd, F_GETFL); 93 if (flags == -1) 94 pjdlog_exit(EX_TEMPFAIL, "fcntl(F_GETFL) failed"); 95 flags &= ~O_NONBLOCK; 96 if (fcntl(fd, F_SETFL, flags) == -1) 97 pjdlog_exit(EX_TEMPFAIL, "fcntl(F_SETFL) failed"); 98 } 99 100 static void 101 nonblock(int fd) 102 { 103 int flags; 104 105 flags = fcntl(fd, F_GETFL); 106 if (flags == -1) 107 pjdlog_exit(EX_TEMPFAIL, "fcntl(F_GETFL) failed"); 108 flags |= O_NONBLOCK; 109 if (fcntl(fd, F_SETFL, flags) == -1) 110 pjdlog_exit(EX_TEMPFAIL, "fcntl(F_SETFL) failed"); 111 } 112 113 static int 114 wait_for_fd(int fd, int timeout) 115 { 116 struct timeval tv; 117 fd_set fdset; 118 int error, ret; 119 120 error = 0; 121 122 for (;;) { 123 FD_ZERO(&fdset); 124 FD_SET(fd, &fdset); 125 126 tv.tv_sec = timeout; 127 tv.tv_usec = 0; 128 129 ret = select(fd + 1, NULL, &fdset, NULL, 130 timeout == -1 ? NULL : &tv); 131 if (ret == 0) { 132 error = ETIMEDOUT; 133 break; 134 } else if (ret == -1) { 135 if (errno == EINTR) 136 continue; 137 error = errno; 138 break; 139 } 140 PJDLOG_ASSERT(ret > 0); 141 PJDLOG_ASSERT(FD_ISSET(fd, &fdset)); 142 break; 143 } 144 145 return (error); 146 } 147 148 static void 149 ssl_log_errors(void) 150 { 151 unsigned long error; 152 153 while ((error = ERR_get_error()) != 0) 154 pjdlog_error("SSL error: %s", ERR_error_string(error, NULL)); 155 } 156 157 static int 158 ssl_check_error(SSL *ssl, int ret) 159 { 160 int error; 161 162 error = SSL_get_error(ssl, ret); 163 164 switch (error) { 165 case SSL_ERROR_NONE: 166 return (0); 167 case SSL_ERROR_WANT_READ: 168 pjdlog_debug(2, "SSL_ERROR_WANT_READ"); 169 return (-1); 170 case SSL_ERROR_WANT_WRITE: 171 pjdlog_debug(2, "SSL_ERROR_WANT_WRITE"); 172 return (-1); 173 case SSL_ERROR_ZERO_RETURN: 174 pjdlog_exitx(EX_OK, "Connection closed."); 175 case SSL_ERROR_SYSCALL: 176 ssl_log_errors(); 177 pjdlog_exitx(EX_TEMPFAIL, "SSL I/O error."); 178 case SSL_ERROR_SSL: 179 ssl_log_errors(); 180 pjdlog_exitx(EX_TEMPFAIL, "SSL protocol error."); 181 default: 182 ssl_log_errors(); 183 pjdlog_exitx(EX_TEMPFAIL, "Unknown SSL error (%d).", error); 184 } 185 } 186 187 static void 188 tcp_recv_ssl_send(int recvfd, SSL *sendssl) 189 { 190 static unsigned char buf[65536]; 191 ssize_t tcpdone; 192 int sendfd, ssldone; 193 194 sendfd = SSL_get_fd(sendssl); 195 PJDLOG_ASSERT(sendfd >= 0); 196 pjdlog_debug(2, "%s: start %d -> %d", __func__, recvfd, sendfd); 197 for (;;) { 198 tcpdone = recv(recvfd, buf, sizeof(buf), 0); 199 pjdlog_debug(2, "%s: recv() returned %zd", __func__, tcpdone); 200 if (tcpdone == 0) { 201 pjdlog_debug(1, "Connection terminated."); 202 exit(0); 203 } else if (tcpdone == -1) { 204 if (errno == EINTR) 205 continue; 206 else if (errno == EAGAIN) 207 break; 208 pjdlog_exit(EX_TEMPFAIL, "recv() failed"); 209 } 210 for (;;) { 211 ssldone = SSL_write(sendssl, buf, (int)tcpdone); 212 pjdlog_debug(2, "%s: send() returned %d", __func__, 213 ssldone); 214 if (ssl_check_error(sendssl, ssldone) == -1) { 215 (void)wait_for_fd(sendfd, -1); 216 continue; 217 } 218 PJDLOG_ASSERT(ssldone == tcpdone); 219 break; 220 } 221 } 222 pjdlog_debug(2, "%s: done %d -> %d", __func__, recvfd, sendfd); 223 } 224 225 static void 226 ssl_recv_tcp_send(SSL *recvssl, int sendfd) 227 { 228 static unsigned char buf[65536]; 229 unsigned char *ptr; 230 ssize_t tcpdone; 231 size_t todo; 232 int recvfd, ssldone; 233 234 recvfd = SSL_get_fd(recvssl); 235 PJDLOG_ASSERT(recvfd >= 0); 236 pjdlog_debug(2, "%s: start %d -> %d", __func__, recvfd, sendfd); 237 for (;;) { 238 ssldone = SSL_read(recvssl, buf, sizeof(buf)); 239 pjdlog_debug(2, "%s: SSL_read() returned %d", __func__, 240 ssldone); 241 if (ssl_check_error(recvssl, ssldone) == -1) 242 break; 243 todo = (size_t)ssldone; 244 ptr = buf; 245 do { 246 tcpdone = send(sendfd, ptr, todo, MSG_NOSIGNAL); 247 pjdlog_debug(2, "%s: send() returned %zd", __func__, 248 tcpdone); 249 if (tcpdone == 0) { 250 pjdlog_debug(1, "Connection terminated."); 251 exit(0); 252 } else if (tcpdone == -1) { 253 if (errno == EINTR || errno == ENOBUFS) 254 continue; 255 if (errno == EAGAIN) { 256 (void)wait_for_fd(sendfd, -1); 257 continue; 258 } 259 pjdlog_exit(EX_TEMPFAIL, "send() failed"); 260 } 261 todo -= tcpdone; 262 ptr += tcpdone; 263 } while (todo > 0); 264 } 265 pjdlog_debug(2, "%s: done %d -> %d", __func__, recvfd, sendfd); 266 } 267 268 static void 269 tls_loop(int sockfd, SSL *tcpssl) 270 { 271 fd_set fds; 272 int maxfd, tcpfd; 273 274 tcpfd = SSL_get_fd(tcpssl); 275 PJDLOG_ASSERT(tcpfd >= 0); 276 277 for (;;) { 278 FD_ZERO(&fds); 279 FD_SET(sockfd, &fds); 280 FD_SET(tcpfd, &fds); 281 maxfd = MAX(sockfd, tcpfd); 282 283 PJDLOG_ASSERT(maxfd + 1 <= (int)FD_SETSIZE); 284 if (select(maxfd + 1, &fds, NULL, NULL, NULL) == -1) { 285 if (errno == EINTR) 286 continue; 287 pjdlog_exit(EX_TEMPFAIL, "select() failed"); 288 } 289 if (FD_ISSET(sockfd, &fds)) 290 tcp_recv_ssl_send(sockfd, tcpssl); 291 if (FD_ISSET(tcpfd, &fds)) 292 ssl_recv_tcp_send(tcpssl, sockfd); 293 } 294 } 295 296 static void 297 tls_certificate_verify(SSL *ssl, const char *fingerprint) 298 { 299 unsigned char md[EVP_MAX_MD_SIZE]; 300 char mdstr[sizeof("SHA256=") - 1 + EVP_MAX_MD_SIZE * 3]; 301 char *mdstrp; 302 unsigned int i, mdsize; 303 X509 *cert; 304 305 if (fingerprint[0] == '\0') { 306 pjdlog_debug(1, "No fingerprint verification requested."); 307 return; 308 } 309 310 cert = SSL_get_peer_certificate(ssl); 311 if (cert == NULL) 312 pjdlog_exitx(EX_TEMPFAIL, "No peer certificate received."); 313 314 if (X509_digest(cert, EVP_sha256(), md, &mdsize) != 1) 315 pjdlog_exitx(EX_TEMPFAIL, "X509_digest() failed."); 316 PJDLOG_ASSERT(mdsize <= EVP_MAX_MD_SIZE); 317 318 X509_free(cert); 319 320 (void)strlcpy(mdstr, "SHA256=", sizeof(mdstr)); 321 mdstrp = mdstr + strlen(mdstr); 322 for (i = 0; i < mdsize; i++) { 323 PJDLOG_VERIFY(mdstrp + 3 <= mdstr + sizeof(mdstr)); 324 (void)sprintf(mdstrp, "%02hhX:", md[i]); 325 mdstrp += 3; 326 } 327 /* Clear last colon. */ 328 mdstrp[-1] = '\0'; 329 if (strcasecmp(mdstr, fingerprint) != 0) { 330 pjdlog_exitx(EX_NOPERM, 331 "Finger print doesn't match. Received \"%s\", expected \"%s\"", 332 mdstr, fingerprint); 333 } 334 } 335 336 static void 337 tls_exec_client(const char *user, int startfd, const char *srcaddr, 338 const char *dstaddr, const char *fingerprint, const char *defport, 339 int timeout, int debuglevel) 340 { 341 struct proto_conn *tcp; 342 char *saddr, *daddr; 343 SSL_CTX *sslctx; 344 SSL *ssl; 345 long ret; 346 int sockfd, tcpfd; 347 uint8_t connected; 348 349 pjdlog_debug_set(debuglevel); 350 pjdlog_prefix_set("[TLS sandbox] (client) "); 351 #ifdef HAVE_SETPROCTITLE 352 setproctitle("[TLS sandbox] (client) "); 353 #endif 354 proto_set("tcp:port", defport); 355 356 sockfd = startfd; 357 358 /* Change tls:// to tcp://. */ 359 if (srcaddr == NULL) { 360 saddr = NULL; 361 } else { 362 saddr = strdup(srcaddr); 363 if (saddr == NULL) 364 pjdlog_exitx(EX_TEMPFAIL, "Unable to allocate memory."); 365 bcopy("tcp://", saddr, 6); 366 } 367 daddr = strdup(dstaddr); 368 if (daddr == NULL) 369 pjdlog_exitx(EX_TEMPFAIL, "Unable to allocate memory."); 370 bcopy("tcp://", daddr, 6); 371 372 /* Establish TCP connection. */ 373 if (proto_connect(saddr, daddr, timeout, &tcp) == -1) 374 exit(EX_TEMPFAIL); 375 376 SSL_load_error_strings(); 377 SSL_library_init(); 378 379 /* 380 * TODO: On FreeBSD we could move this below sandbox() once libc and 381 * libcrypto use sysctl kern.arandom to obtain random data 382 * instead of /dev/urandom and friends. 383 */ 384 sslctx = SSL_CTX_new(TLSv1_client_method()); 385 if (sslctx == NULL) 386 pjdlog_exitx(EX_TEMPFAIL, "SSL_CTX_new() failed."); 387 388 if (sandbox(user, true, "proto_tls client: %s", dstaddr) != 0) 389 pjdlog_exitx(EX_CONFIG, "Unable to sandbox TLS client."); 390 pjdlog_debug(1, "Privileges successfully dropped."); 391 392 SSL_CTX_set_options(sslctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3); 393 394 /* Load CA certs. */ 395 /* TODO */ 396 //SSL_CTX_load_verify_locations(sslctx, cacerts_file, NULL); 397 398 ssl = SSL_new(sslctx); 399 if (ssl == NULL) 400 pjdlog_exitx(EX_TEMPFAIL, "SSL_new() failed."); 401 402 tcpfd = proto_descriptor(tcp); 403 404 block(tcpfd); 405 406 if (SSL_set_fd(ssl, tcpfd) != 1) 407 pjdlog_exitx(EX_TEMPFAIL, "SSL_set_fd() failed."); 408 409 ret = SSL_connect(ssl); 410 ssl_check_error(ssl, (int)ret); 411 412 nonblock(sockfd); 413 nonblock(tcpfd); 414 415 tls_certificate_verify(ssl, fingerprint); 416 417 /* 418 * The following byte is send to make proto_connect_wait() to work. 419 */ 420 connected = 1; 421 for (;;) { 422 switch (send(sockfd, &connected, sizeof(connected), 0)) { 423 case -1: 424 if (errno == EINTR || errno == ENOBUFS) 425 continue; 426 if (errno == EAGAIN) { 427 (void)wait_for_fd(sockfd, -1); 428 continue; 429 } 430 pjdlog_exit(EX_TEMPFAIL, "send() failed"); 431 case 0: 432 pjdlog_debug(1, "Connection terminated."); 433 exit(0); 434 case 1: 435 break; 436 } 437 break; 438 } 439 440 tls_loop(sockfd, ssl); 441 } 442 443 static void 444 tls_call_exec_client(struct proto_conn *sock, const char *srcaddr, 445 const char *dstaddr, int timeout) 446 { 447 char *timeoutstr, *startfdstr, *debugstr; 448 int startfd; 449 450 /* Declare that we are receiver. */ 451 proto_recv(sock, NULL, 0); 452 453 if (pjdlog_mode_get() == PJDLOG_MODE_STD) 454 startfd = 3; 455 else /* if (pjdlog_mode_get() == PJDLOG_MODE_SYSLOG) */ 456 startfd = 0; 457 458 if (proto_descriptor(sock) != startfd) { 459 /* Move socketpair descriptor to descriptor number startfd. */ 460 if (dup2(proto_descriptor(sock), startfd) == -1) 461 pjdlog_exit(EX_OSERR, "dup2() failed"); 462 proto_close(sock); 463 } else { 464 /* 465 * The FD_CLOEXEC is cleared by dup2(2), so when we not 466 * call it, we have to clear it by hand in case it is set. 467 */ 468 if (fcntl(startfd, F_SETFD, 0) == -1) 469 pjdlog_exit(EX_OSERR, "fcntl() failed"); 470 } 471 472 closefrom(startfd + 1); 473 474 if (asprintf(&startfdstr, "%d", startfd) == -1) 475 pjdlog_exit(EX_TEMPFAIL, "asprintf() failed"); 476 if (timeout == -1) 477 timeout = TLS_DEFAULT_TIMEOUT; 478 if (asprintf(&timeoutstr, "%d", timeout) == -1) 479 pjdlog_exit(EX_TEMPFAIL, "asprintf() failed"); 480 if (asprintf(&debugstr, "%d", pjdlog_debug_get()) == -1) 481 pjdlog_exit(EX_TEMPFAIL, "asprintf() failed"); 482 483 execl(proto_get("execpath"), proto_get("execpath"), "proto", "tls", 484 proto_get("user"), "client", startfdstr, 485 srcaddr == NULL ? "" : srcaddr, dstaddr, 486 proto_get("tls:fingerprint"), proto_get("tcp:port"), timeoutstr, 487 debugstr, NULL); 488 pjdlog_exit(EX_SOFTWARE, "execl() failed"); 489 } 490 491 static int 492 tls_connect(const char *srcaddr, const char *dstaddr, int timeout, void **ctxp) 493 { 494 struct tls_ctx *tlsctx; 495 struct proto_conn *sock; 496 pid_t pid; 497 int error; 498 499 PJDLOG_ASSERT(srcaddr == NULL || srcaddr[0] != '\0'); 500 PJDLOG_ASSERT(dstaddr != NULL); 501 PJDLOG_ASSERT(timeout >= -1); 502 PJDLOG_ASSERT(ctxp != NULL); 503 504 if (strncmp(dstaddr, "tls://", 6) != 0) 505 return (-1); 506 if (srcaddr != NULL && strncmp(srcaddr, "tls://", 6) != 0) 507 return (-1); 508 509 if (proto_connect(NULL, "socketpair://", -1, &sock) == -1) 510 return (errno); 511 512 #if 0 513 /* 514 * We use rfork() with the following flags to disable SIGCHLD 515 * delivery upon the sandbox process exit. 516 */ 517 pid = rfork(RFFDG | RFPROC | RFTSIGZMB | RFTSIGFLAGS(0)); 518 #else 519 /* 520 * We don't use rfork() to be able to log information about sandbox 521 * process exiting. 522 */ 523 pid = fork(); 524 #endif 525 switch (pid) { 526 case -1: 527 /* Failure. */ 528 error = errno; 529 proto_close(sock); 530 return (error); 531 case 0: 532 /* Child. */ 533 pjdlog_prefix_set("[TLS sandbox] (client) "); 534 #ifdef HAVE_SETPROCTITLE 535 setproctitle("[TLS sandbox] (client) "); 536 #endif 537 tls_call_exec_client(sock, srcaddr, dstaddr, timeout); 538 /* NOTREACHED */ 539 default: 540 /* Parent. */ 541 tlsctx = calloc(1, sizeof(*tlsctx)); 542 if (tlsctx == NULL) { 543 error = errno; 544 proto_close(sock); 545 (void)kill(pid, SIGKILL); 546 return (error); 547 } 548 proto_send(sock, NULL, 0); 549 tlsctx->tls_sock = sock; 550 tlsctx->tls_tcp = NULL; 551 tlsctx->tls_side = TLS_SIDE_CLIENT; 552 tlsctx->tls_wait_called = false; 553 tlsctx->tls_magic = TLS_CTX_MAGIC; 554 if (timeout >= 0) { 555 error = tls_connect_wait(tlsctx, timeout); 556 if (error != 0) { 557 (void)kill(pid, SIGKILL); 558 tls_close(tlsctx); 559 return (error); 560 } 561 } 562 *ctxp = tlsctx; 563 return (0); 564 } 565 } 566 567 static int 568 tls_connect_wait(void *ctx, int timeout) 569 { 570 struct tls_ctx *tlsctx = ctx; 571 int error, sockfd; 572 uint8_t connected; 573 574 PJDLOG_ASSERT(tlsctx != NULL); 575 PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC); 576 PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT); 577 PJDLOG_ASSERT(tlsctx->tls_sock != NULL); 578 PJDLOG_ASSERT(!tlsctx->tls_wait_called); 579 PJDLOG_ASSERT(timeout >= 0); 580 581 sockfd = proto_descriptor(tlsctx->tls_sock); 582 error = wait_for_fd(sockfd, timeout); 583 if (error != 0) 584 return (error); 585 586 for (;;) { 587 switch (recv(sockfd, &connected, sizeof(connected), 588 MSG_WAITALL)) { 589 case -1: 590 if (errno == EINTR || errno == ENOBUFS) 591 continue; 592 error = errno; 593 break; 594 case 0: 595 pjdlog_debug(1, "Connection terminated."); 596 error = ENOTCONN; 597 break; 598 case 1: 599 tlsctx->tls_wait_called = true; 600 break; 601 } 602 break; 603 } 604 605 return (error); 606 } 607 608 static int 609 tls_server(const char *lstaddr, void **ctxp) 610 { 611 struct proto_conn *tcp; 612 struct tls_ctx *tlsctx; 613 char *laddr; 614 int error; 615 616 if (strncmp(lstaddr, "tls://", 6) != 0) 617 return (-1); 618 619 tlsctx = malloc(sizeof(*tlsctx)); 620 if (tlsctx == NULL) { 621 pjdlog_warning("Unable to allocate memory."); 622 return (ENOMEM); 623 } 624 625 laddr = strdup(lstaddr); 626 if (laddr == NULL) { 627 free(tlsctx); 628 pjdlog_warning("Unable to allocate memory."); 629 return (ENOMEM); 630 } 631 bcopy("tcp://", laddr, 6); 632 633 if (proto_server(laddr, &tcp) == -1) { 634 error = errno; 635 free(tlsctx); 636 free(laddr); 637 return (error); 638 } 639 free(laddr); 640 641 tlsctx->tls_sock = NULL; 642 tlsctx->tls_tcp = tcp; 643 tlsctx->tls_side = TLS_SIDE_SERVER_LISTEN; 644 tlsctx->tls_wait_called = true; 645 tlsctx->tls_magic = TLS_CTX_MAGIC; 646 *ctxp = tlsctx; 647 648 return (0); 649 } 650 651 static void 652 tls_exec_server(const char *user, int startfd, const char *privkey, 653 const char *cert, int debuglevel) 654 { 655 SSL_CTX *sslctx; 656 SSL *ssl; 657 int sockfd, tcpfd, ret; 658 659 pjdlog_debug_set(debuglevel); 660 pjdlog_prefix_set("[TLS sandbox] (server) "); 661 #ifdef HAVE_SETPROCTITLE 662 setproctitle("[TLS sandbox] (server) "); 663 #endif 664 665 sockfd = startfd; 666 tcpfd = startfd + 1; 667 668 SSL_load_error_strings(); 669 SSL_library_init(); 670 671 sslctx = SSL_CTX_new(TLSv1_server_method()); 672 if (sslctx == NULL) 673 pjdlog_exitx(EX_TEMPFAIL, "SSL_CTX_new() failed."); 674 675 SSL_CTX_set_options(sslctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3); 676 677 ssl = SSL_new(sslctx); 678 if (ssl == NULL) 679 pjdlog_exitx(EX_TEMPFAIL, "SSL_new() failed."); 680 681 if (SSL_use_RSAPrivateKey_file(ssl, privkey, SSL_FILETYPE_PEM) != 1) { 682 ssl_log_errors(); 683 pjdlog_exitx(EX_CONFIG, 684 "SSL_use_RSAPrivateKey_file(%s) failed.", privkey); 685 } 686 687 if (SSL_use_certificate_file(ssl, cert, SSL_FILETYPE_PEM) != 1) { 688 ssl_log_errors(); 689 pjdlog_exitx(EX_CONFIG, "SSL_use_certificate_file(%s) failed.", 690 cert); 691 } 692 693 if (sandbox(user, true, "proto_tls server") != 0) 694 pjdlog_exitx(EX_CONFIG, "Unable to sandbox TLS server."); 695 pjdlog_debug(1, "Privileges successfully dropped."); 696 697 nonblock(sockfd); 698 nonblock(tcpfd); 699 700 if (SSL_set_fd(ssl, tcpfd) != 1) 701 pjdlog_exitx(EX_TEMPFAIL, "SSL_set_fd() failed."); 702 703 ret = SSL_accept(ssl); 704 ssl_check_error(ssl, ret); 705 706 tls_loop(sockfd, ssl); 707 } 708 709 static void 710 tls_call_exec_server(struct proto_conn *sock, struct proto_conn *tcp) 711 { 712 int startfd, sockfd, tcpfd, safefd; 713 char *startfdstr, *debugstr; 714 715 if (pjdlog_mode_get() == PJDLOG_MODE_STD) 716 startfd = 3; 717 else /* if (pjdlog_mode_get() == PJDLOG_MODE_SYSLOG) */ 718 startfd = 0; 719 720 /* Declare that we are receiver. */ 721 proto_send(sock, NULL, 0); 722 723 sockfd = proto_descriptor(sock); 724 tcpfd = proto_descriptor(tcp); 725 726 safefd = MAX(sockfd, tcpfd); 727 safefd = MAX(safefd, startfd); 728 safefd++; 729 730 /* Move sockfd and tcpfd to safe numbers first. */ 731 if (dup2(sockfd, safefd) == -1) 732 pjdlog_exit(EX_OSERR, "dup2() failed"); 733 proto_close(sock); 734 sockfd = safefd; 735 if (dup2(tcpfd, safefd + 1) == -1) 736 pjdlog_exit(EX_OSERR, "dup2() failed"); 737 proto_close(tcp); 738 tcpfd = safefd + 1; 739 740 /* Move socketpair descriptor to descriptor number startfd. */ 741 if (dup2(sockfd, startfd) == -1) 742 pjdlog_exit(EX_OSERR, "dup2() failed"); 743 (void)close(sockfd); 744 /* Move tcp descriptor to descriptor number startfd + 1. */ 745 if (dup2(tcpfd, startfd + 1) == -1) 746 pjdlog_exit(EX_OSERR, "dup2() failed"); 747 (void)close(tcpfd); 748 749 closefrom(startfd + 2); 750 751 /* 752 * Even if FD_CLOEXEC was set on descriptors before dup2(), it should 753 * have been cleared on dup2(), but better be safe than sorry. 754 */ 755 if (fcntl(startfd, F_SETFD, 0) == -1) 756 pjdlog_exit(EX_OSERR, "fcntl() failed"); 757 if (fcntl(startfd + 1, F_SETFD, 0) == -1) 758 pjdlog_exit(EX_OSERR, "fcntl() failed"); 759 760 if (asprintf(&startfdstr, "%d", startfd) == -1) 761 pjdlog_exit(EX_TEMPFAIL, "asprintf() failed"); 762 if (asprintf(&debugstr, "%d", pjdlog_debug_get()) == -1) 763 pjdlog_exit(EX_TEMPFAIL, "asprintf() failed"); 764 765 execl(proto_get("execpath"), proto_get("execpath"), "proto", "tls", 766 proto_get("user"), "server", startfdstr, proto_get("tls:keyfile"), 767 proto_get("tls:certfile"), debugstr, NULL); 768 pjdlog_exit(EX_SOFTWARE, "execl() failed"); 769 } 770 771 static int 772 tls_accept(void *ctx, void **newctxp) 773 { 774 struct tls_ctx *tlsctx = ctx; 775 struct tls_ctx *newtlsctx; 776 struct proto_conn *sock, *tcp; 777 pid_t pid; 778 int error; 779 780 PJDLOG_ASSERT(tlsctx != NULL); 781 PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC); 782 PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_SERVER_LISTEN); 783 784 if (proto_connect(NULL, "socketpair://", -1, &sock) == -1) 785 return (errno); 786 787 /* Accept TCP connection. */ 788 if (proto_accept(tlsctx->tls_tcp, &tcp) == -1) { 789 error = errno; 790 proto_close(sock); 791 return (error); 792 } 793 794 pid = fork(); 795 switch (pid) { 796 case -1: 797 /* Failure. */ 798 error = errno; 799 proto_close(sock); 800 return (error); 801 case 0: 802 /* Child. */ 803 pjdlog_prefix_set("[TLS sandbox] (server) "); 804 #ifdef HAVE_SETPROCTITLE 805 setproctitle("[TLS sandbox] (server) "); 806 #endif 807 /* Close listen socket. */ 808 proto_close(tlsctx->tls_tcp); 809 tls_call_exec_server(sock, tcp); 810 /* NOTREACHED */ 811 PJDLOG_ABORT("Unreachable."); 812 default: 813 /* Parent. */ 814 newtlsctx = calloc(1, sizeof(*tlsctx)); 815 if (newtlsctx == NULL) { 816 error = errno; 817 proto_close(sock); 818 proto_close(tcp); 819 (void)kill(pid, SIGKILL); 820 return (error); 821 } 822 proto_local_address(tcp, newtlsctx->tls_laddr, 823 sizeof(newtlsctx->tls_laddr)); 824 PJDLOG_ASSERT(strncmp(newtlsctx->tls_laddr, "tcp://", 6) == 0); 825 bcopy("tls://", newtlsctx->tls_laddr, 6); 826 *strrchr(newtlsctx->tls_laddr, ':') = '\0'; 827 proto_remote_address(tcp, newtlsctx->tls_raddr, 828 sizeof(newtlsctx->tls_raddr)); 829 PJDLOG_ASSERT(strncmp(newtlsctx->tls_raddr, "tcp://", 6) == 0); 830 bcopy("tls://", newtlsctx->tls_raddr, 6); 831 *strrchr(newtlsctx->tls_raddr, ':') = '\0'; 832 proto_close(tcp); 833 proto_recv(sock, NULL, 0); 834 newtlsctx->tls_sock = sock; 835 newtlsctx->tls_tcp = NULL; 836 newtlsctx->tls_wait_called = true; 837 newtlsctx->tls_side = TLS_SIDE_SERVER_WORK; 838 newtlsctx->tls_magic = TLS_CTX_MAGIC; 839 *newctxp = newtlsctx; 840 return (0); 841 } 842 } 843 844 static int 845 tls_wrap(int fd, bool client, void **ctxp) 846 { 847 struct tls_ctx *tlsctx; 848 struct proto_conn *sock; 849 int error; 850 851 tlsctx = calloc(1, sizeof(*tlsctx)); 852 if (tlsctx == NULL) 853 return (errno); 854 855 if (proto_wrap("socketpair", client, fd, &sock) == -1) { 856 error = errno; 857 free(tlsctx); 858 return (error); 859 } 860 861 tlsctx->tls_sock = sock; 862 tlsctx->tls_tcp = NULL; 863 tlsctx->tls_wait_called = (client ? false : true); 864 tlsctx->tls_side = (client ? TLS_SIDE_CLIENT : TLS_SIDE_SERVER_WORK); 865 tlsctx->tls_magic = TLS_CTX_MAGIC; 866 *ctxp = tlsctx; 867 868 return (0); 869 } 870 871 static int 872 tls_send(void *ctx, const unsigned char *data, size_t size, int fd) 873 { 874 struct tls_ctx *tlsctx = ctx; 875 876 PJDLOG_ASSERT(tlsctx != NULL); 877 PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC); 878 PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT || 879 tlsctx->tls_side == TLS_SIDE_SERVER_WORK); 880 PJDLOG_ASSERT(tlsctx->tls_sock != NULL); 881 PJDLOG_ASSERT(tlsctx->tls_wait_called); 882 PJDLOG_ASSERT(fd == -1); 883 884 if (proto_send(tlsctx->tls_sock, data, size) == -1) 885 return (errno); 886 887 return (0); 888 } 889 890 static int 891 tls_recv(void *ctx, unsigned char *data, size_t size, int *fdp) 892 { 893 struct tls_ctx *tlsctx = ctx; 894 895 PJDLOG_ASSERT(tlsctx != NULL); 896 PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC); 897 PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT || 898 tlsctx->tls_side == TLS_SIDE_SERVER_WORK); 899 PJDLOG_ASSERT(tlsctx->tls_sock != NULL); 900 PJDLOG_ASSERT(tlsctx->tls_wait_called); 901 PJDLOG_ASSERT(fdp == NULL); 902 903 if (proto_recv(tlsctx->tls_sock, data, size) == -1) 904 return (errno); 905 906 return (0); 907 } 908 909 static int 910 tls_descriptor(const void *ctx) 911 { 912 const struct tls_ctx *tlsctx = ctx; 913 914 PJDLOG_ASSERT(tlsctx != NULL); 915 PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC); 916 917 switch (tlsctx->tls_side) { 918 case TLS_SIDE_CLIENT: 919 case TLS_SIDE_SERVER_WORK: 920 PJDLOG_ASSERT(tlsctx->tls_sock != NULL); 921 922 return (proto_descriptor(tlsctx->tls_sock)); 923 case TLS_SIDE_SERVER_LISTEN: 924 PJDLOG_ASSERT(tlsctx->tls_tcp != NULL); 925 926 return (proto_descriptor(tlsctx->tls_tcp)); 927 default: 928 PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side); 929 } 930 } 931 932 static bool 933 tcp_address_match(const void *ctx, const char *addr) 934 { 935 const struct tls_ctx *tlsctx = ctx; 936 937 PJDLOG_ASSERT(tlsctx != NULL); 938 PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC); 939 940 return (strcmp(tlsctx->tls_raddr, addr) == 0); 941 } 942 943 static void 944 tls_local_address(const void *ctx, char *addr, size_t size) 945 { 946 const struct tls_ctx *tlsctx = ctx; 947 948 PJDLOG_ASSERT(tlsctx != NULL); 949 PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC); 950 PJDLOG_ASSERT(tlsctx->tls_wait_called); 951 952 switch (tlsctx->tls_side) { 953 case TLS_SIDE_CLIENT: 954 PJDLOG_ASSERT(tlsctx->tls_sock != NULL); 955 956 PJDLOG_VERIFY(strlcpy(addr, "tls://N/A", size) < size); 957 break; 958 case TLS_SIDE_SERVER_WORK: 959 PJDLOG_ASSERT(tlsctx->tls_sock != NULL); 960 961 PJDLOG_VERIFY(strlcpy(addr, tlsctx->tls_laddr, size) < size); 962 break; 963 case TLS_SIDE_SERVER_LISTEN: 964 PJDLOG_ASSERT(tlsctx->tls_tcp != NULL); 965 966 proto_local_address(tlsctx->tls_tcp, addr, size); 967 PJDLOG_ASSERT(strncmp(addr, "tcp://", 6) == 0); 968 /* Replace tcp:// prefix with tls:// */ 969 bcopy("tls://", addr, 6); 970 break; 971 default: 972 PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side); 973 } 974 } 975 976 static void 977 tls_remote_address(const void *ctx, char *addr, size_t size) 978 { 979 const struct tls_ctx *tlsctx = ctx; 980 981 PJDLOG_ASSERT(tlsctx != NULL); 982 PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC); 983 PJDLOG_ASSERT(tlsctx->tls_wait_called); 984 985 switch (tlsctx->tls_side) { 986 case TLS_SIDE_CLIENT: 987 PJDLOG_ASSERT(tlsctx->tls_sock != NULL); 988 989 PJDLOG_VERIFY(strlcpy(addr, "tls://N/A", size) < size); 990 break; 991 case TLS_SIDE_SERVER_WORK: 992 PJDLOG_ASSERT(tlsctx->tls_sock != NULL); 993 994 PJDLOG_VERIFY(strlcpy(addr, tlsctx->tls_raddr, size) < size); 995 break; 996 case TLS_SIDE_SERVER_LISTEN: 997 PJDLOG_ASSERT(tlsctx->tls_tcp != NULL); 998 999 proto_remote_address(tlsctx->tls_tcp, addr, size); 1000 PJDLOG_ASSERT(strncmp(addr, "tcp://", 6) == 0); 1001 /* Replace tcp:// prefix with tls:// */ 1002 bcopy("tls://", addr, 6); 1003 break; 1004 default: 1005 PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side); 1006 } 1007 } 1008 1009 static void 1010 tls_close(void *ctx) 1011 { 1012 struct tls_ctx *tlsctx = ctx; 1013 1014 PJDLOG_ASSERT(tlsctx != NULL); 1015 PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC); 1016 1017 if (tlsctx->tls_sock != NULL) { 1018 proto_close(tlsctx->tls_sock); 1019 tlsctx->tls_sock = NULL; 1020 } 1021 if (tlsctx->tls_tcp != NULL) { 1022 proto_close(tlsctx->tls_tcp); 1023 tlsctx->tls_tcp = NULL; 1024 } 1025 tlsctx->tls_side = 0; 1026 tlsctx->tls_magic = 0; 1027 free(tlsctx); 1028 } 1029 1030 static int 1031 tls_exec(int argc, char *argv[]) 1032 { 1033 1034 PJDLOG_ASSERT(argc > 3); 1035 PJDLOG_ASSERT(strcmp(argv[0], "tls") == 0); 1036 1037 pjdlog_init(atoi(argv[3]) == 0 ? PJDLOG_MODE_SYSLOG : PJDLOG_MODE_STD); 1038 1039 if (strcmp(argv[2], "client") == 0) { 1040 if (argc != 10) 1041 return (EINVAL); 1042 tls_exec_client(argv[1], atoi(argv[3]), 1043 argv[4][0] == '\0' ? NULL : argv[4], argv[5], argv[6], 1044 argv[7], atoi(argv[8]), atoi(argv[9])); 1045 } else if (strcmp(argv[2], "server") == 0) { 1046 if (argc != 7) 1047 return (EINVAL); 1048 tls_exec_server(argv[1], atoi(argv[3]), argv[4], argv[5], 1049 atoi(argv[6])); 1050 } 1051 return (EINVAL); 1052 } 1053 1054 static struct proto tls_proto = { 1055 .prt_name = "tls", 1056 .prt_connect = tls_connect, 1057 .prt_connect_wait = tls_connect_wait, 1058 .prt_server = tls_server, 1059 .prt_accept = tls_accept, 1060 .prt_wrap = tls_wrap, 1061 .prt_send = tls_send, 1062 .prt_recv = tls_recv, 1063 .prt_descriptor = tls_descriptor, 1064 .prt_address_match = tcp_address_match, 1065 .prt_local_address = tls_local_address, 1066 .prt_remote_address = tls_remote_address, 1067 .prt_close = tls_close, 1068 .prt_exec = tls_exec 1069 }; 1070 1071 static __constructor void 1072 tls_ctor(void) 1073 { 1074 1075 proto_register(&tls_proto, false); 1076 } 1077