1 /*-
2 * Copyright (c) 2011 The FreeBSD Foundation
3 * All rights reserved.
4 *
5 * This software was developed by Pawel Jakub Dawidek under sponsorship from
6 * the FreeBSD Foundation.
7 *
8 * Redistribution and use in source and binary forms, with or without
9 * modification, are permitted provided that the following conditions
10 * are met:
11 * 1. Redistributions of source code must retain the above copyright
12 * notice, this list of conditions and the following disclaimer.
13 * 2. Redistributions in binary form must reproduce the above copyright
14 * notice, this list of conditions and the following disclaimer in the
15 * documentation and/or other materials provided with the distribution.
16 *
17 * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
18 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
21 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
23 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
24 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
26 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
27 * SUCH DAMAGE.
28 */
29
30 #include <config/config.h>
31
32 #include <sys/param.h> /* MAXHOSTNAMELEN */
33 #include <sys/socket.h>
34
35 #include <arpa/inet.h>
36
37 #include <netinet/in.h>
38 #include <netinet/tcp.h>
39
40 #include <errno.h>
41 #include <fcntl.h>
42 #include <netdb.h>
43 #include <signal.h>
44 #include <stdbool.h>
45 #include <stdint.h>
46 #include <stdio.h>
47 #include <string.h>
48 #include <unistd.h>
49
50 #include <openssl/err.h>
51 #include <openssl/ssl.h>
52
53 #include <compat/compat.h>
54 #ifndef HAVE_CLOSEFROM
55 #include <compat/closefrom.h>
56 #endif
57 #ifndef HAVE_STRLCPY
58 #include <compat/strlcpy.h>
59 #endif
60
61 #include "pjdlog.h"
62 #include "proto_impl.h"
63 #include "sandbox.h"
64 #include "subr.h"
65
66 #define TLS_CTX_MAGIC 0x715c7
67 struct tls_ctx {
68 int tls_magic;
69 struct proto_conn *tls_sock;
70 struct proto_conn *tls_tcp;
71 char tls_laddr[256];
72 char tls_raddr[256];
73 int tls_side;
74 #define TLS_SIDE_CLIENT 0
75 #define TLS_SIDE_SERVER_LISTEN 1
76 #define TLS_SIDE_SERVER_WORK 2
77 bool tls_wait_called;
78 };
79
80 #define TLS_DEFAULT_TIMEOUT 30
81
82 static int tls_connect_wait(void *ctx, int timeout);
83 static void tls_close(void *ctx);
84
85 static void
block(int fd)86 block(int fd)
87 {
88 int flags;
89
90 flags = fcntl(fd, F_GETFL);
91 if (flags == -1)
92 pjdlog_exit(EX_TEMPFAIL, "fcntl(F_GETFL) failed");
93 flags &= ~O_NONBLOCK;
94 if (fcntl(fd, F_SETFL, flags) == -1)
95 pjdlog_exit(EX_TEMPFAIL, "fcntl(F_SETFL) failed");
96 }
97
98 static void
nonblock(int fd)99 nonblock(int fd)
100 {
101 int flags;
102
103 flags = fcntl(fd, F_GETFL);
104 if (flags == -1)
105 pjdlog_exit(EX_TEMPFAIL, "fcntl(F_GETFL) failed");
106 flags |= O_NONBLOCK;
107 if (fcntl(fd, F_SETFL, flags) == -1)
108 pjdlog_exit(EX_TEMPFAIL, "fcntl(F_SETFL) failed");
109 }
110
111 static int
wait_for_fd(int fd,int timeout)112 wait_for_fd(int fd, int timeout)
113 {
114 struct timeval tv;
115 fd_set fdset;
116 int error, ret;
117
118 error = 0;
119
120 for (;;) {
121 FD_ZERO(&fdset);
122 FD_SET(fd, &fdset);
123
124 tv.tv_sec = timeout;
125 tv.tv_usec = 0;
126
127 ret = select(fd + 1, NULL, &fdset, NULL,
128 timeout == -1 ? NULL : &tv);
129 if (ret == 0) {
130 error = ETIMEDOUT;
131 break;
132 } else if (ret == -1) {
133 if (errno == EINTR)
134 continue;
135 error = errno;
136 break;
137 }
138 PJDLOG_ASSERT(ret > 0);
139 PJDLOG_ASSERT(FD_ISSET(fd, &fdset));
140 break;
141 }
142
143 return (error);
144 }
145
146 static void
ssl_log_errors(void)147 ssl_log_errors(void)
148 {
149 unsigned long error;
150
151 while ((error = ERR_get_error()) != 0)
152 pjdlog_error("SSL error: %s", ERR_error_string(error, NULL));
153 }
154
155 static int
ssl_check_error(SSL * ssl,int ret)156 ssl_check_error(SSL *ssl, int ret)
157 {
158 int error;
159
160 error = SSL_get_error(ssl, ret);
161
162 switch (error) {
163 case SSL_ERROR_NONE:
164 return (0);
165 case SSL_ERROR_WANT_READ:
166 pjdlog_debug(2, "SSL_ERROR_WANT_READ");
167 return (-1);
168 case SSL_ERROR_WANT_WRITE:
169 pjdlog_debug(2, "SSL_ERROR_WANT_WRITE");
170 return (-1);
171 case SSL_ERROR_ZERO_RETURN:
172 pjdlog_exitx(EX_OK, "Connection closed.");
173 case SSL_ERROR_SYSCALL:
174 ssl_log_errors();
175 pjdlog_exitx(EX_TEMPFAIL, "SSL I/O error.");
176 case SSL_ERROR_SSL:
177 ssl_log_errors();
178 pjdlog_exitx(EX_TEMPFAIL, "SSL protocol error.");
179 default:
180 ssl_log_errors();
181 pjdlog_exitx(EX_TEMPFAIL, "Unknown SSL error (%d).", error);
182 }
183 }
184
185 static void
tcp_recv_ssl_send(int recvfd,SSL * sendssl)186 tcp_recv_ssl_send(int recvfd, SSL *sendssl)
187 {
188 static unsigned char buf[65536];
189 ssize_t tcpdone;
190 int sendfd, ssldone;
191
192 sendfd = SSL_get_fd(sendssl);
193 PJDLOG_ASSERT(sendfd >= 0);
194 pjdlog_debug(2, "%s: start %d -> %d", __func__, recvfd, sendfd);
195 for (;;) {
196 tcpdone = recv(recvfd, buf, sizeof(buf), 0);
197 pjdlog_debug(2, "%s: recv() returned %zd", __func__, tcpdone);
198 if (tcpdone == 0) {
199 pjdlog_debug(1, "Connection terminated.");
200 exit(0);
201 } else if (tcpdone == -1) {
202 if (errno == EINTR)
203 continue;
204 else if (errno == EAGAIN)
205 break;
206 pjdlog_exit(EX_TEMPFAIL, "recv() failed");
207 }
208 for (;;) {
209 ssldone = SSL_write(sendssl, buf, (int)tcpdone);
210 pjdlog_debug(2, "%s: send() returned %d", __func__,
211 ssldone);
212 if (ssl_check_error(sendssl, ssldone) == -1) {
213 (void)wait_for_fd(sendfd, -1);
214 continue;
215 }
216 PJDLOG_ASSERT(ssldone == tcpdone);
217 break;
218 }
219 }
220 pjdlog_debug(2, "%s: done %d -> %d", __func__, recvfd, sendfd);
221 }
222
223 static void
ssl_recv_tcp_send(SSL * recvssl,int sendfd)224 ssl_recv_tcp_send(SSL *recvssl, int sendfd)
225 {
226 static unsigned char buf[65536];
227 unsigned char *ptr;
228 ssize_t tcpdone;
229 size_t todo;
230 int recvfd, ssldone;
231
232 recvfd = SSL_get_fd(recvssl);
233 PJDLOG_ASSERT(recvfd >= 0);
234 pjdlog_debug(2, "%s: start %d -> %d", __func__, recvfd, sendfd);
235 for (;;) {
236 ssldone = SSL_read(recvssl, buf, sizeof(buf));
237 pjdlog_debug(2, "%s: SSL_read() returned %d", __func__,
238 ssldone);
239 if (ssl_check_error(recvssl, ssldone) == -1)
240 break;
241 todo = (size_t)ssldone;
242 ptr = buf;
243 do {
244 tcpdone = send(sendfd, ptr, todo, MSG_NOSIGNAL);
245 pjdlog_debug(2, "%s: send() returned %zd", __func__,
246 tcpdone);
247 if (tcpdone == 0) {
248 pjdlog_debug(1, "Connection terminated.");
249 exit(0);
250 } else if (tcpdone == -1) {
251 if (errno == EINTR || errno == ENOBUFS)
252 continue;
253 if (errno == EAGAIN) {
254 (void)wait_for_fd(sendfd, -1);
255 continue;
256 }
257 pjdlog_exit(EX_TEMPFAIL, "send() failed");
258 }
259 todo -= tcpdone;
260 ptr += tcpdone;
261 } while (todo > 0);
262 }
263 pjdlog_debug(2, "%s: done %d -> %d", __func__, recvfd, sendfd);
264 }
265
266 static void
tls_loop(int sockfd,SSL * tcpssl)267 tls_loop(int sockfd, SSL *tcpssl)
268 {
269 fd_set fds;
270 int maxfd, tcpfd;
271
272 tcpfd = SSL_get_fd(tcpssl);
273 PJDLOG_ASSERT(tcpfd >= 0);
274
275 for (;;) {
276 FD_ZERO(&fds);
277 FD_SET(sockfd, &fds);
278 FD_SET(tcpfd, &fds);
279 maxfd = MAX(sockfd, tcpfd);
280
281 PJDLOG_ASSERT(maxfd + 1 <= (int)FD_SETSIZE);
282 if (select(maxfd + 1, &fds, NULL, NULL, NULL) == -1) {
283 if (errno == EINTR)
284 continue;
285 pjdlog_exit(EX_TEMPFAIL, "select() failed");
286 }
287 if (FD_ISSET(sockfd, &fds))
288 tcp_recv_ssl_send(sockfd, tcpssl);
289 if (FD_ISSET(tcpfd, &fds))
290 ssl_recv_tcp_send(tcpssl, sockfd);
291 }
292 }
293
294 static void
tls_certificate_verify(SSL * ssl,const char * fingerprint)295 tls_certificate_verify(SSL *ssl, const char *fingerprint)
296 {
297 unsigned char md[EVP_MAX_MD_SIZE];
298 char mdstr[sizeof("SHA256=") - 1 + EVP_MAX_MD_SIZE * 3];
299 char *mdstrp;
300 unsigned int i, mdsize;
301 X509 *cert;
302
303 if (fingerprint[0] == '\0') {
304 pjdlog_debug(1, "No fingerprint verification requested.");
305 return;
306 }
307
308 cert = SSL_get_peer_certificate(ssl);
309 if (cert == NULL)
310 pjdlog_exitx(EX_TEMPFAIL, "No peer certificate received.");
311
312 if (X509_digest(cert, EVP_sha256(), md, &mdsize) != 1)
313 pjdlog_exitx(EX_TEMPFAIL, "X509_digest() failed.");
314 PJDLOG_ASSERT(mdsize <= EVP_MAX_MD_SIZE);
315
316 X509_free(cert);
317
318 (void)strlcpy(mdstr, "SHA256=", sizeof(mdstr));
319 mdstrp = mdstr + strlen(mdstr);
320 for (i = 0; i < mdsize; i++) {
321 PJDLOG_VERIFY(mdstrp + 3 <= mdstr + sizeof(mdstr));
322 (void)sprintf(mdstrp, "%02hhX:", md[i]);
323 mdstrp += 3;
324 }
325 /* Clear last colon. */
326 mdstrp[-1] = '\0';
327 if (strcasecmp(mdstr, fingerprint) != 0) {
328 pjdlog_exitx(EX_NOPERM,
329 "Finger print doesn't match. Received \"%s\", expected \"%s\"",
330 mdstr, fingerprint);
331 }
332 }
333
334 static void
tls_exec_client(const char * user,int startfd,const char * srcaddr,const char * dstaddr,const char * fingerprint,const char * defport,int timeout,int debuglevel)335 tls_exec_client(const char *user, int startfd, const char *srcaddr,
336 const char *dstaddr, const char *fingerprint, const char *defport,
337 int timeout, int debuglevel)
338 {
339 struct proto_conn *tcp;
340 char *saddr, *daddr;
341 SSL_CTX *sslctx;
342 SSL *ssl;
343 long ret;
344 int sockfd, tcpfd;
345 uint8_t connected;
346
347 pjdlog_debug_set(debuglevel);
348 pjdlog_prefix_set("[TLS sandbox] (client) ");
349 #ifdef HAVE_SETPROCTITLE
350 setproctitle("[TLS sandbox] (client) ");
351 #endif
352 proto_set("tcp:port", defport);
353
354 sockfd = startfd;
355
356 /* Change tls:// to tcp://. */
357 if (srcaddr == NULL) {
358 saddr = NULL;
359 } else {
360 saddr = strdup(srcaddr);
361 if (saddr == NULL)
362 pjdlog_exitx(EX_TEMPFAIL, "Unable to allocate memory.");
363 bcopy("tcp://", saddr, 6);
364 }
365 daddr = strdup(dstaddr);
366 if (daddr == NULL)
367 pjdlog_exitx(EX_TEMPFAIL, "Unable to allocate memory.");
368 bcopy("tcp://", daddr, 6);
369
370 /* Establish TCP connection. */
371 if (proto_connect(saddr, daddr, timeout, &tcp) == -1)
372 exit(EX_TEMPFAIL);
373
374 #if OPENSSL_VERSION_NUMBER < 0x10100000L
375 SSL_load_error_strings();
376 SSL_library_init();
377 #endif
378
379 /*
380 * TODO: On FreeBSD we could move this below sandbox() once libc and
381 * libcrypto use sysctl kern.arandom to obtain random data
382 * instead of /dev/urandom and friends.
383 */
384 sslctx = SSL_CTX_new(TLS_client_method());
385 if (sslctx == NULL)
386 pjdlog_exitx(EX_TEMPFAIL, "SSL_CTX_new() failed.");
387
388 if (sandbox(user, true, "proto_tls client: %s", dstaddr) != 0)
389 pjdlog_exitx(EX_CONFIG, "Unable to sandbox TLS client.");
390 pjdlog_debug(1, "Privileges successfully dropped.");
391
392 SSL_CTX_set_options(sslctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);
393
394 /* Load CA certs. */
395 /* TODO */
396 //SSL_CTX_load_verify_locations(sslctx, cacerts_file, NULL);
397
398 ssl = SSL_new(sslctx);
399 if (ssl == NULL)
400 pjdlog_exitx(EX_TEMPFAIL, "SSL_new() failed.");
401
402 tcpfd = proto_descriptor(tcp);
403
404 block(tcpfd);
405
406 if (SSL_set_fd(ssl, tcpfd) != 1)
407 pjdlog_exitx(EX_TEMPFAIL, "SSL_set_fd() failed.");
408
409 ret = SSL_connect(ssl);
410 ssl_check_error(ssl, (int)ret);
411
412 nonblock(sockfd);
413 nonblock(tcpfd);
414
415 tls_certificate_verify(ssl, fingerprint);
416
417 /*
418 * The following byte is sent to make proto_connect_wait() work.
419 */
420 connected = 1;
421 for (;;) {
422 switch (send(sockfd, &connected, sizeof(connected), 0)) {
423 case -1:
424 if (errno == EINTR || errno == ENOBUFS)
425 continue;
426 if (errno == EAGAIN) {
427 (void)wait_for_fd(sockfd, -1);
428 continue;
429 }
430 pjdlog_exit(EX_TEMPFAIL, "send() failed");
431 case 0:
432 pjdlog_debug(1, "Connection terminated.");
433 exit(0);
434 case 1:
435 break;
436 }
437 break;
438 }
439
440 tls_loop(sockfd, ssl);
441 }
442
443 static void
tls_call_exec_client(struct proto_conn * sock,const char * srcaddr,const char * dstaddr,int timeout)444 tls_call_exec_client(struct proto_conn *sock, const char *srcaddr,
445 const char *dstaddr, int timeout)
446 {
447 char *timeoutstr, *startfdstr, *debugstr;
448 int startfd;
449
450 /* Declare that we are receiver. */
451 proto_recv(sock, NULL, 0);
452
453 if (pjdlog_mode_get() == PJDLOG_MODE_STD)
454 startfd = 3;
455 else /* if (pjdlog_mode_get() == PJDLOG_MODE_SYSLOG) */
456 startfd = 0;
457
458 if (proto_descriptor(sock) != startfd) {
459 /* Move socketpair descriptor to descriptor number startfd. */
460 if (dup2(proto_descriptor(sock), startfd) == -1)
461 pjdlog_exit(EX_OSERR, "dup2() failed");
462 proto_close(sock);
463 } else {
464 /*
465 * The FD_CLOEXEC is cleared by dup2(2), so when we do not
466 * call it, we have to clear it by hand in case it is set.
467 */
468 if (fcntl(startfd, F_SETFD, 0) == -1)
469 pjdlog_exit(EX_OSERR, "fcntl() failed");
470 }
471
472 closefrom(startfd + 1);
473
474 if (asprintf(&startfdstr, "%d", startfd) == -1)
475 pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
476 if (timeout == -1)
477 timeout = TLS_DEFAULT_TIMEOUT;
478 if (asprintf(&timeoutstr, "%d", timeout) == -1)
479 pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
480 if (asprintf(&debugstr, "%d", pjdlog_debug_get()) == -1)
481 pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
482
483 execl(proto_get("execpath"), proto_get("execpath"), "proto", "tls",
484 proto_get("user"), "client", startfdstr,
485 srcaddr == NULL ? "" : srcaddr, dstaddr,
486 proto_get("tls:fingerprint"), proto_get("tcp:port"), timeoutstr,
487 debugstr, NULL);
488 pjdlog_exit(EX_SOFTWARE, "execl() failed");
489 }
490
491 static int
tls_connect(const char * srcaddr,const char * dstaddr,int timeout,void ** ctxp)492 tls_connect(const char *srcaddr, const char *dstaddr, int timeout, void **ctxp)
493 {
494 struct tls_ctx *tlsctx;
495 struct proto_conn *sock;
496 pid_t pid;
497 int error;
498
499 PJDLOG_ASSERT(srcaddr == NULL || srcaddr[0] != '\0');
500 PJDLOG_ASSERT(dstaddr != NULL);
501 PJDLOG_ASSERT(timeout >= -1);
502 PJDLOG_ASSERT(ctxp != NULL);
503
504 if (strncmp(dstaddr, "tls://", 6) != 0)
505 return (-1);
506 if (srcaddr != NULL && strncmp(srcaddr, "tls://", 6) != 0)
507 return (-1);
508
509 if (proto_connect(NULL, "socketpair://", -1, &sock) == -1)
510 return (errno);
511
512 #if 0
513 /*
514 * We use rfork() with the following flags to disable SIGCHLD
515 * delivery upon the sandbox process exit.
516 */
517 pid = rfork(RFFDG | RFPROC | RFTSIGZMB | RFTSIGFLAGS(0));
518 #else
519 /*
520 * We don't use rfork() to be able to log information about sandbox
521 * process exiting.
522 */
523 pid = fork();
524 #endif
525 switch (pid) {
526 case -1:
527 /* Failure. */
528 error = errno;
529 proto_close(sock);
530 return (error);
531 case 0:
532 /* Child. */
533 pjdlog_prefix_set("[TLS sandbox] (client) ");
534 #ifdef HAVE_SETPROCTITLE
535 setproctitle("[TLS sandbox] (client) ");
536 #endif
537 tls_call_exec_client(sock, srcaddr, dstaddr, timeout);
538 /* NOTREACHED */
539 default:
540 /* Parent. */
541 tlsctx = calloc(1, sizeof(*tlsctx));
542 if (tlsctx == NULL) {
543 error = errno;
544 proto_close(sock);
545 (void)kill(pid, SIGKILL);
546 return (error);
547 }
548 proto_send(sock, NULL, 0);
549 tlsctx->tls_sock = sock;
550 tlsctx->tls_tcp = NULL;
551 tlsctx->tls_side = TLS_SIDE_CLIENT;
552 tlsctx->tls_wait_called = false;
553 tlsctx->tls_magic = TLS_CTX_MAGIC;
554 if (timeout >= 0) {
555 error = tls_connect_wait(tlsctx, timeout);
556 if (error != 0) {
557 (void)kill(pid, SIGKILL);
558 tls_close(tlsctx);
559 return (error);
560 }
561 }
562 *ctxp = tlsctx;
563 return (0);
564 }
565 }
566
567 static int
tls_connect_wait(void * ctx,int timeout)568 tls_connect_wait(void *ctx, int timeout)
569 {
570 struct tls_ctx *tlsctx = ctx;
571 int error, sockfd;
572 uint8_t connected;
573
574 PJDLOG_ASSERT(tlsctx != NULL);
575 PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
576 PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT);
577 PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
578 PJDLOG_ASSERT(!tlsctx->tls_wait_called);
579 PJDLOG_ASSERT(timeout >= 0);
580
581 sockfd = proto_descriptor(tlsctx->tls_sock);
582 error = wait_for_fd(sockfd, timeout);
583 if (error != 0)
584 return (error);
585
586 for (;;) {
587 switch (recv(sockfd, &connected, sizeof(connected),
588 MSG_WAITALL)) {
589 case -1:
590 if (errno == EINTR || errno == ENOBUFS)
591 continue;
592 error = errno;
593 break;
594 case 0:
595 pjdlog_debug(1, "Connection terminated.");
596 error = ENOTCONN;
597 break;
598 case 1:
599 tlsctx->tls_wait_called = true;
600 break;
601 }
602 break;
603 }
604
605 return (error);
606 }
607
608 static int
tls_server(const char * lstaddr,void ** ctxp)609 tls_server(const char *lstaddr, void **ctxp)
610 {
611 struct proto_conn *tcp;
612 struct tls_ctx *tlsctx;
613 char *laddr;
614 int error;
615
616 if (strncmp(lstaddr, "tls://", 6) != 0)
617 return (-1);
618
619 tlsctx = malloc(sizeof(*tlsctx));
620 if (tlsctx == NULL) {
621 pjdlog_warning("Unable to allocate memory.");
622 return (ENOMEM);
623 }
624
625 laddr = strdup(lstaddr);
626 if (laddr == NULL) {
627 free(tlsctx);
628 pjdlog_warning("Unable to allocate memory.");
629 return (ENOMEM);
630 }
631 bcopy("tcp://", laddr, 6);
632
633 if (proto_server(laddr, &tcp) == -1) {
634 error = errno;
635 free(tlsctx);
636 free(laddr);
637 return (error);
638 }
639 free(laddr);
640
641 tlsctx->tls_sock = NULL;
642 tlsctx->tls_tcp = tcp;
643 tlsctx->tls_side = TLS_SIDE_SERVER_LISTEN;
644 tlsctx->tls_wait_called = true;
645 tlsctx->tls_magic = TLS_CTX_MAGIC;
646 *ctxp = tlsctx;
647
648 return (0);
649 }
650
651 static void
tls_exec_server(const char * user,int startfd,const char * privkey,const char * cert,int debuglevel)652 tls_exec_server(const char *user, int startfd, const char *privkey,
653 const char *cert, int debuglevel)
654 {
655 SSL_CTX *sslctx;
656 SSL *ssl;
657 int sockfd, tcpfd, ret;
658
659 pjdlog_debug_set(debuglevel);
660 pjdlog_prefix_set("[TLS sandbox] (server) ");
661 #ifdef HAVE_SETPROCTITLE
662 setproctitle("[TLS sandbox] (server) ");
663 #endif
664
665 sockfd = startfd;
666 tcpfd = startfd + 1;
667
668 #if OPENSSL_VERSION_NUMBER < 0x10100000L
669 SSL_load_error_strings();
670 SSL_library_init();
671 #endif
672
673 sslctx = SSL_CTX_new(TLS_server_method());
674 if (sslctx == NULL)
675 pjdlog_exitx(EX_TEMPFAIL, "SSL_CTX_new() failed.");
676
677 SSL_CTX_set_options(sslctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);
678
679 ssl = SSL_new(sslctx);
680 if (ssl == NULL)
681 pjdlog_exitx(EX_TEMPFAIL, "SSL_new() failed.");
682
683 if (SSL_use_RSAPrivateKey_file(ssl, privkey, SSL_FILETYPE_PEM) != 1) {
684 ssl_log_errors();
685 pjdlog_exitx(EX_CONFIG,
686 "SSL_use_RSAPrivateKey_file(%s) failed.", privkey);
687 }
688
689 if (SSL_use_certificate_file(ssl, cert, SSL_FILETYPE_PEM) != 1) {
690 ssl_log_errors();
691 pjdlog_exitx(EX_CONFIG, "SSL_use_certificate_file(%s) failed.",
692 cert);
693 }
694
695 if (sandbox(user, true, "proto_tls server") != 0)
696 pjdlog_exitx(EX_CONFIG, "Unable to sandbox TLS server.");
697 pjdlog_debug(1, "Privileges successfully dropped.");
698
699 nonblock(sockfd);
700 nonblock(tcpfd);
701
702 if (SSL_set_fd(ssl, tcpfd) != 1)
703 pjdlog_exitx(EX_TEMPFAIL, "SSL_set_fd() failed.");
704
705 ret = SSL_accept(ssl);
706 ssl_check_error(ssl, ret);
707
708 tls_loop(sockfd, ssl);
709 }
710
711 static void
tls_call_exec_server(struct proto_conn * sock,struct proto_conn * tcp)712 tls_call_exec_server(struct proto_conn *sock, struct proto_conn *tcp)
713 {
714 int startfd, sockfd, tcpfd, safefd;
715 char *startfdstr, *debugstr;
716
717 if (pjdlog_mode_get() == PJDLOG_MODE_STD)
718 startfd = 3;
719 else /* if (pjdlog_mode_get() == PJDLOG_MODE_SYSLOG) */
720 startfd = 0;
721
722 /* Declare that we are receiver. */
723 proto_send(sock, NULL, 0);
724
725 sockfd = proto_descriptor(sock);
726 tcpfd = proto_descriptor(tcp);
727
728 safefd = MAX(sockfd, tcpfd);
729 safefd = MAX(safefd, startfd);
730 safefd++;
731
732 /* Move sockfd and tcpfd to safe numbers first. */
733 if (dup2(sockfd, safefd) == -1)
734 pjdlog_exit(EX_OSERR, "dup2() failed");
735 proto_close(sock);
736 sockfd = safefd;
737 if (dup2(tcpfd, safefd + 1) == -1)
738 pjdlog_exit(EX_OSERR, "dup2() failed");
739 proto_close(tcp);
740 tcpfd = safefd + 1;
741
742 /* Move socketpair descriptor to descriptor number startfd. */
743 if (dup2(sockfd, startfd) == -1)
744 pjdlog_exit(EX_OSERR, "dup2() failed");
745 (void)close(sockfd);
746 /* Move tcp descriptor to descriptor number startfd + 1. */
747 if (dup2(tcpfd, startfd + 1) == -1)
748 pjdlog_exit(EX_OSERR, "dup2() failed");
749 (void)close(tcpfd);
750
751 closefrom(startfd + 2);
752
753 /*
754 * Even if FD_CLOEXEC was set on descriptors before dup2(), it should
755 * have been cleared on dup2(), but better be safe than sorry.
756 */
757 if (fcntl(startfd, F_SETFD, 0) == -1)
758 pjdlog_exit(EX_OSERR, "fcntl() failed");
759 if (fcntl(startfd + 1, F_SETFD, 0) == -1)
760 pjdlog_exit(EX_OSERR, "fcntl() failed");
761
762 if (asprintf(&startfdstr, "%d", startfd) == -1)
763 pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
764 if (asprintf(&debugstr, "%d", pjdlog_debug_get()) == -1)
765 pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
766
767 execl(proto_get("execpath"), proto_get("execpath"), "proto", "tls",
768 proto_get("user"), "server", startfdstr, proto_get("tls:keyfile"),
769 proto_get("tls:certfile"), debugstr, NULL);
770 pjdlog_exit(EX_SOFTWARE, "execl() failed");
771 }
772
773 static int
tls_accept(void * ctx,void ** newctxp)774 tls_accept(void *ctx, void **newctxp)
775 {
776 struct tls_ctx *tlsctx = ctx;
777 struct tls_ctx *newtlsctx;
778 struct proto_conn *sock, *tcp;
779 pid_t pid;
780 int error;
781
782 PJDLOG_ASSERT(tlsctx != NULL);
783 PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
784 PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_SERVER_LISTEN);
785
786 if (proto_connect(NULL, "socketpair://", -1, &sock) == -1)
787 return (errno);
788
789 /* Accept TCP connection. */
790 if (proto_accept(tlsctx->tls_tcp, &tcp) == -1) {
791 error = errno;
792 proto_close(sock);
793 return (error);
794 }
795
796 pid = fork();
797 switch (pid) {
798 case -1:
799 /* Failure. */
800 error = errno;
801 proto_close(sock);
802 return (error);
803 case 0:
804 /* Child. */
805 pjdlog_prefix_set("[TLS sandbox] (server) ");
806 #ifdef HAVE_SETPROCTITLE
807 setproctitle("[TLS sandbox] (server) ");
808 #endif
809 /* Close listen socket. */
810 proto_close(tlsctx->tls_tcp);
811 tls_call_exec_server(sock, tcp);
812 /* NOTREACHED */
813 PJDLOG_ABORT("Unreachable.");
814 default:
815 /* Parent. */
816 newtlsctx = calloc(1, sizeof(*tlsctx));
817 if (newtlsctx == NULL) {
818 error = errno;
819 proto_close(sock);
820 proto_close(tcp);
821 (void)kill(pid, SIGKILL);
822 return (error);
823 }
824 proto_local_address(tcp, newtlsctx->tls_laddr,
825 sizeof(newtlsctx->tls_laddr));
826 PJDLOG_ASSERT(strncmp(newtlsctx->tls_laddr, "tcp://", 6) == 0);
827 bcopy("tls://", newtlsctx->tls_laddr, 6);
828 *strrchr(newtlsctx->tls_laddr, ':') = '\0';
829 proto_remote_address(tcp, newtlsctx->tls_raddr,
830 sizeof(newtlsctx->tls_raddr));
831 PJDLOG_ASSERT(strncmp(newtlsctx->tls_raddr, "tcp://", 6) == 0);
832 bcopy("tls://", newtlsctx->tls_raddr, 6);
833 *strrchr(newtlsctx->tls_raddr, ':') = '\0';
834 proto_close(tcp);
835 proto_recv(sock, NULL, 0);
836 newtlsctx->tls_sock = sock;
837 newtlsctx->tls_tcp = NULL;
838 newtlsctx->tls_wait_called = true;
839 newtlsctx->tls_side = TLS_SIDE_SERVER_WORK;
840 newtlsctx->tls_magic = TLS_CTX_MAGIC;
841 *newctxp = newtlsctx;
842 return (0);
843 }
844 }
845
846 static int
tls_wrap(int fd,bool client,void ** ctxp)847 tls_wrap(int fd, bool client, void **ctxp)
848 {
849 struct tls_ctx *tlsctx;
850 struct proto_conn *sock;
851 int error;
852
853 tlsctx = calloc(1, sizeof(*tlsctx));
854 if (tlsctx == NULL)
855 return (errno);
856
857 if (proto_wrap("socketpair", client, fd, &sock) == -1) {
858 error = errno;
859 free(tlsctx);
860 return (error);
861 }
862
863 tlsctx->tls_sock = sock;
864 tlsctx->tls_tcp = NULL;
865 tlsctx->tls_wait_called = (client ? false : true);
866 tlsctx->tls_side = (client ? TLS_SIDE_CLIENT : TLS_SIDE_SERVER_WORK);
867 tlsctx->tls_magic = TLS_CTX_MAGIC;
868 *ctxp = tlsctx;
869
870 return (0);
871 }
872
873 static int
tls_send(void * ctx,const unsigned char * data,size_t size,int fd)874 tls_send(void *ctx, const unsigned char *data, size_t size, int fd)
875 {
876 struct tls_ctx *tlsctx = ctx;
877
878 PJDLOG_ASSERT(tlsctx != NULL);
879 PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
880 PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT ||
881 tlsctx->tls_side == TLS_SIDE_SERVER_WORK);
882 PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
883 PJDLOG_ASSERT(tlsctx->tls_wait_called);
884 PJDLOG_ASSERT(fd == -1);
885
886 if (proto_send(tlsctx->tls_sock, data, size) == -1)
887 return (errno);
888
889 return (0);
890 }
891
892 static int
tls_recv(void * ctx,unsigned char * data,size_t size,int * fdp)893 tls_recv(void *ctx, unsigned char *data, size_t size, int *fdp)
894 {
895 struct tls_ctx *tlsctx = ctx;
896
897 PJDLOG_ASSERT(tlsctx != NULL);
898 PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
899 PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT ||
900 tlsctx->tls_side == TLS_SIDE_SERVER_WORK);
901 PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
902 PJDLOG_ASSERT(tlsctx->tls_wait_called);
903 PJDLOG_ASSERT(fdp == NULL);
904
905 if (proto_recv(tlsctx->tls_sock, data, size) == -1)
906 return (errno);
907
908 return (0);
909 }
910
911 static int
tls_descriptor(const void * ctx)912 tls_descriptor(const void *ctx)
913 {
914 const struct tls_ctx *tlsctx = ctx;
915
916 PJDLOG_ASSERT(tlsctx != NULL);
917 PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
918
919 switch (tlsctx->tls_side) {
920 case TLS_SIDE_CLIENT:
921 case TLS_SIDE_SERVER_WORK:
922 PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
923
924 return (proto_descriptor(tlsctx->tls_sock));
925 case TLS_SIDE_SERVER_LISTEN:
926 PJDLOG_ASSERT(tlsctx->tls_tcp != NULL);
927
928 return (proto_descriptor(tlsctx->tls_tcp));
929 default:
930 PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side);
931 }
932 }
933
934 static bool
tcp_address_match(const void * ctx,const char * addr)935 tcp_address_match(const void *ctx, const char *addr)
936 {
937 const struct tls_ctx *tlsctx = ctx;
938
939 PJDLOG_ASSERT(tlsctx != NULL);
940 PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
941
942 return (strcmp(tlsctx->tls_raddr, addr) == 0);
943 }
944
945 static void
tls_local_address(const void * ctx,char * addr,size_t size)946 tls_local_address(const void *ctx, char *addr, size_t size)
947 {
948 const struct tls_ctx *tlsctx = ctx;
949
950 PJDLOG_ASSERT(tlsctx != NULL);
951 PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
952 PJDLOG_ASSERT(tlsctx->tls_wait_called);
953
954 switch (tlsctx->tls_side) {
955 case TLS_SIDE_CLIENT:
956 PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
957
958 PJDLOG_VERIFY(strlcpy(addr, "tls://N/A", size) < size);
959 break;
960 case TLS_SIDE_SERVER_WORK:
961 PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
962
963 PJDLOG_VERIFY(strlcpy(addr, tlsctx->tls_laddr, size) < size);
964 break;
965 case TLS_SIDE_SERVER_LISTEN:
966 PJDLOG_ASSERT(tlsctx->tls_tcp != NULL);
967
968 proto_local_address(tlsctx->tls_tcp, addr, size);
969 PJDLOG_ASSERT(strncmp(addr, "tcp://", 6) == 0);
970 /* Replace tcp:// prefix with tls:// */
971 bcopy("tls://", addr, 6);
972 break;
973 default:
974 PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side);
975 }
976 }
977
978 static void
tls_remote_address(const void * ctx,char * addr,size_t size)979 tls_remote_address(const void *ctx, char *addr, size_t size)
980 {
981 const struct tls_ctx *tlsctx = ctx;
982
983 PJDLOG_ASSERT(tlsctx != NULL);
984 PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
985 PJDLOG_ASSERT(tlsctx->tls_wait_called);
986
987 switch (tlsctx->tls_side) {
988 case TLS_SIDE_CLIENT:
989 PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
990
991 PJDLOG_VERIFY(strlcpy(addr, "tls://N/A", size) < size);
992 break;
993 case TLS_SIDE_SERVER_WORK:
994 PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
995
996 PJDLOG_VERIFY(strlcpy(addr, tlsctx->tls_raddr, size) < size);
997 break;
998 case TLS_SIDE_SERVER_LISTEN:
999 PJDLOG_ASSERT(tlsctx->tls_tcp != NULL);
1000
1001 proto_remote_address(tlsctx->tls_tcp, addr, size);
1002 PJDLOG_ASSERT(strncmp(addr, "tcp://", 6) == 0);
1003 /* Replace tcp:// prefix with tls:// */
1004 bcopy("tls://", addr, 6);
1005 break;
1006 default:
1007 PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side);
1008 }
1009 }
1010
1011 static void
tls_close(void * ctx)1012 tls_close(void *ctx)
1013 {
1014 struct tls_ctx *tlsctx = ctx;
1015
1016 PJDLOG_ASSERT(tlsctx != NULL);
1017 PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
1018
1019 if (tlsctx->tls_sock != NULL) {
1020 proto_close(tlsctx->tls_sock);
1021 tlsctx->tls_sock = NULL;
1022 }
1023 if (tlsctx->tls_tcp != NULL) {
1024 proto_close(tlsctx->tls_tcp);
1025 tlsctx->tls_tcp = NULL;
1026 }
1027 tlsctx->tls_side = 0;
1028 tlsctx->tls_magic = 0;
1029 free(tlsctx);
1030 }
1031
1032 static int
tls_exec(int argc,char * argv[])1033 tls_exec(int argc, char *argv[])
1034 {
1035
1036 PJDLOG_ASSERT(argc > 3);
1037 PJDLOG_ASSERT(strcmp(argv[0], "tls") == 0);
1038
1039 pjdlog_init(atoi(argv[3]) == 0 ? PJDLOG_MODE_SYSLOG : PJDLOG_MODE_STD);
1040
1041 if (strcmp(argv[2], "client") == 0) {
1042 if (argc != 10)
1043 return (EINVAL);
1044 tls_exec_client(argv[1], atoi(argv[3]),
1045 argv[4][0] == '\0' ? NULL : argv[4], argv[5], argv[6],
1046 argv[7], atoi(argv[8]), atoi(argv[9]));
1047 } else if (strcmp(argv[2], "server") == 0) {
1048 if (argc != 7)
1049 return (EINVAL);
1050 tls_exec_server(argv[1], atoi(argv[3]), argv[4], argv[5],
1051 atoi(argv[6]));
1052 }
1053 return (EINVAL);
1054 }
1055
1056 static struct proto tls_proto = {
1057 .prt_name = "tls",
1058 .prt_connect = tls_connect,
1059 .prt_connect_wait = tls_connect_wait,
1060 .prt_server = tls_server,
1061 .prt_accept = tls_accept,
1062 .prt_wrap = tls_wrap,
1063 .prt_send = tls_send,
1064 .prt_recv = tls_recv,
1065 .prt_descriptor = tls_descriptor,
1066 .prt_address_match = tcp_address_match,
1067 .prt_local_address = tls_local_address,
1068 .prt_remote_address = tls_remote_address,
1069 .prt_close = tls_close,
1070 .prt_exec = tls_exec
1071 };
1072
1073 static __constructor void
tls_ctor(void)1074 tls_ctor(void)
1075 {
1076
1077 proto_register(&tls_proto, false);
1078 }
1079