/*- * Copyright (c) 2011 The FreeBSD Foundation * All rights reserved. * * This software was developed by Pawel Jakub Dawidek under sponsorship from * the FreeBSD Foundation. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions * are met: * 1. Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * 2. Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF * SUCH DAMAGE. */ #include #include /* MAXHOSTNAMELEN */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #ifndef HAVE_CLOSEFROM #include #endif #ifndef HAVE_STRLCPY #include #endif #include "pjdlog.h" #include "proto_impl.h" #include "sandbox.h" #include "subr.h" #define TLS_CTX_MAGIC 0x715c7 struct tls_ctx { int tls_magic; struct proto_conn *tls_sock; struct proto_conn *tls_tcp; char tls_laddr[256]; char tls_raddr[256]; int tls_side; #define TLS_SIDE_CLIENT 0 #define TLS_SIDE_SERVER_LISTEN 1 #define TLS_SIDE_SERVER_WORK 2 bool tls_wait_called; }; #define TLS_DEFAULT_TIMEOUT 30 static int tls_connect_wait(void *ctx, int timeout); static void tls_close(void *ctx); static void block(int fd) { int flags; flags = fcntl(fd, F_GETFL); if (flags == -1) pjdlog_exit(EX_TEMPFAIL, "fcntl(F_GETFL) failed"); flags &= ~O_NONBLOCK; if (fcntl(fd, F_SETFL, flags) == -1) pjdlog_exit(EX_TEMPFAIL, "fcntl(F_SETFL) failed"); } static void nonblock(int fd) { int flags; flags = fcntl(fd, F_GETFL); if (flags == -1) pjdlog_exit(EX_TEMPFAIL, "fcntl(F_GETFL) failed"); flags |= O_NONBLOCK; if (fcntl(fd, F_SETFL, flags) == -1) pjdlog_exit(EX_TEMPFAIL, "fcntl(F_SETFL) failed"); } static int wait_for_fd(int fd, int timeout) { struct timeval tv; fd_set fdset; int error, ret; error = 0; for (;;) { FD_ZERO(&fdset); FD_SET(fd, &fdset); tv.tv_sec = timeout; tv.tv_usec = 0; ret = select(fd + 1, NULL, &fdset, NULL, timeout == -1 ? NULL : &tv); if (ret == 0) { error = ETIMEDOUT; break; } else if (ret == -1) { if (errno == EINTR) continue; error = errno; break; } PJDLOG_ASSERT(ret > 0); PJDLOG_ASSERT(FD_ISSET(fd, &fdset)); break; } return (error); } static void ssl_log_errors(void) { unsigned long error; while ((error = ERR_get_error()) != 0) pjdlog_error("SSL error: %s", ERR_error_string(error, NULL)); } static int ssl_check_error(SSL *ssl, int ret) { int error; error = SSL_get_error(ssl, ret); switch (error) { case SSL_ERROR_NONE: return (0); case SSL_ERROR_WANT_READ: pjdlog_debug(2, "SSL_ERROR_WANT_READ"); return (-1); case SSL_ERROR_WANT_WRITE: pjdlog_debug(2, "SSL_ERROR_WANT_WRITE"); return (-1); case SSL_ERROR_ZERO_RETURN: pjdlog_exitx(EX_OK, "Connection closed."); case SSL_ERROR_SYSCALL: ssl_log_errors(); pjdlog_exitx(EX_TEMPFAIL, "SSL I/O error."); case SSL_ERROR_SSL: ssl_log_errors(); pjdlog_exitx(EX_TEMPFAIL, "SSL protocol error."); default: ssl_log_errors(); pjdlog_exitx(EX_TEMPFAIL, "Unknown SSL error (%d).", error); } } static void tcp_recv_ssl_send(int recvfd, SSL *sendssl) { static unsigned char buf[65536]; ssize_t tcpdone; int sendfd, ssldone; sendfd = SSL_get_fd(sendssl); PJDLOG_ASSERT(sendfd >= 0); pjdlog_debug(2, "%s: start %d -> %d", __func__, recvfd, sendfd); for (;;) { tcpdone = recv(recvfd, buf, sizeof(buf), 0); pjdlog_debug(2, "%s: recv() returned %zd", __func__, tcpdone); if (tcpdone == 0) { pjdlog_debug(1, "Connection terminated."); exit(0); } else if (tcpdone == -1) { if (errno == EINTR) continue; else if (errno == EAGAIN) break; pjdlog_exit(EX_TEMPFAIL, "recv() failed"); } for (;;) { ssldone = SSL_write(sendssl, buf, (int)tcpdone); pjdlog_debug(2, "%s: send() returned %d", __func__, ssldone); if (ssl_check_error(sendssl, ssldone) == -1) { (void)wait_for_fd(sendfd, -1); continue; } PJDLOG_ASSERT(ssldone == tcpdone); break; } } pjdlog_debug(2, "%s: done %d -> %d", __func__, recvfd, sendfd); } static void ssl_recv_tcp_send(SSL *recvssl, int sendfd) { static unsigned char buf[65536]; unsigned char *ptr; ssize_t tcpdone; size_t todo; int recvfd, ssldone; recvfd = SSL_get_fd(recvssl); PJDLOG_ASSERT(recvfd >= 0); pjdlog_debug(2, "%s: start %d -> %d", __func__, recvfd, sendfd); for (;;) { ssldone = SSL_read(recvssl, buf, sizeof(buf)); pjdlog_debug(2, "%s: SSL_read() returned %d", __func__, ssldone); if (ssl_check_error(recvssl, ssldone) == -1) break; todo = (size_t)ssldone; ptr = buf; do { tcpdone = send(sendfd, ptr, todo, MSG_NOSIGNAL); pjdlog_debug(2, "%s: send() returned %zd", __func__, tcpdone); if (tcpdone == 0) { pjdlog_debug(1, "Connection terminated."); exit(0); } else if (tcpdone == -1) { if (errno == EINTR || errno == ENOBUFS) continue; if (errno == EAGAIN) { (void)wait_for_fd(sendfd, -1); continue; } pjdlog_exit(EX_TEMPFAIL, "send() failed"); } todo -= tcpdone; ptr += tcpdone; } while (todo > 0); } pjdlog_debug(2, "%s: done %d -> %d", __func__, recvfd, sendfd); } static void tls_loop(int sockfd, SSL *tcpssl) { fd_set fds; int maxfd, tcpfd; tcpfd = SSL_get_fd(tcpssl); PJDLOG_ASSERT(tcpfd >= 0); for (;;) { FD_ZERO(&fds); FD_SET(sockfd, &fds); FD_SET(tcpfd, &fds); maxfd = MAX(sockfd, tcpfd); PJDLOG_ASSERT(maxfd + 1 <= (int)FD_SETSIZE); if (select(maxfd + 1, &fds, NULL, NULL, NULL) == -1) { if (errno == EINTR) continue; pjdlog_exit(EX_TEMPFAIL, "select() failed"); } if (FD_ISSET(sockfd, &fds)) tcp_recv_ssl_send(sockfd, tcpssl); if (FD_ISSET(tcpfd, &fds)) ssl_recv_tcp_send(tcpssl, sockfd); } } static void tls_certificate_verify(SSL *ssl, const char *fingerprint) { unsigned char md[EVP_MAX_MD_SIZE]; char mdstr[sizeof("SHA256=") - 1 + EVP_MAX_MD_SIZE * 3]; char *mdstrp; unsigned int i, mdsize; X509 *cert; if (fingerprint[0] == '\0') { pjdlog_debug(1, "No fingerprint verification requested."); return; } cert = SSL_get_peer_certificate(ssl); if (cert == NULL) pjdlog_exitx(EX_TEMPFAIL, "No peer certificate received."); if (X509_digest(cert, EVP_sha256(), md, &mdsize) != 1) pjdlog_exitx(EX_TEMPFAIL, "X509_digest() failed."); PJDLOG_ASSERT(mdsize <= EVP_MAX_MD_SIZE); X509_free(cert); (void)strlcpy(mdstr, "SHA256=", sizeof(mdstr)); mdstrp = mdstr + strlen(mdstr); for (i = 0; i < mdsize; i++) { PJDLOG_VERIFY(mdstrp + 3 <= mdstr + sizeof(mdstr)); (void)sprintf(mdstrp, "%02hhX:", md[i]); mdstrp += 3; } /* Clear last colon. */ mdstrp[-1] = '\0'; if (strcasecmp(mdstr, fingerprint) != 0) { pjdlog_exitx(EX_NOPERM, "Finger print doesn't match. Received \"%s\", expected \"%s\"", mdstr, fingerprint); } } static void tls_exec_client(const char *user, int startfd, const char *srcaddr, const char *dstaddr, const char *fingerprint, const char *defport, int timeout, int debuglevel) { struct proto_conn *tcp; char *saddr, *daddr; SSL_CTX *sslctx; SSL *ssl; long ret; int sockfd, tcpfd; uint8_t connected; pjdlog_debug_set(debuglevel); pjdlog_prefix_set("[TLS sandbox] (client) "); #ifdef HAVE_SETPROCTITLE setproctitle("[TLS sandbox] (client) "); #endif proto_set("tcp:port", defport); sockfd = startfd; /* Change tls:// to tcp://. */ if (srcaddr == NULL) { saddr = NULL; } else { saddr = strdup(srcaddr); if (saddr == NULL) pjdlog_exitx(EX_TEMPFAIL, "Unable to allocate memory."); bcopy("tcp://", saddr, 6); } daddr = strdup(dstaddr); if (daddr == NULL) pjdlog_exitx(EX_TEMPFAIL, "Unable to allocate memory."); bcopy("tcp://", daddr, 6); /* Establish TCP connection. */ if (proto_connect(saddr, daddr, timeout, &tcp) == -1) exit(EX_TEMPFAIL); #if OPENSSL_VERSION_NUMBER < 0x10100000L SSL_load_error_strings(); SSL_library_init(); #endif /* * TODO: On FreeBSD we could move this below sandbox() once libc and * libcrypto use sysctl kern.arandom to obtain random data * instead of /dev/urandom and friends. */ sslctx = SSL_CTX_new(TLS_client_method()); if (sslctx == NULL) pjdlog_exitx(EX_TEMPFAIL, "SSL_CTX_new() failed."); if (sandbox(user, true, "proto_tls client: %s", dstaddr) != 0) pjdlog_exitx(EX_CONFIG, "Unable to sandbox TLS client."); pjdlog_debug(1, "Privileges successfully dropped."); SSL_CTX_set_options(sslctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3); /* Load CA certs. */ /* TODO */ //SSL_CTX_load_verify_locations(sslctx, cacerts_file, NULL); ssl = SSL_new(sslctx); if (ssl == NULL) pjdlog_exitx(EX_TEMPFAIL, "SSL_new() failed."); tcpfd = proto_descriptor(tcp); block(tcpfd); if (SSL_set_fd(ssl, tcpfd) != 1) pjdlog_exitx(EX_TEMPFAIL, "SSL_set_fd() failed."); ret = SSL_connect(ssl); ssl_check_error(ssl, (int)ret); nonblock(sockfd); nonblock(tcpfd); tls_certificate_verify(ssl, fingerprint); /* * The following byte is sent to make proto_connect_wait() work. */ connected = 1; for (;;) { switch (send(sockfd, &connected, sizeof(connected), 0)) { case -1: if (errno == EINTR || errno == ENOBUFS) continue; if (errno == EAGAIN) { (void)wait_for_fd(sockfd, -1); continue; } pjdlog_exit(EX_TEMPFAIL, "send() failed"); case 0: pjdlog_debug(1, "Connection terminated."); exit(0); case 1: break; } break; } tls_loop(sockfd, ssl); } static void tls_call_exec_client(struct proto_conn *sock, const char *srcaddr, const char *dstaddr, int timeout) { char *timeoutstr, *startfdstr, *debugstr; int startfd; /* Declare that we are receiver. */ proto_recv(sock, NULL, 0); if (pjdlog_mode_get() == PJDLOG_MODE_STD) startfd = 3; else /* if (pjdlog_mode_get() == PJDLOG_MODE_SYSLOG) */ startfd = 0; if (proto_descriptor(sock) != startfd) { /* Move socketpair descriptor to descriptor number startfd. */ if (dup2(proto_descriptor(sock), startfd) == -1) pjdlog_exit(EX_OSERR, "dup2() failed"); proto_close(sock); } else { /* * The FD_CLOEXEC is cleared by dup2(2), so when we do not * call it, we have to clear it by hand in case it is set. */ if (fcntl(startfd, F_SETFD, 0) == -1) pjdlog_exit(EX_OSERR, "fcntl() failed"); } closefrom(startfd + 1); if (asprintf(&startfdstr, "%d", startfd) == -1) pjdlog_exit(EX_TEMPFAIL, "asprintf() failed"); if (timeout == -1) timeout = TLS_DEFAULT_TIMEOUT; if (asprintf(&timeoutstr, "%d", timeout) == -1) pjdlog_exit(EX_TEMPFAIL, "asprintf() failed"); if (asprintf(&debugstr, "%d", pjdlog_debug_get()) == -1) pjdlog_exit(EX_TEMPFAIL, "asprintf() failed"); execl(proto_get("execpath"), proto_get("execpath"), "proto", "tls", proto_get("user"), "client", startfdstr, srcaddr == NULL ? "" : srcaddr, dstaddr, proto_get("tls:fingerprint"), proto_get("tcp:port"), timeoutstr, debugstr, NULL); pjdlog_exit(EX_SOFTWARE, "execl() failed"); } static int tls_connect(const char *srcaddr, const char *dstaddr, int timeout, void **ctxp) { struct tls_ctx *tlsctx; struct proto_conn *sock; pid_t pid; int error; PJDLOG_ASSERT(srcaddr == NULL || srcaddr[0] != '\0'); PJDLOG_ASSERT(dstaddr != NULL); PJDLOG_ASSERT(timeout >= -1); PJDLOG_ASSERT(ctxp != NULL); if (strncmp(dstaddr, "tls://", 6) != 0) return (-1); if (srcaddr != NULL && strncmp(srcaddr, "tls://", 6) != 0) return (-1); if (proto_connect(NULL, "socketpair://", -1, &sock) == -1) return (errno); #if 0 /* * We use rfork() with the following flags to disable SIGCHLD * delivery upon the sandbox process exit. */ pid = rfork(RFFDG | RFPROC | RFTSIGZMB | RFTSIGFLAGS(0)); #else /* * We don't use rfork() to be able to log information about sandbox * process exiting. */ pid = fork(); #endif switch (pid) { case -1: /* Failure. */ error = errno; proto_close(sock); return (error); case 0: /* Child. */ pjdlog_prefix_set("[TLS sandbox] (client) "); #ifdef HAVE_SETPROCTITLE setproctitle("[TLS sandbox] (client) "); #endif tls_call_exec_client(sock, srcaddr, dstaddr, timeout); /* NOTREACHED */ default: /* Parent. */ tlsctx = calloc(1, sizeof(*tlsctx)); if (tlsctx == NULL) { error = errno; proto_close(sock); (void)kill(pid, SIGKILL); return (error); } proto_send(sock, NULL, 0); tlsctx->tls_sock = sock; tlsctx->tls_tcp = NULL; tlsctx->tls_side = TLS_SIDE_CLIENT; tlsctx->tls_wait_called = false; tlsctx->tls_magic = TLS_CTX_MAGIC; if (timeout >= 0) { error = tls_connect_wait(tlsctx, timeout); if (error != 0) { (void)kill(pid, SIGKILL); tls_close(tlsctx); return (error); } } *ctxp = tlsctx; return (0); } } static int tls_connect_wait(void *ctx, int timeout) { struct tls_ctx *tlsctx = ctx; int error, sockfd; uint8_t connected; PJDLOG_ASSERT(tlsctx != NULL); PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC); PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT); PJDLOG_ASSERT(tlsctx->tls_sock != NULL); PJDLOG_ASSERT(!tlsctx->tls_wait_called); PJDLOG_ASSERT(timeout >= 0); sockfd = proto_descriptor(tlsctx->tls_sock); error = wait_for_fd(sockfd, timeout); if (error != 0) return (error); for (;;) { switch (recv(sockfd, &connected, sizeof(connected), MSG_WAITALL)) { case -1: if (errno == EINTR || errno == ENOBUFS) continue; error = errno; break; case 0: pjdlog_debug(1, "Connection terminated."); error = ENOTCONN; break; case 1: tlsctx->tls_wait_called = true; break; } break; } return (error); } static int tls_server(const char *lstaddr, void **ctxp) { struct proto_conn *tcp; struct tls_ctx *tlsctx; char *laddr; int error; if (strncmp(lstaddr, "tls://", 6) != 0) return (-1); tlsctx = malloc(sizeof(*tlsctx)); if (tlsctx == NULL) { pjdlog_warning("Unable to allocate memory."); return (ENOMEM); } laddr = strdup(lstaddr); if (laddr == NULL) { free(tlsctx); pjdlog_warning("Unable to allocate memory."); return (ENOMEM); } bcopy("tcp://", laddr, 6); if (proto_server(laddr, &tcp) == -1) { error = errno; free(tlsctx); free(laddr); return (error); } free(laddr); tlsctx->tls_sock = NULL; tlsctx->tls_tcp = tcp; tlsctx->tls_side = TLS_SIDE_SERVER_LISTEN; tlsctx->tls_wait_called = true; tlsctx->tls_magic = TLS_CTX_MAGIC; *ctxp = tlsctx; return (0); } static void tls_exec_server(const char *user, int startfd, const char *privkey, const char *cert, int debuglevel) { SSL_CTX *sslctx; SSL *ssl; int sockfd, tcpfd, ret; pjdlog_debug_set(debuglevel); pjdlog_prefix_set("[TLS sandbox] (server) "); #ifdef HAVE_SETPROCTITLE setproctitle("[TLS sandbox] (server) "); #endif sockfd = startfd; tcpfd = startfd + 1; #if OPENSSL_VERSION_NUMBER < 0x10100000L SSL_load_error_strings(); SSL_library_init(); #endif sslctx = SSL_CTX_new(TLS_server_method()); if (sslctx == NULL) pjdlog_exitx(EX_TEMPFAIL, "SSL_CTX_new() failed."); SSL_CTX_set_options(sslctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3); ssl = SSL_new(sslctx); if (ssl == NULL) pjdlog_exitx(EX_TEMPFAIL, "SSL_new() failed."); if (SSL_use_RSAPrivateKey_file(ssl, privkey, SSL_FILETYPE_PEM) != 1) { ssl_log_errors(); pjdlog_exitx(EX_CONFIG, "SSL_use_RSAPrivateKey_file(%s) failed.", privkey); } if (SSL_use_certificate_file(ssl, cert, SSL_FILETYPE_PEM) != 1) { ssl_log_errors(); pjdlog_exitx(EX_CONFIG, "SSL_use_certificate_file(%s) failed.", cert); } if (sandbox(user, true, "proto_tls server") != 0) pjdlog_exitx(EX_CONFIG, "Unable to sandbox TLS server."); pjdlog_debug(1, "Privileges successfully dropped."); nonblock(sockfd); nonblock(tcpfd); if (SSL_set_fd(ssl, tcpfd) != 1) pjdlog_exitx(EX_TEMPFAIL, "SSL_set_fd() failed."); ret = SSL_accept(ssl); ssl_check_error(ssl, ret); tls_loop(sockfd, ssl); } static void tls_call_exec_server(struct proto_conn *sock, struct proto_conn *tcp) { int startfd, sockfd, tcpfd, safefd; char *startfdstr, *debugstr; if (pjdlog_mode_get() == PJDLOG_MODE_STD) startfd = 3; else /* if (pjdlog_mode_get() == PJDLOG_MODE_SYSLOG) */ startfd = 0; /* Declare that we are receiver. */ proto_send(sock, NULL, 0); sockfd = proto_descriptor(sock); tcpfd = proto_descriptor(tcp); safefd = MAX(sockfd, tcpfd); safefd = MAX(safefd, startfd); safefd++; /* Move sockfd and tcpfd to safe numbers first. */ if (dup2(sockfd, safefd) == -1) pjdlog_exit(EX_OSERR, "dup2() failed"); proto_close(sock); sockfd = safefd; if (dup2(tcpfd, safefd + 1) == -1) pjdlog_exit(EX_OSERR, "dup2() failed"); proto_close(tcp); tcpfd = safefd + 1; /* Move socketpair descriptor to descriptor number startfd. */ if (dup2(sockfd, startfd) == -1) pjdlog_exit(EX_OSERR, "dup2() failed"); (void)close(sockfd); /* Move tcp descriptor to descriptor number startfd + 1. */ if (dup2(tcpfd, startfd + 1) == -1) pjdlog_exit(EX_OSERR, "dup2() failed"); (void)close(tcpfd); closefrom(startfd + 2); /* * Even if FD_CLOEXEC was set on descriptors before dup2(), it should * have been cleared on dup2(), but better be safe than sorry. */ if (fcntl(startfd, F_SETFD, 0) == -1) pjdlog_exit(EX_OSERR, "fcntl() failed"); if (fcntl(startfd + 1, F_SETFD, 0) == -1) pjdlog_exit(EX_OSERR, "fcntl() failed"); if (asprintf(&startfdstr, "%d", startfd) == -1) pjdlog_exit(EX_TEMPFAIL, "asprintf() failed"); if (asprintf(&debugstr, "%d", pjdlog_debug_get()) == -1) pjdlog_exit(EX_TEMPFAIL, "asprintf() failed"); execl(proto_get("execpath"), proto_get("execpath"), "proto", "tls", proto_get("user"), "server", startfdstr, proto_get("tls:keyfile"), proto_get("tls:certfile"), debugstr, NULL); pjdlog_exit(EX_SOFTWARE, "execl() failed"); } static int tls_accept(void *ctx, void **newctxp) { struct tls_ctx *tlsctx = ctx; struct tls_ctx *newtlsctx; struct proto_conn *sock, *tcp; pid_t pid; int error; PJDLOG_ASSERT(tlsctx != NULL); PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC); PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_SERVER_LISTEN); if (proto_connect(NULL, "socketpair://", -1, &sock) == -1) return (errno); /* Accept TCP connection. */ if (proto_accept(tlsctx->tls_tcp, &tcp) == -1) { error = errno; proto_close(sock); return (error); } pid = fork(); switch (pid) { case -1: /* Failure. */ error = errno; proto_close(sock); return (error); case 0: /* Child. */ pjdlog_prefix_set("[TLS sandbox] (server) "); #ifdef HAVE_SETPROCTITLE setproctitle("[TLS sandbox] (server) "); #endif /* Close listen socket. */ proto_close(tlsctx->tls_tcp); tls_call_exec_server(sock, tcp); /* NOTREACHED */ PJDLOG_ABORT("Unreachable."); default: /* Parent. */ newtlsctx = calloc(1, sizeof(*tlsctx)); if (newtlsctx == NULL) { error = errno; proto_close(sock); proto_close(tcp); (void)kill(pid, SIGKILL); return (error); } proto_local_address(tcp, newtlsctx->tls_laddr, sizeof(newtlsctx->tls_laddr)); PJDLOG_ASSERT(strncmp(newtlsctx->tls_laddr, "tcp://", 6) == 0); bcopy("tls://", newtlsctx->tls_laddr, 6); *strrchr(newtlsctx->tls_laddr, ':') = '\0'; proto_remote_address(tcp, newtlsctx->tls_raddr, sizeof(newtlsctx->tls_raddr)); PJDLOG_ASSERT(strncmp(newtlsctx->tls_raddr, "tcp://", 6) == 0); bcopy("tls://", newtlsctx->tls_raddr, 6); *strrchr(newtlsctx->tls_raddr, ':') = '\0'; proto_close(tcp); proto_recv(sock, NULL, 0); newtlsctx->tls_sock = sock; newtlsctx->tls_tcp = NULL; newtlsctx->tls_wait_called = true; newtlsctx->tls_side = TLS_SIDE_SERVER_WORK; newtlsctx->tls_magic = TLS_CTX_MAGIC; *newctxp = newtlsctx; return (0); } } static int tls_wrap(int fd, bool client, void **ctxp) { struct tls_ctx *tlsctx; struct proto_conn *sock; int error; tlsctx = calloc(1, sizeof(*tlsctx)); if (tlsctx == NULL) return (errno); if (proto_wrap("socketpair", client, fd, &sock) == -1) { error = errno; free(tlsctx); return (error); } tlsctx->tls_sock = sock; tlsctx->tls_tcp = NULL; tlsctx->tls_wait_called = (client ? false : true); tlsctx->tls_side = (client ? TLS_SIDE_CLIENT : TLS_SIDE_SERVER_WORK); tlsctx->tls_magic = TLS_CTX_MAGIC; *ctxp = tlsctx; return (0); } static int tls_send(void *ctx, const unsigned char *data, size_t size, int fd) { struct tls_ctx *tlsctx = ctx; PJDLOG_ASSERT(tlsctx != NULL); PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC); PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT || tlsctx->tls_side == TLS_SIDE_SERVER_WORK); PJDLOG_ASSERT(tlsctx->tls_sock != NULL); PJDLOG_ASSERT(tlsctx->tls_wait_called); PJDLOG_ASSERT(fd == -1); if (proto_send(tlsctx->tls_sock, data, size) == -1) return (errno); return (0); } static int tls_recv(void *ctx, unsigned char *data, size_t size, int *fdp) { struct tls_ctx *tlsctx = ctx; PJDLOG_ASSERT(tlsctx != NULL); PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC); PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT || tlsctx->tls_side == TLS_SIDE_SERVER_WORK); PJDLOG_ASSERT(tlsctx->tls_sock != NULL); PJDLOG_ASSERT(tlsctx->tls_wait_called); PJDLOG_ASSERT(fdp == NULL); if (proto_recv(tlsctx->tls_sock, data, size) == -1) return (errno); return (0); } static int tls_descriptor(const void *ctx) { const struct tls_ctx *tlsctx = ctx; PJDLOG_ASSERT(tlsctx != NULL); PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC); switch (tlsctx->tls_side) { case TLS_SIDE_CLIENT: case TLS_SIDE_SERVER_WORK: PJDLOG_ASSERT(tlsctx->tls_sock != NULL); return (proto_descriptor(tlsctx->tls_sock)); case TLS_SIDE_SERVER_LISTEN: PJDLOG_ASSERT(tlsctx->tls_tcp != NULL); return (proto_descriptor(tlsctx->tls_tcp)); default: PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side); } } static bool tcp_address_match(const void *ctx, const char *addr) { const struct tls_ctx *tlsctx = ctx; PJDLOG_ASSERT(tlsctx != NULL); PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC); return (strcmp(tlsctx->tls_raddr, addr) == 0); } static void tls_local_address(const void *ctx, char *addr, size_t size) { const struct tls_ctx *tlsctx = ctx; PJDLOG_ASSERT(tlsctx != NULL); PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC); PJDLOG_ASSERT(tlsctx->tls_wait_called); switch (tlsctx->tls_side) { case TLS_SIDE_CLIENT: PJDLOG_ASSERT(tlsctx->tls_sock != NULL); PJDLOG_VERIFY(strlcpy(addr, "tls://N/A", size) < size); break; case TLS_SIDE_SERVER_WORK: PJDLOG_ASSERT(tlsctx->tls_sock != NULL); PJDLOG_VERIFY(strlcpy(addr, tlsctx->tls_laddr, size) < size); break; case TLS_SIDE_SERVER_LISTEN: PJDLOG_ASSERT(tlsctx->tls_tcp != NULL); proto_local_address(tlsctx->tls_tcp, addr, size); PJDLOG_ASSERT(strncmp(addr, "tcp://", 6) == 0); /* Replace tcp:// prefix with tls:// */ bcopy("tls://", addr, 6); break; default: PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side); } } static void tls_remote_address(const void *ctx, char *addr, size_t size) { const struct tls_ctx *tlsctx = ctx; PJDLOG_ASSERT(tlsctx != NULL); PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC); PJDLOG_ASSERT(tlsctx->tls_wait_called); switch (tlsctx->tls_side) { case TLS_SIDE_CLIENT: PJDLOG_ASSERT(tlsctx->tls_sock != NULL); PJDLOG_VERIFY(strlcpy(addr, "tls://N/A", size) < size); break; case TLS_SIDE_SERVER_WORK: PJDLOG_ASSERT(tlsctx->tls_sock != NULL); PJDLOG_VERIFY(strlcpy(addr, tlsctx->tls_raddr, size) < size); break; case TLS_SIDE_SERVER_LISTEN: PJDLOG_ASSERT(tlsctx->tls_tcp != NULL); proto_remote_address(tlsctx->tls_tcp, addr, size); PJDLOG_ASSERT(strncmp(addr, "tcp://", 6) == 0); /* Replace tcp:// prefix with tls:// */ bcopy("tls://", addr, 6); break; default: PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side); } } static void tls_close(void *ctx) { struct tls_ctx *tlsctx = ctx; PJDLOG_ASSERT(tlsctx != NULL); PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC); if (tlsctx->tls_sock != NULL) { proto_close(tlsctx->tls_sock); tlsctx->tls_sock = NULL; } if (tlsctx->tls_tcp != NULL) { proto_close(tlsctx->tls_tcp); tlsctx->tls_tcp = NULL; } tlsctx->tls_side = 0; tlsctx->tls_magic = 0; free(tlsctx); } static int tls_exec(int argc, char *argv[]) { PJDLOG_ASSERT(argc > 3); PJDLOG_ASSERT(strcmp(argv[0], "tls") == 0); pjdlog_init(atoi(argv[3]) == 0 ? PJDLOG_MODE_SYSLOG : PJDLOG_MODE_STD); if (strcmp(argv[2], "client") == 0) { if (argc != 10) return (EINVAL); tls_exec_client(argv[1], atoi(argv[3]), argv[4][0] == '\0' ? NULL : argv[4], argv[5], argv[6], argv[7], atoi(argv[8]), atoi(argv[9])); } else if (strcmp(argv[2], "server") == 0) { if (argc != 7) return (EINVAL); tls_exec_server(argv[1], atoi(argv[3]), argv[4], argv[5], atoi(argv[6])); } return (EINVAL); } static struct proto tls_proto = { .prt_name = "tls", .prt_connect = tls_connect, .prt_connect_wait = tls_connect_wait, .prt_server = tls_server, .prt_accept = tls_accept, .prt_wrap = tls_wrap, .prt_send = tls_send, .prt_recv = tls_recv, .prt_descriptor = tls_descriptor, .prt_address_match = tcp_address_match, .prt_local_address = tls_local_address, .prt_remote_address = tls_remote_address, .prt_close = tls_close, .prt_exec = tls_exec }; static __constructor void tls_ctor(void) { proto_register(&tls_proto, false); }