1 /*- 2 * Copyright (c) 2011 The FreeBSD Foundation 3 * All rights reserved. 4 * 5 * This software was developed by Pawel Jakub Dawidek under sponsorship from 6 * the FreeBSD Foundation. 7 * 8 * Redistribution and use in source and binary forms, with or without 9 * modification, are permitted provided that the following conditions 10 * are met: 11 * 1. Redistributions of source code must retain the above copyright 12 * notice, this list of conditions and the following disclaimer. 13 * 2. Redistributions in binary form must reproduce the above copyright 14 * notice, this list of conditions and the following disclaimer in the 15 * documentation and/or other materials provided with the distribution. 16 * 17 * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND 18 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE 21 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 22 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS 23 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 24 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 25 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY 26 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF 27 * SUCH DAMAGE. 28 */ 29 30 #include <config/config.h> 31 32 #include <sys/param.h> /* MAXHOSTNAMELEN */ 33 #include <sys/socket.h> 34 35 #include <arpa/inet.h> 36 37 #include <netinet/in.h> 38 #include <netinet/tcp.h> 39 40 #include <errno.h> 41 #include <fcntl.h> 42 #include <netdb.h> 43 #include <signal.h> 44 #include <stdbool.h> 45 #include <stdint.h> 46 #include <stdio.h> 47 #include <string.h> 48 #include <unistd.h> 49 50 #include <openssl/err.h> 51 #include <openssl/ssl.h> 52 53 #include <compat/compat.h> 54 #ifndef HAVE_CLOSEFROM 55 #include <compat/closefrom.h> 56 #endif 57 #ifndef HAVE_STRLCPY 58 #include <compat/strlcpy.h> 59 #endif 60 61 #include "pjdlog.h" 62 #include "proto_impl.h" 63 #include "sandbox.h" 64 #include "subr.h" 65 66 #define TLS_CTX_MAGIC 0x715c7 67 struct tls_ctx { 68 int tls_magic; 69 struct proto_conn *tls_sock; 70 struct proto_conn *tls_tcp; 71 char tls_laddr[256]; 72 char tls_raddr[256]; 73 int tls_side; 74 #define TLS_SIDE_CLIENT 0 75 #define TLS_SIDE_SERVER_LISTEN 1 76 #define TLS_SIDE_SERVER_WORK 2 77 bool tls_wait_called; 78 }; 79 80 #define TLS_DEFAULT_TIMEOUT 30 81 82 static int tls_connect_wait(void *ctx, int timeout); 83 static void tls_close(void *ctx); 84 85 static void 86 block(int fd) 87 { 88 int flags; 89 90 flags = fcntl(fd, F_GETFL); 91 if (flags == -1) 92 pjdlog_exit(EX_TEMPFAIL, "fcntl(F_GETFL) failed"); 93 flags &= ~O_NONBLOCK; 94 if (fcntl(fd, F_SETFL, flags) == -1) 95 pjdlog_exit(EX_TEMPFAIL, "fcntl(F_SETFL) failed"); 96 } 97 98 static void 99 nonblock(int fd) 100 { 101 int flags; 102 103 flags = fcntl(fd, F_GETFL); 104 if (flags == -1) 105 pjdlog_exit(EX_TEMPFAIL, "fcntl(F_GETFL) failed"); 106 flags |= O_NONBLOCK; 107 if (fcntl(fd, F_SETFL, flags) == -1) 108 pjdlog_exit(EX_TEMPFAIL, "fcntl(F_SETFL) failed"); 109 } 110 111 static int 112 wait_for_fd(int fd, int timeout) 113 { 114 struct timeval tv; 115 fd_set fdset; 116 int error, ret; 117 118 error = 0; 119 120 for (;;) { 121 FD_ZERO(&fdset); 122 FD_SET(fd, &fdset); 123 124 tv.tv_sec = timeout; 125 tv.tv_usec = 0; 126 127 ret = select(fd + 1, NULL, &fdset, NULL, 128 timeout == -1 ? NULL : &tv); 129 if (ret == 0) { 130 error = ETIMEDOUT; 131 break; 132 } else if (ret == -1) { 133 if (errno == EINTR) 134 continue; 135 error = errno; 136 break; 137 } 138 PJDLOG_ASSERT(ret > 0); 139 PJDLOG_ASSERT(FD_ISSET(fd, &fdset)); 140 break; 141 } 142 143 return (error); 144 } 145 146 static void 147 ssl_log_errors(void) 148 { 149 unsigned long error; 150 151 while ((error = ERR_get_error()) != 0) 152 pjdlog_error("SSL error: %s", ERR_error_string(error, NULL)); 153 } 154 155 static int 156 ssl_check_error(SSL *ssl, int ret) 157 { 158 int error; 159 160 error = SSL_get_error(ssl, ret); 161 162 switch (error) { 163 case SSL_ERROR_NONE: 164 return (0); 165 case SSL_ERROR_WANT_READ: 166 pjdlog_debug(2, "SSL_ERROR_WANT_READ"); 167 return (-1); 168 case SSL_ERROR_WANT_WRITE: 169 pjdlog_debug(2, "SSL_ERROR_WANT_WRITE"); 170 return (-1); 171 case SSL_ERROR_ZERO_RETURN: 172 pjdlog_exitx(EX_OK, "Connection closed."); 173 case SSL_ERROR_SYSCALL: 174 ssl_log_errors(); 175 pjdlog_exitx(EX_TEMPFAIL, "SSL I/O error."); 176 case SSL_ERROR_SSL: 177 ssl_log_errors(); 178 pjdlog_exitx(EX_TEMPFAIL, "SSL protocol error."); 179 default: 180 ssl_log_errors(); 181 pjdlog_exitx(EX_TEMPFAIL, "Unknown SSL error (%d).", error); 182 } 183 } 184 185 static void 186 tcp_recv_ssl_send(int recvfd, SSL *sendssl) 187 { 188 static unsigned char buf[65536]; 189 ssize_t tcpdone; 190 int sendfd, ssldone; 191 192 sendfd = SSL_get_fd(sendssl); 193 PJDLOG_ASSERT(sendfd >= 0); 194 pjdlog_debug(2, "%s: start %d -> %d", __func__, recvfd, sendfd); 195 for (;;) { 196 tcpdone = recv(recvfd, buf, sizeof(buf), 0); 197 pjdlog_debug(2, "%s: recv() returned %zd", __func__, tcpdone); 198 if (tcpdone == 0) { 199 pjdlog_debug(1, "Connection terminated."); 200 exit(0); 201 } else if (tcpdone == -1) { 202 if (errno == EINTR) 203 continue; 204 else if (errno == EAGAIN) 205 break; 206 pjdlog_exit(EX_TEMPFAIL, "recv() failed"); 207 } 208 for (;;) { 209 ssldone = SSL_write(sendssl, buf, (int)tcpdone); 210 pjdlog_debug(2, "%s: send() returned %d", __func__, 211 ssldone); 212 if (ssl_check_error(sendssl, ssldone) == -1) { 213 (void)wait_for_fd(sendfd, -1); 214 continue; 215 } 216 PJDLOG_ASSERT(ssldone == tcpdone); 217 break; 218 } 219 } 220 pjdlog_debug(2, "%s: done %d -> %d", __func__, recvfd, sendfd); 221 } 222 223 static void 224 ssl_recv_tcp_send(SSL *recvssl, int sendfd) 225 { 226 static unsigned char buf[65536]; 227 unsigned char *ptr; 228 ssize_t tcpdone; 229 size_t todo; 230 int recvfd, ssldone; 231 232 recvfd = SSL_get_fd(recvssl); 233 PJDLOG_ASSERT(recvfd >= 0); 234 pjdlog_debug(2, "%s: start %d -> %d", __func__, recvfd, sendfd); 235 for (;;) { 236 ssldone = SSL_read(recvssl, buf, sizeof(buf)); 237 pjdlog_debug(2, "%s: SSL_read() returned %d", __func__, 238 ssldone); 239 if (ssl_check_error(recvssl, ssldone) == -1) 240 break; 241 todo = (size_t)ssldone; 242 ptr = buf; 243 do { 244 tcpdone = send(sendfd, ptr, todo, MSG_NOSIGNAL); 245 pjdlog_debug(2, "%s: send() returned %zd", __func__, 246 tcpdone); 247 if (tcpdone == 0) { 248 pjdlog_debug(1, "Connection terminated."); 249 exit(0); 250 } else if (tcpdone == -1) { 251 if (errno == EINTR || errno == ENOBUFS) 252 continue; 253 if (errno == EAGAIN) { 254 (void)wait_for_fd(sendfd, -1); 255 continue; 256 } 257 pjdlog_exit(EX_TEMPFAIL, "send() failed"); 258 } 259 todo -= tcpdone; 260 ptr += tcpdone; 261 } while (todo > 0); 262 } 263 pjdlog_debug(2, "%s: done %d -> %d", __func__, recvfd, sendfd); 264 } 265 266 static void 267 tls_loop(int sockfd, SSL *tcpssl) 268 { 269 fd_set fds; 270 int maxfd, tcpfd; 271 272 tcpfd = SSL_get_fd(tcpssl); 273 PJDLOG_ASSERT(tcpfd >= 0); 274 275 for (;;) { 276 FD_ZERO(&fds); 277 FD_SET(sockfd, &fds); 278 FD_SET(tcpfd, &fds); 279 maxfd = MAX(sockfd, tcpfd); 280 281 PJDLOG_ASSERT(maxfd + 1 <= (int)FD_SETSIZE); 282 if (select(maxfd + 1, &fds, NULL, NULL, NULL) == -1) { 283 if (errno == EINTR) 284 continue; 285 pjdlog_exit(EX_TEMPFAIL, "select() failed"); 286 } 287 if (FD_ISSET(sockfd, &fds)) 288 tcp_recv_ssl_send(sockfd, tcpssl); 289 if (FD_ISSET(tcpfd, &fds)) 290 ssl_recv_tcp_send(tcpssl, sockfd); 291 } 292 } 293 294 static void 295 tls_certificate_verify(SSL *ssl, const char *fingerprint) 296 { 297 unsigned char md[EVP_MAX_MD_SIZE]; 298 char mdstr[sizeof("SHA256=") - 1 + EVP_MAX_MD_SIZE * 3]; 299 char *mdstrp; 300 unsigned int i, mdsize; 301 X509 *cert; 302 303 if (fingerprint[0] == '\0') { 304 pjdlog_debug(1, "No fingerprint verification requested."); 305 return; 306 } 307 308 cert = SSL_get_peer_certificate(ssl); 309 if (cert == NULL) 310 pjdlog_exitx(EX_TEMPFAIL, "No peer certificate received."); 311 312 if (X509_digest(cert, EVP_sha256(), md, &mdsize) != 1) 313 pjdlog_exitx(EX_TEMPFAIL, "X509_digest() failed."); 314 PJDLOG_ASSERT(mdsize <= EVP_MAX_MD_SIZE); 315 316 X509_free(cert); 317 318 (void)strlcpy(mdstr, "SHA256=", sizeof(mdstr)); 319 mdstrp = mdstr + strlen(mdstr); 320 for (i = 0; i < mdsize; i++) { 321 PJDLOG_VERIFY(mdstrp + 3 <= mdstr + sizeof(mdstr)); 322 (void)sprintf(mdstrp, "%02hhX:", md[i]); 323 mdstrp += 3; 324 } 325 /* Clear last colon. */ 326 mdstrp[-1] = '\0'; 327 if (strcasecmp(mdstr, fingerprint) != 0) { 328 pjdlog_exitx(EX_NOPERM, 329 "Finger print doesn't match. Received \"%s\", expected \"%s\"", 330 mdstr, fingerprint); 331 } 332 } 333 334 static void 335 tls_exec_client(const char *user, int startfd, const char *srcaddr, 336 const char *dstaddr, const char *fingerprint, const char *defport, 337 int timeout, int debuglevel) 338 { 339 struct proto_conn *tcp; 340 char *saddr, *daddr; 341 SSL_CTX *sslctx; 342 SSL *ssl; 343 long ret; 344 int sockfd, tcpfd; 345 uint8_t connected; 346 347 pjdlog_debug_set(debuglevel); 348 pjdlog_prefix_set("[TLS sandbox] (client) "); 349 #ifdef HAVE_SETPROCTITLE 350 setproctitle("[TLS sandbox] (client) "); 351 #endif 352 proto_set("tcp:port", defport); 353 354 sockfd = startfd; 355 356 /* Change tls:// to tcp://. */ 357 if (srcaddr == NULL) { 358 saddr = NULL; 359 } else { 360 saddr = strdup(srcaddr); 361 if (saddr == NULL) 362 pjdlog_exitx(EX_TEMPFAIL, "Unable to allocate memory."); 363 bcopy("tcp://", saddr, 6); 364 } 365 daddr = strdup(dstaddr); 366 if (daddr == NULL) 367 pjdlog_exitx(EX_TEMPFAIL, "Unable to allocate memory."); 368 bcopy("tcp://", daddr, 6); 369 370 /* Establish TCP connection. */ 371 if (proto_connect(saddr, daddr, timeout, &tcp) == -1) 372 exit(EX_TEMPFAIL); 373 374 #if OPENSSL_VERSION_NUMBER < 0x10100000L 375 SSL_load_error_strings(); 376 SSL_library_init(); 377 #endif 378 379 /* 380 * TODO: On FreeBSD we could move this below sandbox() once libc and 381 * libcrypto use sysctl kern.arandom to obtain random data 382 * instead of /dev/urandom and friends. 383 */ 384 sslctx = SSL_CTX_new(TLS_client_method()); 385 if (sslctx == NULL) 386 pjdlog_exitx(EX_TEMPFAIL, "SSL_CTX_new() failed."); 387 388 if (sandbox(user, true, "proto_tls client: %s", dstaddr) != 0) 389 pjdlog_exitx(EX_CONFIG, "Unable to sandbox TLS client."); 390 pjdlog_debug(1, "Privileges successfully dropped."); 391 392 SSL_CTX_set_options(sslctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3); 393 394 /* Load CA certs. */ 395 /* TODO */ 396 //SSL_CTX_load_verify_locations(sslctx, cacerts_file, NULL); 397 398 ssl = SSL_new(sslctx); 399 if (ssl == NULL) 400 pjdlog_exitx(EX_TEMPFAIL, "SSL_new() failed."); 401 402 tcpfd = proto_descriptor(tcp); 403 404 block(tcpfd); 405 406 if (SSL_set_fd(ssl, tcpfd) != 1) 407 pjdlog_exitx(EX_TEMPFAIL, "SSL_set_fd() failed."); 408 409 ret = SSL_connect(ssl); 410 ssl_check_error(ssl, (int)ret); 411 412 nonblock(sockfd); 413 nonblock(tcpfd); 414 415 tls_certificate_verify(ssl, fingerprint); 416 417 /* 418 * The following byte is sent to make proto_connect_wait() work. 419 */ 420 connected = 1; 421 for (;;) { 422 switch (send(sockfd, &connected, sizeof(connected), 0)) { 423 case -1: 424 if (errno == EINTR || errno == ENOBUFS) 425 continue; 426 if (errno == EAGAIN) { 427 (void)wait_for_fd(sockfd, -1); 428 continue; 429 } 430 pjdlog_exit(EX_TEMPFAIL, "send() failed"); 431 case 0: 432 pjdlog_debug(1, "Connection terminated."); 433 exit(0); 434 case 1: 435 break; 436 } 437 break; 438 } 439 440 tls_loop(sockfd, ssl); 441 } 442 443 static void 444 tls_call_exec_client(struct proto_conn *sock, const char *srcaddr, 445 const char *dstaddr, int timeout) 446 { 447 char *timeoutstr, *startfdstr, *debugstr; 448 int startfd; 449 450 /* Declare that we are receiver. */ 451 proto_recv(sock, NULL, 0); 452 453 if (pjdlog_mode_get() == PJDLOG_MODE_STD) 454 startfd = 3; 455 else /* if (pjdlog_mode_get() == PJDLOG_MODE_SYSLOG) */ 456 startfd = 0; 457 458 if (proto_descriptor(sock) != startfd) { 459 /* Move socketpair descriptor to descriptor number startfd. */ 460 if (dup2(proto_descriptor(sock), startfd) == -1) 461 pjdlog_exit(EX_OSERR, "dup2() failed"); 462 proto_close(sock); 463 } else { 464 /* 465 * The FD_CLOEXEC is cleared by dup2(2), so when we do not 466 * call it, we have to clear it by hand in case it is set. 467 */ 468 if (fcntl(startfd, F_SETFD, 0) == -1) 469 pjdlog_exit(EX_OSERR, "fcntl() failed"); 470 } 471 472 closefrom(startfd + 1); 473 474 if (asprintf(&startfdstr, "%d", startfd) == -1) 475 pjdlog_exit(EX_TEMPFAIL, "asprintf() failed"); 476 if (timeout == -1) 477 timeout = TLS_DEFAULT_TIMEOUT; 478 if (asprintf(&timeoutstr, "%d", timeout) == -1) 479 pjdlog_exit(EX_TEMPFAIL, "asprintf() failed"); 480 if (asprintf(&debugstr, "%d", pjdlog_debug_get()) == -1) 481 pjdlog_exit(EX_TEMPFAIL, "asprintf() failed"); 482 483 execl(proto_get("execpath"), proto_get("execpath"), "proto", "tls", 484 proto_get("user"), "client", startfdstr, 485 srcaddr == NULL ? "" : srcaddr, dstaddr, 486 proto_get("tls:fingerprint"), proto_get("tcp:port"), timeoutstr, 487 debugstr, NULL); 488 pjdlog_exit(EX_SOFTWARE, "execl() failed"); 489 } 490 491 static int 492 tls_connect(const char *srcaddr, const char *dstaddr, int timeout, void **ctxp) 493 { 494 struct tls_ctx *tlsctx; 495 struct proto_conn *sock; 496 pid_t pid; 497 int error; 498 499 PJDLOG_ASSERT(srcaddr == NULL || srcaddr[0] != '\0'); 500 PJDLOG_ASSERT(dstaddr != NULL); 501 PJDLOG_ASSERT(timeout >= -1); 502 PJDLOG_ASSERT(ctxp != NULL); 503 504 if (strncmp(dstaddr, "tls://", 6) != 0) 505 return (-1); 506 if (srcaddr != NULL && strncmp(srcaddr, "tls://", 6) != 0) 507 return (-1); 508 509 if (proto_connect(NULL, "socketpair://", -1, &sock) == -1) 510 return (errno); 511 512 #if 0 513 /* 514 * We use rfork() with the following flags to disable SIGCHLD 515 * delivery upon the sandbox process exit. 516 */ 517 pid = rfork(RFFDG | RFPROC | RFTSIGZMB | RFTSIGFLAGS(0)); 518 #else 519 /* 520 * We don't use rfork() to be able to log information about sandbox 521 * process exiting. 522 */ 523 pid = fork(); 524 #endif 525 switch (pid) { 526 case -1: 527 /* Failure. */ 528 error = errno; 529 proto_close(sock); 530 return (error); 531 case 0: 532 /* Child. */ 533 pjdlog_prefix_set("[TLS sandbox] (client) "); 534 #ifdef HAVE_SETPROCTITLE 535 setproctitle("[TLS sandbox] (client) "); 536 #endif 537 tls_call_exec_client(sock, srcaddr, dstaddr, timeout); 538 /* NOTREACHED */ 539 default: 540 /* Parent. */ 541 tlsctx = calloc(1, sizeof(*tlsctx)); 542 if (tlsctx == NULL) { 543 error = errno; 544 proto_close(sock); 545 (void)kill(pid, SIGKILL); 546 return (error); 547 } 548 proto_send(sock, NULL, 0); 549 tlsctx->tls_sock = sock; 550 tlsctx->tls_tcp = NULL; 551 tlsctx->tls_side = TLS_SIDE_CLIENT; 552 tlsctx->tls_wait_called = false; 553 tlsctx->tls_magic = TLS_CTX_MAGIC; 554 if (timeout >= 0) { 555 error = tls_connect_wait(tlsctx, timeout); 556 if (error != 0) { 557 (void)kill(pid, SIGKILL); 558 tls_close(tlsctx); 559 return (error); 560 } 561 } 562 *ctxp = tlsctx; 563 return (0); 564 } 565 } 566 567 static int 568 tls_connect_wait(void *ctx, int timeout) 569 { 570 struct tls_ctx *tlsctx = ctx; 571 int error, sockfd; 572 uint8_t connected; 573 574 PJDLOG_ASSERT(tlsctx != NULL); 575 PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC); 576 PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT); 577 PJDLOG_ASSERT(tlsctx->tls_sock != NULL); 578 PJDLOG_ASSERT(!tlsctx->tls_wait_called); 579 PJDLOG_ASSERT(timeout >= 0); 580 581 sockfd = proto_descriptor(tlsctx->tls_sock); 582 error = wait_for_fd(sockfd, timeout); 583 if (error != 0) 584 return (error); 585 586 for (;;) { 587 switch (recv(sockfd, &connected, sizeof(connected), 588 MSG_WAITALL)) { 589 case -1: 590 if (errno == EINTR || errno == ENOBUFS) 591 continue; 592 error = errno; 593 break; 594 case 0: 595 pjdlog_debug(1, "Connection terminated."); 596 error = ENOTCONN; 597 break; 598 case 1: 599 tlsctx->tls_wait_called = true; 600 break; 601 } 602 break; 603 } 604 605 return (error); 606 } 607 608 static int 609 tls_server(const char *lstaddr, void **ctxp) 610 { 611 struct proto_conn *tcp; 612 struct tls_ctx *tlsctx; 613 char *laddr; 614 int error; 615 616 if (strncmp(lstaddr, "tls://", 6) != 0) 617 return (-1); 618 619 tlsctx = malloc(sizeof(*tlsctx)); 620 if (tlsctx == NULL) { 621 pjdlog_warning("Unable to allocate memory."); 622 return (ENOMEM); 623 } 624 625 laddr = strdup(lstaddr); 626 if (laddr == NULL) { 627 free(tlsctx); 628 pjdlog_warning("Unable to allocate memory."); 629 return (ENOMEM); 630 } 631 bcopy("tcp://", laddr, 6); 632 633 if (proto_server(laddr, &tcp) == -1) { 634 error = errno; 635 free(tlsctx); 636 free(laddr); 637 return (error); 638 } 639 free(laddr); 640 641 tlsctx->tls_sock = NULL; 642 tlsctx->tls_tcp = tcp; 643 tlsctx->tls_side = TLS_SIDE_SERVER_LISTEN; 644 tlsctx->tls_wait_called = true; 645 tlsctx->tls_magic = TLS_CTX_MAGIC; 646 *ctxp = tlsctx; 647 648 return (0); 649 } 650 651 static void 652 tls_exec_server(const char *user, int startfd, const char *privkey, 653 const char *cert, int debuglevel) 654 { 655 SSL_CTX *sslctx; 656 SSL *ssl; 657 int sockfd, tcpfd, ret; 658 659 pjdlog_debug_set(debuglevel); 660 pjdlog_prefix_set("[TLS sandbox] (server) "); 661 #ifdef HAVE_SETPROCTITLE 662 setproctitle("[TLS sandbox] (server) "); 663 #endif 664 665 sockfd = startfd; 666 tcpfd = startfd + 1; 667 668 #if OPENSSL_VERSION_NUMBER < 0x10100000L 669 SSL_load_error_strings(); 670 SSL_library_init(); 671 #endif 672 673 sslctx = SSL_CTX_new(TLS_server_method()); 674 if (sslctx == NULL) 675 pjdlog_exitx(EX_TEMPFAIL, "SSL_CTX_new() failed."); 676 677 SSL_CTX_set_options(sslctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3); 678 679 ssl = SSL_new(sslctx); 680 if (ssl == NULL) 681 pjdlog_exitx(EX_TEMPFAIL, "SSL_new() failed."); 682 683 if (SSL_use_RSAPrivateKey_file(ssl, privkey, SSL_FILETYPE_PEM) != 1) { 684 ssl_log_errors(); 685 pjdlog_exitx(EX_CONFIG, 686 "SSL_use_RSAPrivateKey_file(%s) failed.", privkey); 687 } 688 689 if (SSL_use_certificate_file(ssl, cert, SSL_FILETYPE_PEM) != 1) { 690 ssl_log_errors(); 691 pjdlog_exitx(EX_CONFIG, "SSL_use_certificate_file(%s) failed.", 692 cert); 693 } 694 695 if (sandbox(user, true, "proto_tls server") != 0) 696 pjdlog_exitx(EX_CONFIG, "Unable to sandbox TLS server."); 697 pjdlog_debug(1, "Privileges successfully dropped."); 698 699 nonblock(sockfd); 700 nonblock(tcpfd); 701 702 if (SSL_set_fd(ssl, tcpfd) != 1) 703 pjdlog_exitx(EX_TEMPFAIL, "SSL_set_fd() failed."); 704 705 ret = SSL_accept(ssl); 706 ssl_check_error(ssl, ret); 707 708 tls_loop(sockfd, ssl); 709 } 710 711 static void 712 tls_call_exec_server(struct proto_conn *sock, struct proto_conn *tcp) 713 { 714 int startfd, sockfd, tcpfd, safefd; 715 char *startfdstr, *debugstr; 716 717 if (pjdlog_mode_get() == PJDLOG_MODE_STD) 718 startfd = 3; 719 else /* if (pjdlog_mode_get() == PJDLOG_MODE_SYSLOG) */ 720 startfd = 0; 721 722 /* Declare that we are receiver. */ 723 proto_send(sock, NULL, 0); 724 725 sockfd = proto_descriptor(sock); 726 tcpfd = proto_descriptor(tcp); 727 728 safefd = MAX(sockfd, tcpfd); 729 safefd = MAX(safefd, startfd); 730 safefd++; 731 732 /* Move sockfd and tcpfd to safe numbers first. */ 733 if (dup2(sockfd, safefd) == -1) 734 pjdlog_exit(EX_OSERR, "dup2() failed"); 735 proto_close(sock); 736 sockfd = safefd; 737 if (dup2(tcpfd, safefd + 1) == -1) 738 pjdlog_exit(EX_OSERR, "dup2() failed"); 739 proto_close(tcp); 740 tcpfd = safefd + 1; 741 742 /* Move socketpair descriptor to descriptor number startfd. */ 743 if (dup2(sockfd, startfd) == -1) 744 pjdlog_exit(EX_OSERR, "dup2() failed"); 745 (void)close(sockfd); 746 /* Move tcp descriptor to descriptor number startfd + 1. */ 747 if (dup2(tcpfd, startfd + 1) == -1) 748 pjdlog_exit(EX_OSERR, "dup2() failed"); 749 (void)close(tcpfd); 750 751 closefrom(startfd + 2); 752 753 /* 754 * Even if FD_CLOEXEC was set on descriptors before dup2(), it should 755 * have been cleared on dup2(), but better be safe than sorry. 756 */ 757 if (fcntl(startfd, F_SETFD, 0) == -1) 758 pjdlog_exit(EX_OSERR, "fcntl() failed"); 759 if (fcntl(startfd + 1, F_SETFD, 0) == -1) 760 pjdlog_exit(EX_OSERR, "fcntl() failed"); 761 762 if (asprintf(&startfdstr, "%d", startfd) == -1) 763 pjdlog_exit(EX_TEMPFAIL, "asprintf() failed"); 764 if (asprintf(&debugstr, "%d", pjdlog_debug_get()) == -1) 765 pjdlog_exit(EX_TEMPFAIL, "asprintf() failed"); 766 767 execl(proto_get("execpath"), proto_get("execpath"), "proto", "tls", 768 proto_get("user"), "server", startfdstr, proto_get("tls:keyfile"), 769 proto_get("tls:certfile"), debugstr, NULL); 770 pjdlog_exit(EX_SOFTWARE, "execl() failed"); 771 } 772 773 static int 774 tls_accept(void *ctx, void **newctxp) 775 { 776 struct tls_ctx *tlsctx = ctx; 777 struct tls_ctx *newtlsctx; 778 struct proto_conn *sock, *tcp; 779 pid_t pid; 780 int error; 781 782 PJDLOG_ASSERT(tlsctx != NULL); 783 PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC); 784 PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_SERVER_LISTEN); 785 786 if (proto_connect(NULL, "socketpair://", -1, &sock) == -1) 787 return (errno); 788 789 /* Accept TCP connection. */ 790 if (proto_accept(tlsctx->tls_tcp, &tcp) == -1) { 791 error = errno; 792 proto_close(sock); 793 return (error); 794 } 795 796 pid = fork(); 797 switch (pid) { 798 case -1: 799 /* Failure. */ 800 error = errno; 801 proto_close(sock); 802 return (error); 803 case 0: 804 /* Child. */ 805 pjdlog_prefix_set("[TLS sandbox] (server) "); 806 #ifdef HAVE_SETPROCTITLE 807 setproctitle("[TLS sandbox] (server) "); 808 #endif 809 /* Close listen socket. */ 810 proto_close(tlsctx->tls_tcp); 811 tls_call_exec_server(sock, tcp); 812 /* NOTREACHED */ 813 PJDLOG_ABORT("Unreachable."); 814 default: 815 /* Parent. */ 816 newtlsctx = calloc(1, sizeof(*tlsctx)); 817 if (newtlsctx == NULL) { 818 error = errno; 819 proto_close(sock); 820 proto_close(tcp); 821 (void)kill(pid, SIGKILL); 822 return (error); 823 } 824 proto_local_address(tcp, newtlsctx->tls_laddr, 825 sizeof(newtlsctx->tls_laddr)); 826 PJDLOG_ASSERT(strncmp(newtlsctx->tls_laddr, "tcp://", 6) == 0); 827 bcopy("tls://", newtlsctx->tls_laddr, 6); 828 *strrchr(newtlsctx->tls_laddr, ':') = '\0'; 829 proto_remote_address(tcp, newtlsctx->tls_raddr, 830 sizeof(newtlsctx->tls_raddr)); 831 PJDLOG_ASSERT(strncmp(newtlsctx->tls_raddr, "tcp://", 6) == 0); 832 bcopy("tls://", newtlsctx->tls_raddr, 6); 833 *strrchr(newtlsctx->tls_raddr, ':') = '\0'; 834 proto_close(tcp); 835 proto_recv(sock, NULL, 0); 836 newtlsctx->tls_sock = sock; 837 newtlsctx->tls_tcp = NULL; 838 newtlsctx->tls_wait_called = true; 839 newtlsctx->tls_side = TLS_SIDE_SERVER_WORK; 840 newtlsctx->tls_magic = TLS_CTX_MAGIC; 841 *newctxp = newtlsctx; 842 return (0); 843 } 844 } 845 846 static int 847 tls_wrap(int fd, bool client, void **ctxp) 848 { 849 struct tls_ctx *tlsctx; 850 struct proto_conn *sock; 851 int error; 852 853 tlsctx = calloc(1, sizeof(*tlsctx)); 854 if (tlsctx == NULL) 855 return (errno); 856 857 if (proto_wrap("socketpair", client, fd, &sock) == -1) { 858 error = errno; 859 free(tlsctx); 860 return (error); 861 } 862 863 tlsctx->tls_sock = sock; 864 tlsctx->tls_tcp = NULL; 865 tlsctx->tls_wait_called = (client ? false : true); 866 tlsctx->tls_side = (client ? TLS_SIDE_CLIENT : TLS_SIDE_SERVER_WORK); 867 tlsctx->tls_magic = TLS_CTX_MAGIC; 868 *ctxp = tlsctx; 869 870 return (0); 871 } 872 873 static int 874 tls_send(void *ctx, const unsigned char *data, size_t size, int fd) 875 { 876 struct tls_ctx *tlsctx = ctx; 877 878 PJDLOG_ASSERT(tlsctx != NULL); 879 PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC); 880 PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT || 881 tlsctx->tls_side == TLS_SIDE_SERVER_WORK); 882 PJDLOG_ASSERT(tlsctx->tls_sock != NULL); 883 PJDLOG_ASSERT(tlsctx->tls_wait_called); 884 PJDLOG_ASSERT(fd == -1); 885 886 if (proto_send(tlsctx->tls_sock, data, size) == -1) 887 return (errno); 888 889 return (0); 890 } 891 892 static int 893 tls_recv(void *ctx, unsigned char *data, size_t size, int *fdp) 894 { 895 struct tls_ctx *tlsctx = ctx; 896 897 PJDLOG_ASSERT(tlsctx != NULL); 898 PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC); 899 PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT || 900 tlsctx->tls_side == TLS_SIDE_SERVER_WORK); 901 PJDLOG_ASSERT(tlsctx->tls_sock != NULL); 902 PJDLOG_ASSERT(tlsctx->tls_wait_called); 903 PJDLOG_ASSERT(fdp == NULL); 904 905 if (proto_recv(tlsctx->tls_sock, data, size) == -1) 906 return (errno); 907 908 return (0); 909 } 910 911 static int 912 tls_descriptor(const void *ctx) 913 { 914 const struct tls_ctx *tlsctx = ctx; 915 916 PJDLOG_ASSERT(tlsctx != NULL); 917 PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC); 918 919 switch (tlsctx->tls_side) { 920 case TLS_SIDE_CLIENT: 921 case TLS_SIDE_SERVER_WORK: 922 PJDLOG_ASSERT(tlsctx->tls_sock != NULL); 923 924 return (proto_descriptor(tlsctx->tls_sock)); 925 case TLS_SIDE_SERVER_LISTEN: 926 PJDLOG_ASSERT(tlsctx->tls_tcp != NULL); 927 928 return (proto_descriptor(tlsctx->tls_tcp)); 929 default: 930 PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side); 931 } 932 } 933 934 static bool 935 tcp_address_match(const void *ctx, const char *addr) 936 { 937 const struct tls_ctx *tlsctx = ctx; 938 939 PJDLOG_ASSERT(tlsctx != NULL); 940 PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC); 941 942 return (strcmp(tlsctx->tls_raddr, addr) == 0); 943 } 944 945 static void 946 tls_local_address(const void *ctx, char *addr, size_t size) 947 { 948 const struct tls_ctx *tlsctx = ctx; 949 950 PJDLOG_ASSERT(tlsctx != NULL); 951 PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC); 952 PJDLOG_ASSERT(tlsctx->tls_wait_called); 953 954 switch (tlsctx->tls_side) { 955 case TLS_SIDE_CLIENT: 956 PJDLOG_ASSERT(tlsctx->tls_sock != NULL); 957 958 PJDLOG_VERIFY(strlcpy(addr, "tls://N/A", size) < size); 959 break; 960 case TLS_SIDE_SERVER_WORK: 961 PJDLOG_ASSERT(tlsctx->tls_sock != NULL); 962 963 PJDLOG_VERIFY(strlcpy(addr, tlsctx->tls_laddr, size) < size); 964 break; 965 case TLS_SIDE_SERVER_LISTEN: 966 PJDLOG_ASSERT(tlsctx->tls_tcp != NULL); 967 968 proto_local_address(tlsctx->tls_tcp, addr, size); 969 PJDLOG_ASSERT(strncmp(addr, "tcp://", 6) == 0); 970 /* Replace tcp:// prefix with tls:// */ 971 bcopy("tls://", addr, 6); 972 break; 973 default: 974 PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side); 975 } 976 } 977 978 static void 979 tls_remote_address(const void *ctx, char *addr, size_t size) 980 { 981 const struct tls_ctx *tlsctx = ctx; 982 983 PJDLOG_ASSERT(tlsctx != NULL); 984 PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC); 985 PJDLOG_ASSERT(tlsctx->tls_wait_called); 986 987 switch (tlsctx->tls_side) { 988 case TLS_SIDE_CLIENT: 989 PJDLOG_ASSERT(tlsctx->tls_sock != NULL); 990 991 PJDLOG_VERIFY(strlcpy(addr, "tls://N/A", size) < size); 992 break; 993 case TLS_SIDE_SERVER_WORK: 994 PJDLOG_ASSERT(tlsctx->tls_sock != NULL); 995 996 PJDLOG_VERIFY(strlcpy(addr, tlsctx->tls_raddr, size) < size); 997 break; 998 case TLS_SIDE_SERVER_LISTEN: 999 PJDLOG_ASSERT(tlsctx->tls_tcp != NULL); 1000 1001 proto_remote_address(tlsctx->tls_tcp, addr, size); 1002 PJDLOG_ASSERT(strncmp(addr, "tcp://", 6) == 0); 1003 /* Replace tcp:// prefix with tls:// */ 1004 bcopy("tls://", addr, 6); 1005 break; 1006 default: 1007 PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side); 1008 } 1009 } 1010 1011 static void 1012 tls_close(void *ctx) 1013 { 1014 struct tls_ctx *tlsctx = ctx; 1015 1016 PJDLOG_ASSERT(tlsctx != NULL); 1017 PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC); 1018 1019 if (tlsctx->tls_sock != NULL) { 1020 proto_close(tlsctx->tls_sock); 1021 tlsctx->tls_sock = NULL; 1022 } 1023 if (tlsctx->tls_tcp != NULL) { 1024 proto_close(tlsctx->tls_tcp); 1025 tlsctx->tls_tcp = NULL; 1026 } 1027 tlsctx->tls_side = 0; 1028 tlsctx->tls_magic = 0; 1029 free(tlsctx); 1030 } 1031 1032 static int 1033 tls_exec(int argc, char *argv[]) 1034 { 1035 1036 PJDLOG_ASSERT(argc > 3); 1037 PJDLOG_ASSERT(strcmp(argv[0], "tls") == 0); 1038 1039 pjdlog_init(atoi(argv[3]) == 0 ? PJDLOG_MODE_SYSLOG : PJDLOG_MODE_STD); 1040 1041 if (strcmp(argv[2], "client") == 0) { 1042 if (argc != 10) 1043 return (EINVAL); 1044 tls_exec_client(argv[1], atoi(argv[3]), 1045 argv[4][0] == '\0' ? NULL : argv[4], argv[5], argv[6], 1046 argv[7], atoi(argv[8]), atoi(argv[9])); 1047 } else if (strcmp(argv[2], "server") == 0) { 1048 if (argc != 7) 1049 return (EINVAL); 1050 tls_exec_server(argv[1], atoi(argv[3]), argv[4], argv[5], 1051 atoi(argv[6])); 1052 } 1053 return (EINVAL); 1054 } 1055 1056 static struct proto tls_proto = { 1057 .prt_name = "tls", 1058 .prt_connect = tls_connect, 1059 .prt_connect_wait = tls_connect_wait, 1060 .prt_server = tls_server, 1061 .prt_accept = tls_accept, 1062 .prt_wrap = tls_wrap, 1063 .prt_send = tls_send, 1064 .prt_recv = tls_recv, 1065 .prt_descriptor = tls_descriptor, 1066 .prt_address_match = tcp_address_match, 1067 .prt_local_address = tls_local_address, 1068 .prt_remote_address = tls_remote_address, 1069 .prt_close = tls_close, 1070 .prt_exec = tls_exec 1071 }; 1072 1073 static __constructor void 1074 tls_ctor(void) 1075 { 1076 1077 proto_register(&tls_proto, false); 1078 } 1079