xref: /freebsd/contrib/openbsm/bin/auditdistd/proto_tls.c (revision bc5304a006238115291e7568583632889dffbab9)
1 /*-
2  * Copyright (c) 2011 The FreeBSD Foundation
3  * All rights reserved.
4  *
5  * This software was developed by Pawel Jakub Dawidek under sponsorship from
6  * the FreeBSD Foundation.
7  *
8  * Redistribution and use in source and binary forms, with or without
9  * modification, are permitted provided that the following conditions
10  * are met:
11  * 1. Redistributions of source code must retain the above copyright
12  *    notice, this list of conditions and the following disclaimer.
13  * 2. Redistributions in binary form must reproduce the above copyright
14  *    notice, this list of conditions and the following disclaimer in the
15  *    documentation and/or other materials provided with the distribution.
16  *
17  * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
18  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
21  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
23  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
24  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
26  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
27  * SUCH DAMAGE.
28  */
29 
30 #include <config/config.h>
31 
32 #include <sys/param.h>	/* MAXHOSTNAMELEN */
33 #include <sys/socket.h>
34 
35 #include <arpa/inet.h>
36 
37 #include <netinet/in.h>
38 #include <netinet/tcp.h>
39 
40 #include <errno.h>
41 #include <fcntl.h>
42 #include <netdb.h>
43 #include <signal.h>
44 #include <stdbool.h>
45 #include <stdint.h>
46 #include <stdio.h>
47 #include <string.h>
48 #include <unistd.h>
49 
50 #include <openssl/err.h>
51 #include <openssl/ssl.h>
52 
53 #include <compat/compat.h>
54 #ifndef HAVE_CLOSEFROM
55 #include <compat/closefrom.h>
56 #endif
57 #ifndef HAVE_STRLCPY
58 #include <compat/strlcpy.h>
59 #endif
60 
61 #include "pjdlog.h"
62 #include "proto_impl.h"
63 #include "sandbox.h"
64 #include "subr.h"
65 
66 #define	TLS_CTX_MAGIC	0x715c7
67 struct tls_ctx {
68 	int		tls_magic;
69 	struct proto_conn *tls_sock;
70 	struct proto_conn *tls_tcp;
71 	char		tls_laddr[256];
72 	char		tls_raddr[256];
73 	int		tls_side;
74 #define	TLS_SIDE_CLIENT		0
75 #define	TLS_SIDE_SERVER_LISTEN	1
76 #define	TLS_SIDE_SERVER_WORK	2
77 	bool		tls_wait_called;
78 };
79 
80 #define	TLS_DEFAULT_TIMEOUT	30
81 
82 static int tls_connect_wait(void *ctx, int timeout);
83 static void tls_close(void *ctx);
84 
85 static void
86 block(int fd)
87 {
88 	int flags;
89 
90 	flags = fcntl(fd, F_GETFL);
91 	if (flags == -1)
92 		pjdlog_exit(EX_TEMPFAIL, "fcntl(F_GETFL) failed");
93 	flags &= ~O_NONBLOCK;
94 	if (fcntl(fd, F_SETFL, flags) == -1)
95 		pjdlog_exit(EX_TEMPFAIL, "fcntl(F_SETFL) failed");
96 }
97 
98 static void
99 nonblock(int fd)
100 {
101 	int flags;
102 
103 	flags = fcntl(fd, F_GETFL);
104 	if (flags == -1)
105 		pjdlog_exit(EX_TEMPFAIL, "fcntl(F_GETFL) failed");
106 	flags |= O_NONBLOCK;
107 	if (fcntl(fd, F_SETFL, flags) == -1)
108 		pjdlog_exit(EX_TEMPFAIL, "fcntl(F_SETFL) failed");
109 }
110 
111 static int
112 wait_for_fd(int fd, int timeout)
113 {
114 	struct timeval tv;
115 	fd_set fdset;
116 	int error, ret;
117 
118 	error = 0;
119 
120 	for (;;) {
121 		FD_ZERO(&fdset);
122 		FD_SET(fd, &fdset);
123 
124 		tv.tv_sec = timeout;
125 		tv.tv_usec = 0;
126 
127 		ret = select(fd + 1, NULL, &fdset, NULL,
128 		    timeout == -1 ? NULL : &tv);
129 		if (ret == 0) {
130 			error = ETIMEDOUT;
131 			break;
132 		} else if (ret == -1) {
133 			if (errno == EINTR)
134 				continue;
135 			error = errno;
136 			break;
137 		}
138 		PJDLOG_ASSERT(ret > 0);
139 		PJDLOG_ASSERT(FD_ISSET(fd, &fdset));
140 		break;
141 	}
142 
143 	return (error);
144 }
145 
146 static void
147 ssl_log_errors(void)
148 {
149 	unsigned long error;
150 
151 	while ((error = ERR_get_error()) != 0)
152 		pjdlog_error("SSL error: %s", ERR_error_string(error, NULL));
153 }
154 
155 static int
156 ssl_check_error(SSL *ssl, int ret)
157 {
158 	int error;
159 
160 	error = SSL_get_error(ssl, ret);
161 
162 	switch (error) {
163 	case SSL_ERROR_NONE:
164 		return (0);
165 	case SSL_ERROR_WANT_READ:
166 		pjdlog_debug(2, "SSL_ERROR_WANT_READ");
167 		return (-1);
168 	case SSL_ERROR_WANT_WRITE:
169 		pjdlog_debug(2, "SSL_ERROR_WANT_WRITE");
170 		return (-1);
171 	case SSL_ERROR_ZERO_RETURN:
172 		pjdlog_exitx(EX_OK, "Connection closed.");
173 	case SSL_ERROR_SYSCALL:
174 		ssl_log_errors();
175 		pjdlog_exitx(EX_TEMPFAIL, "SSL I/O error.");
176 	case SSL_ERROR_SSL:
177 		ssl_log_errors();
178 		pjdlog_exitx(EX_TEMPFAIL, "SSL protocol error.");
179 	default:
180 		ssl_log_errors();
181 		pjdlog_exitx(EX_TEMPFAIL, "Unknown SSL error (%d).", error);
182 	}
183 }
184 
185 static void
186 tcp_recv_ssl_send(int recvfd, SSL *sendssl)
187 {
188 	static unsigned char buf[65536];
189 	ssize_t tcpdone;
190 	int sendfd, ssldone;
191 
192 	sendfd = SSL_get_fd(sendssl);
193 	PJDLOG_ASSERT(sendfd >= 0);
194 	pjdlog_debug(2, "%s: start %d -> %d", __func__, recvfd, sendfd);
195 	for (;;) {
196 		tcpdone = recv(recvfd, buf, sizeof(buf), 0);
197 		pjdlog_debug(2, "%s: recv() returned %zd", __func__, tcpdone);
198 		if (tcpdone == 0) {
199 			pjdlog_debug(1, "Connection terminated.");
200 			exit(0);
201 		} else if (tcpdone == -1) {
202 			if (errno == EINTR)
203 				continue;
204 			else if (errno == EAGAIN)
205 				break;
206 			pjdlog_exit(EX_TEMPFAIL, "recv() failed");
207 		}
208 		for (;;) {
209 			ssldone = SSL_write(sendssl, buf, (int)tcpdone);
210 			pjdlog_debug(2, "%s: send() returned %d", __func__,
211 			    ssldone);
212 			if (ssl_check_error(sendssl, ssldone) == -1) {
213 				(void)wait_for_fd(sendfd, -1);
214 				continue;
215 			}
216 			PJDLOG_ASSERT(ssldone == tcpdone);
217 			break;
218 		}
219 	}
220 	pjdlog_debug(2, "%s: done %d -> %d", __func__, recvfd, sendfd);
221 }
222 
223 static void
224 ssl_recv_tcp_send(SSL *recvssl, int sendfd)
225 {
226 	static unsigned char buf[65536];
227 	unsigned char *ptr;
228 	ssize_t tcpdone;
229 	size_t todo;
230 	int recvfd, ssldone;
231 
232 	recvfd = SSL_get_fd(recvssl);
233 	PJDLOG_ASSERT(recvfd >= 0);
234 	pjdlog_debug(2, "%s: start %d -> %d", __func__, recvfd, sendfd);
235 	for (;;) {
236 		ssldone = SSL_read(recvssl, buf, sizeof(buf));
237 		pjdlog_debug(2, "%s: SSL_read() returned %d", __func__,
238 		    ssldone);
239 		if (ssl_check_error(recvssl, ssldone) == -1)
240 			break;
241 		todo = (size_t)ssldone;
242 		ptr = buf;
243 		do {
244 			tcpdone = send(sendfd, ptr, todo, MSG_NOSIGNAL);
245 			pjdlog_debug(2, "%s: send() returned %zd", __func__,
246 			    tcpdone);
247 			if (tcpdone == 0) {
248 				pjdlog_debug(1, "Connection terminated.");
249 				exit(0);
250 			} else if (tcpdone == -1) {
251 				if (errno == EINTR || errno == ENOBUFS)
252 					continue;
253 				if (errno == EAGAIN) {
254 					(void)wait_for_fd(sendfd, -1);
255 					continue;
256 				}
257 				pjdlog_exit(EX_TEMPFAIL, "send() failed");
258 			}
259 			todo -= tcpdone;
260 			ptr += tcpdone;
261 		} while (todo > 0);
262 	}
263 	pjdlog_debug(2, "%s: done %d -> %d", __func__, recvfd, sendfd);
264 }
265 
266 static void
267 tls_loop(int sockfd, SSL *tcpssl)
268 {
269 	fd_set fds;
270 	int maxfd, tcpfd;
271 
272 	tcpfd = SSL_get_fd(tcpssl);
273 	PJDLOG_ASSERT(tcpfd >= 0);
274 
275 	for (;;) {
276 		FD_ZERO(&fds);
277 		FD_SET(sockfd, &fds);
278 		FD_SET(tcpfd, &fds);
279 		maxfd = MAX(sockfd, tcpfd);
280 
281 		PJDLOG_ASSERT(maxfd + 1 <= (int)FD_SETSIZE);
282 		if (select(maxfd + 1, &fds, NULL, NULL, NULL) == -1) {
283 			if (errno == EINTR)
284 				continue;
285 			pjdlog_exit(EX_TEMPFAIL, "select() failed");
286 		}
287 		if (FD_ISSET(sockfd, &fds))
288 			tcp_recv_ssl_send(sockfd, tcpssl);
289 		if (FD_ISSET(tcpfd, &fds))
290 			ssl_recv_tcp_send(tcpssl, sockfd);
291 	}
292 }
293 
294 static void
295 tls_certificate_verify(SSL *ssl, const char *fingerprint)
296 {
297 	unsigned char md[EVP_MAX_MD_SIZE];
298 	char mdstr[sizeof("SHA256=") - 1 + EVP_MAX_MD_SIZE * 3];
299 	char *mdstrp;
300 	unsigned int i, mdsize;
301 	X509 *cert;
302 
303 	if (fingerprint[0] == '\0') {
304 		pjdlog_debug(1, "No fingerprint verification requested.");
305 		return;
306 	}
307 
308 	cert = SSL_get_peer_certificate(ssl);
309 	if (cert == NULL)
310 		pjdlog_exitx(EX_TEMPFAIL, "No peer certificate received.");
311 
312 	if (X509_digest(cert, EVP_sha256(), md, &mdsize) != 1)
313 		pjdlog_exitx(EX_TEMPFAIL, "X509_digest() failed.");
314 	PJDLOG_ASSERT(mdsize <= EVP_MAX_MD_SIZE);
315 
316 	X509_free(cert);
317 
318 	(void)strlcpy(mdstr, "SHA256=", sizeof(mdstr));
319 	mdstrp = mdstr + strlen(mdstr);
320 	for (i = 0; i < mdsize; i++) {
321 		PJDLOG_VERIFY(mdstrp + 3 <= mdstr + sizeof(mdstr));
322 		(void)sprintf(mdstrp, "%02hhX:", md[i]);
323 		mdstrp += 3;
324 	}
325 	/* Clear last colon. */
326 	mdstrp[-1] = '\0';
327 	if (strcasecmp(mdstr, fingerprint) != 0) {
328 		pjdlog_exitx(EX_NOPERM,
329 		    "Finger print doesn't match. Received \"%s\", expected \"%s\"",
330 		    mdstr, fingerprint);
331 	}
332 }
333 
334 static void
335 tls_exec_client(const char *user, int startfd, const char *srcaddr,
336     const char *dstaddr, const char *fingerprint, const char *defport,
337     int timeout, int debuglevel)
338 {
339 	struct proto_conn *tcp;
340 	char *saddr, *daddr;
341 	SSL_CTX *sslctx;
342 	SSL *ssl;
343 	long ret;
344 	int sockfd, tcpfd;
345 	uint8_t connected;
346 
347 	pjdlog_debug_set(debuglevel);
348 	pjdlog_prefix_set("[TLS sandbox] (client) ");
349 #ifdef HAVE_SETPROCTITLE
350 	setproctitle("[TLS sandbox] (client) ");
351 #endif
352 	proto_set("tcp:port", defport);
353 
354 	sockfd = startfd;
355 
356 	/* Change tls:// to tcp://. */
357 	if (srcaddr == NULL) {
358 		saddr = NULL;
359 	} else {
360 		saddr = strdup(srcaddr);
361 		if (saddr == NULL)
362 			pjdlog_exitx(EX_TEMPFAIL, "Unable to allocate memory.");
363 		bcopy("tcp://", saddr, 6);
364 	}
365 	daddr = strdup(dstaddr);
366 	if (daddr == NULL)
367 		pjdlog_exitx(EX_TEMPFAIL, "Unable to allocate memory.");
368 	bcopy("tcp://", daddr, 6);
369 
370 	/* Establish TCP connection. */
371 	if (proto_connect(saddr, daddr, timeout, &tcp) == -1)
372 		exit(EX_TEMPFAIL);
373 
374 	SSL_load_error_strings();
375 	SSL_library_init();
376 
377 	/*
378 	 * TODO: On FreeBSD we could move this below sandbox() once libc and
379 	 *       libcrypto use sysctl kern.arandom to obtain random data
380 	 *       instead of /dev/urandom and friends.
381 	 */
382 	sslctx = SSL_CTX_new(TLS_client_method());
383 	if (sslctx == NULL)
384 		pjdlog_exitx(EX_TEMPFAIL, "SSL_CTX_new() failed.");
385 
386 	if (sandbox(user, true, "proto_tls client: %s", dstaddr) != 0)
387 		pjdlog_exitx(EX_CONFIG, "Unable to sandbox TLS client.");
388 	pjdlog_debug(1, "Privileges successfully dropped.");
389 
390 	SSL_CTX_set_options(sslctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);
391 
392 	/* Load CA certs. */
393 	/* TODO */
394 	//SSL_CTX_load_verify_locations(sslctx, cacerts_file, NULL);
395 
396 	ssl = SSL_new(sslctx);
397 	if (ssl == NULL)
398 		pjdlog_exitx(EX_TEMPFAIL, "SSL_new() failed.");
399 
400 	tcpfd = proto_descriptor(tcp);
401 
402 	block(tcpfd);
403 
404 	if (SSL_set_fd(ssl, tcpfd) != 1)
405 		pjdlog_exitx(EX_TEMPFAIL, "SSL_set_fd() failed.");
406 
407 	ret = SSL_connect(ssl);
408 	ssl_check_error(ssl, (int)ret);
409 
410 	nonblock(sockfd);
411 	nonblock(tcpfd);
412 
413 	tls_certificate_verify(ssl, fingerprint);
414 
415 	/*
416 	 * The following byte is sent to make proto_connect_wait() work.
417 	 */
418 	connected = 1;
419 	for (;;) {
420 		switch (send(sockfd, &connected, sizeof(connected), 0)) {
421 		case -1:
422 			if (errno == EINTR || errno == ENOBUFS)
423 				continue;
424 			if (errno == EAGAIN) {
425 				(void)wait_for_fd(sockfd, -1);
426 				continue;
427 			}
428 			pjdlog_exit(EX_TEMPFAIL, "send() failed");
429 		case 0:
430 			pjdlog_debug(1, "Connection terminated.");
431 			exit(0);
432 		case 1:
433 			break;
434 		}
435 		break;
436 	}
437 
438 	tls_loop(sockfd, ssl);
439 }
440 
441 static void
442 tls_call_exec_client(struct proto_conn *sock, const char *srcaddr,
443     const char *dstaddr, int timeout)
444 {
445 	char *timeoutstr, *startfdstr, *debugstr;
446 	int startfd;
447 
448 	/* Declare that we are receiver. */
449 	proto_recv(sock, NULL, 0);
450 
451 	if (pjdlog_mode_get() == PJDLOG_MODE_STD)
452 		startfd = 3;
453 	else /* if (pjdlog_mode_get() == PJDLOG_MODE_SYSLOG) */
454 		startfd = 0;
455 
456 	if (proto_descriptor(sock) != startfd) {
457 		/* Move socketpair descriptor to descriptor number startfd. */
458 		if (dup2(proto_descriptor(sock), startfd) == -1)
459 			pjdlog_exit(EX_OSERR, "dup2() failed");
460 		proto_close(sock);
461 	} else {
462 		/*
463 		 * The FD_CLOEXEC is cleared by dup2(2), so when we do not
464 		 * call it, we have to clear it by hand in case it is set.
465 		 */
466 		if (fcntl(startfd, F_SETFD, 0) == -1)
467 			pjdlog_exit(EX_OSERR, "fcntl() failed");
468 	}
469 
470 	closefrom(startfd + 1);
471 
472 	if (asprintf(&startfdstr, "%d", startfd) == -1)
473 		pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
474 	if (timeout == -1)
475 		timeout = TLS_DEFAULT_TIMEOUT;
476 	if (asprintf(&timeoutstr, "%d", timeout) == -1)
477 		pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
478 	if (asprintf(&debugstr, "%d", pjdlog_debug_get()) == -1)
479 		pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
480 
481 	execl(proto_get("execpath"), proto_get("execpath"), "proto", "tls",
482 	    proto_get("user"), "client", startfdstr,
483 	    srcaddr == NULL ? "" : srcaddr, dstaddr,
484 	    proto_get("tls:fingerprint"), proto_get("tcp:port"), timeoutstr,
485 	    debugstr, NULL);
486 	pjdlog_exit(EX_SOFTWARE, "execl() failed");
487 }
488 
489 static int
490 tls_connect(const char *srcaddr, const char *dstaddr, int timeout, void **ctxp)
491 {
492 	struct tls_ctx *tlsctx;
493 	struct proto_conn *sock;
494 	pid_t pid;
495 	int error;
496 
497 	PJDLOG_ASSERT(srcaddr == NULL || srcaddr[0] != '\0');
498 	PJDLOG_ASSERT(dstaddr != NULL);
499 	PJDLOG_ASSERT(timeout >= -1);
500 	PJDLOG_ASSERT(ctxp != NULL);
501 
502 	if (strncmp(dstaddr, "tls://", 6) != 0)
503 		return (-1);
504 	if (srcaddr != NULL && strncmp(srcaddr, "tls://", 6) != 0)
505 		return (-1);
506 
507 	if (proto_connect(NULL, "socketpair://", -1, &sock) == -1)
508 		return (errno);
509 
510 #if 0
511 	/*
512 	 * We use rfork() with the following flags to disable SIGCHLD
513 	 * delivery upon the sandbox process exit.
514 	 */
515 	pid = rfork(RFFDG | RFPROC | RFTSIGZMB | RFTSIGFLAGS(0));
516 #else
517 	/*
518 	 * We don't use rfork() to be able to log information about sandbox
519 	 * process exiting.
520 	 */
521 	pid = fork();
522 #endif
523 	switch (pid) {
524 	case -1:
525 		/* Failure. */
526 		error = errno;
527 		proto_close(sock);
528 		return (error);
529 	case 0:
530 		/* Child. */
531 		pjdlog_prefix_set("[TLS sandbox] (client) ");
532 #ifdef HAVE_SETPROCTITLE
533 		setproctitle("[TLS sandbox] (client) ");
534 #endif
535 		tls_call_exec_client(sock, srcaddr, dstaddr, timeout);
536 		/* NOTREACHED */
537 	default:
538 		/* Parent. */
539 		tlsctx = calloc(1, sizeof(*tlsctx));
540 		if (tlsctx == NULL) {
541 			error = errno;
542 			proto_close(sock);
543 			(void)kill(pid, SIGKILL);
544 			return (error);
545 		}
546 		proto_send(sock, NULL, 0);
547 		tlsctx->tls_sock = sock;
548 		tlsctx->tls_tcp = NULL;
549 		tlsctx->tls_side = TLS_SIDE_CLIENT;
550 		tlsctx->tls_wait_called = false;
551 		tlsctx->tls_magic = TLS_CTX_MAGIC;
552 		if (timeout >= 0) {
553 			error = tls_connect_wait(tlsctx, timeout);
554 			if (error != 0) {
555 				(void)kill(pid, SIGKILL);
556 				tls_close(tlsctx);
557 				return (error);
558 			}
559 		}
560 		*ctxp = tlsctx;
561 		return (0);
562 	}
563 }
564 
565 static int
566 tls_connect_wait(void *ctx, int timeout)
567 {
568 	struct tls_ctx *tlsctx = ctx;
569 	int error, sockfd;
570 	uint8_t connected;
571 
572 	PJDLOG_ASSERT(tlsctx != NULL);
573 	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
574 	PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT);
575 	PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
576 	PJDLOG_ASSERT(!tlsctx->tls_wait_called);
577 	PJDLOG_ASSERT(timeout >= 0);
578 
579 	sockfd = proto_descriptor(tlsctx->tls_sock);
580 	error = wait_for_fd(sockfd, timeout);
581 	if (error != 0)
582 		return (error);
583 
584 	for (;;) {
585 		switch (recv(sockfd, &connected, sizeof(connected),
586 		    MSG_WAITALL)) {
587 		case -1:
588 			if (errno == EINTR || errno == ENOBUFS)
589 				continue;
590 			error = errno;
591 			break;
592 		case 0:
593 			pjdlog_debug(1, "Connection terminated.");
594 			error = ENOTCONN;
595 			break;
596 		case 1:
597 			tlsctx->tls_wait_called = true;
598 			break;
599 		}
600 		break;
601 	}
602 
603 	return (error);
604 }
605 
606 static int
607 tls_server(const char *lstaddr, void **ctxp)
608 {
609 	struct proto_conn *tcp;
610 	struct tls_ctx *tlsctx;
611 	char *laddr;
612 	int error;
613 
614 	if (strncmp(lstaddr, "tls://", 6) != 0)
615 		return (-1);
616 
617 	tlsctx = malloc(sizeof(*tlsctx));
618 	if (tlsctx == NULL) {
619 		pjdlog_warning("Unable to allocate memory.");
620 		return (ENOMEM);
621 	}
622 
623 	laddr = strdup(lstaddr);
624 	if (laddr == NULL) {
625 		free(tlsctx);
626 		pjdlog_warning("Unable to allocate memory.");
627 		return (ENOMEM);
628 	}
629 	bcopy("tcp://", laddr, 6);
630 
631 	if (proto_server(laddr, &tcp) == -1) {
632 		error = errno;
633 		free(tlsctx);
634 		free(laddr);
635 		return (error);
636 	}
637 	free(laddr);
638 
639 	tlsctx->tls_sock = NULL;
640 	tlsctx->tls_tcp = tcp;
641 	tlsctx->tls_side = TLS_SIDE_SERVER_LISTEN;
642 	tlsctx->tls_wait_called = true;
643 	tlsctx->tls_magic = TLS_CTX_MAGIC;
644 	*ctxp = tlsctx;
645 
646 	return (0);
647 }
648 
649 static void
650 tls_exec_server(const char *user, int startfd, const char *privkey,
651     const char *cert, int debuglevel)
652 {
653 	SSL_CTX *sslctx;
654 	SSL *ssl;
655 	int sockfd, tcpfd, ret;
656 
657 	pjdlog_debug_set(debuglevel);
658 	pjdlog_prefix_set("[TLS sandbox] (server) ");
659 #ifdef HAVE_SETPROCTITLE
660 	setproctitle("[TLS sandbox] (server) ");
661 #endif
662 
663 	sockfd = startfd;
664 	tcpfd = startfd + 1;
665 
666 	SSL_load_error_strings();
667 	SSL_library_init();
668 
669 	sslctx = SSL_CTX_new(TLS_server_method());
670 	if (sslctx == NULL)
671 		pjdlog_exitx(EX_TEMPFAIL, "SSL_CTX_new() failed.");
672 
673 	SSL_CTX_set_options(sslctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);
674 
675 	ssl = SSL_new(sslctx);
676 	if (ssl == NULL)
677 		pjdlog_exitx(EX_TEMPFAIL, "SSL_new() failed.");
678 
679 	if (SSL_use_RSAPrivateKey_file(ssl, privkey, SSL_FILETYPE_PEM) != 1) {
680 		ssl_log_errors();
681 		pjdlog_exitx(EX_CONFIG,
682 		    "SSL_use_RSAPrivateKey_file(%s) failed.", privkey);
683 	}
684 
685 	if (SSL_use_certificate_file(ssl, cert, SSL_FILETYPE_PEM) != 1) {
686 		ssl_log_errors();
687 		pjdlog_exitx(EX_CONFIG, "SSL_use_certificate_file(%s) failed.",
688 		    cert);
689 	}
690 
691 	if (sandbox(user, true, "proto_tls server") != 0)
692 		pjdlog_exitx(EX_CONFIG, "Unable to sandbox TLS server.");
693 	pjdlog_debug(1, "Privileges successfully dropped.");
694 
695 	nonblock(sockfd);
696 	nonblock(tcpfd);
697 
698 	if (SSL_set_fd(ssl, tcpfd) != 1)
699 		pjdlog_exitx(EX_TEMPFAIL, "SSL_set_fd() failed.");
700 
701 	ret = SSL_accept(ssl);
702 	ssl_check_error(ssl, ret);
703 
704 	tls_loop(sockfd, ssl);
705 }
706 
707 static void
708 tls_call_exec_server(struct proto_conn *sock, struct proto_conn *tcp)
709 {
710 	int startfd, sockfd, tcpfd, safefd;
711 	char *startfdstr, *debugstr;
712 
713 	if (pjdlog_mode_get() == PJDLOG_MODE_STD)
714 		startfd = 3;
715 	else /* if (pjdlog_mode_get() == PJDLOG_MODE_SYSLOG) */
716 		startfd = 0;
717 
718 	/* Declare that we are receiver. */
719 	proto_send(sock, NULL, 0);
720 
721 	sockfd = proto_descriptor(sock);
722 	tcpfd = proto_descriptor(tcp);
723 
724 	safefd = MAX(sockfd, tcpfd);
725 	safefd = MAX(safefd, startfd);
726 	safefd++;
727 
728 	/* Move sockfd and tcpfd to safe numbers first. */
729 	if (dup2(sockfd, safefd) == -1)
730 		pjdlog_exit(EX_OSERR, "dup2() failed");
731 	proto_close(sock);
732 	sockfd = safefd;
733 	if (dup2(tcpfd, safefd + 1) == -1)
734 		pjdlog_exit(EX_OSERR, "dup2() failed");
735 	proto_close(tcp);
736 	tcpfd = safefd + 1;
737 
738 	/* Move socketpair descriptor to descriptor number startfd. */
739 	if (dup2(sockfd, startfd) == -1)
740 		pjdlog_exit(EX_OSERR, "dup2() failed");
741 	(void)close(sockfd);
742 	/* Move tcp descriptor to descriptor number startfd + 1. */
743 	if (dup2(tcpfd, startfd + 1) == -1)
744 		pjdlog_exit(EX_OSERR, "dup2() failed");
745 	(void)close(tcpfd);
746 
747 	closefrom(startfd + 2);
748 
749 	/*
750 	 * Even if FD_CLOEXEC was set on descriptors before dup2(), it should
751 	 * have been cleared on dup2(), but better be safe than sorry.
752 	 */
753 	if (fcntl(startfd, F_SETFD, 0) == -1)
754 		pjdlog_exit(EX_OSERR, "fcntl() failed");
755 	if (fcntl(startfd + 1, F_SETFD, 0) == -1)
756 		pjdlog_exit(EX_OSERR, "fcntl() failed");
757 
758 	if (asprintf(&startfdstr, "%d", startfd) == -1)
759 		pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
760 	if (asprintf(&debugstr, "%d", pjdlog_debug_get()) == -1)
761 		pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
762 
763 	execl(proto_get("execpath"), proto_get("execpath"), "proto", "tls",
764 	    proto_get("user"), "server", startfdstr, proto_get("tls:keyfile"),
765 	    proto_get("tls:certfile"), debugstr, NULL);
766 	pjdlog_exit(EX_SOFTWARE, "execl() failed");
767 }
768 
769 static int
770 tls_accept(void *ctx, void **newctxp)
771 {
772 	struct tls_ctx *tlsctx = ctx;
773 	struct tls_ctx *newtlsctx;
774 	struct proto_conn *sock, *tcp;
775 	pid_t pid;
776 	int error;
777 
778 	PJDLOG_ASSERT(tlsctx != NULL);
779 	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
780 	PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_SERVER_LISTEN);
781 
782 	if (proto_connect(NULL, "socketpair://", -1, &sock) == -1)
783 		return (errno);
784 
785 	/* Accept TCP connection. */
786 	if (proto_accept(tlsctx->tls_tcp, &tcp) == -1) {
787 		error = errno;
788 		proto_close(sock);
789 		return (error);
790 	}
791 
792 	pid = fork();
793 	switch (pid) {
794 	case -1:
795 		/* Failure. */
796 		error = errno;
797 		proto_close(sock);
798 		return (error);
799 	case 0:
800 		/* Child. */
801 		pjdlog_prefix_set("[TLS sandbox] (server) ");
802 #ifdef HAVE_SETPROCTITLE
803 		setproctitle("[TLS sandbox] (server) ");
804 #endif
805 		/* Close listen socket. */
806 		proto_close(tlsctx->tls_tcp);
807 		tls_call_exec_server(sock, tcp);
808 		/* NOTREACHED */
809 		PJDLOG_ABORT("Unreachable.");
810 	default:
811 		/* Parent. */
812 		newtlsctx = calloc(1, sizeof(*tlsctx));
813 		if (newtlsctx == NULL) {
814 			error = errno;
815 			proto_close(sock);
816 			proto_close(tcp);
817 			(void)kill(pid, SIGKILL);
818 			return (error);
819 		}
820 		proto_local_address(tcp, newtlsctx->tls_laddr,
821 		    sizeof(newtlsctx->tls_laddr));
822 		PJDLOG_ASSERT(strncmp(newtlsctx->tls_laddr, "tcp://", 6) == 0);
823 		bcopy("tls://", newtlsctx->tls_laddr, 6);
824 		*strrchr(newtlsctx->tls_laddr, ':') = '\0';
825 		proto_remote_address(tcp, newtlsctx->tls_raddr,
826 		    sizeof(newtlsctx->tls_raddr));
827 		PJDLOG_ASSERT(strncmp(newtlsctx->tls_raddr, "tcp://", 6) == 0);
828 		bcopy("tls://", newtlsctx->tls_raddr, 6);
829 		*strrchr(newtlsctx->tls_raddr, ':') = '\0';
830 		proto_close(tcp);
831 		proto_recv(sock, NULL, 0);
832 		newtlsctx->tls_sock = sock;
833 		newtlsctx->tls_tcp = NULL;
834 		newtlsctx->tls_wait_called = true;
835 		newtlsctx->tls_side = TLS_SIDE_SERVER_WORK;
836 		newtlsctx->tls_magic = TLS_CTX_MAGIC;
837 		*newctxp = newtlsctx;
838 		return (0);
839 	}
840 }
841 
842 static int
843 tls_wrap(int fd, bool client, void **ctxp)
844 {
845 	struct tls_ctx *tlsctx;
846 	struct proto_conn *sock;
847 	int error;
848 
849 	tlsctx = calloc(1, sizeof(*tlsctx));
850 	if (tlsctx == NULL)
851 		return (errno);
852 
853 	if (proto_wrap("socketpair", client, fd, &sock) == -1) {
854 		error = errno;
855 		free(tlsctx);
856 		return (error);
857 	}
858 
859 	tlsctx->tls_sock = sock;
860 	tlsctx->tls_tcp = NULL;
861 	tlsctx->tls_wait_called = (client ? false : true);
862 	tlsctx->tls_side = (client ? TLS_SIDE_CLIENT : TLS_SIDE_SERVER_WORK);
863 	tlsctx->tls_magic = TLS_CTX_MAGIC;
864 	*ctxp = tlsctx;
865 
866 	return (0);
867 }
868 
869 static int
870 tls_send(void *ctx, const unsigned char *data, size_t size, int fd)
871 {
872 	struct tls_ctx *tlsctx = ctx;
873 
874 	PJDLOG_ASSERT(tlsctx != NULL);
875 	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
876 	PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT ||
877 	    tlsctx->tls_side == TLS_SIDE_SERVER_WORK);
878 	PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
879 	PJDLOG_ASSERT(tlsctx->tls_wait_called);
880 	PJDLOG_ASSERT(fd == -1);
881 
882 	if (proto_send(tlsctx->tls_sock, data, size) == -1)
883 		return (errno);
884 
885 	return (0);
886 }
887 
888 static int
889 tls_recv(void *ctx, unsigned char *data, size_t size, int *fdp)
890 {
891 	struct tls_ctx *tlsctx = ctx;
892 
893 	PJDLOG_ASSERT(tlsctx != NULL);
894 	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
895 	PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT ||
896 	    tlsctx->tls_side == TLS_SIDE_SERVER_WORK);
897 	PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
898 	PJDLOG_ASSERT(tlsctx->tls_wait_called);
899 	PJDLOG_ASSERT(fdp == NULL);
900 
901 	if (proto_recv(tlsctx->tls_sock, data, size) == -1)
902 		return (errno);
903 
904 	return (0);
905 }
906 
907 static int
908 tls_descriptor(const void *ctx)
909 {
910 	const struct tls_ctx *tlsctx = ctx;
911 
912 	PJDLOG_ASSERT(tlsctx != NULL);
913 	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
914 
915 	switch (tlsctx->tls_side) {
916 	case TLS_SIDE_CLIENT:
917 	case TLS_SIDE_SERVER_WORK:
918 		PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
919 
920 		return (proto_descriptor(tlsctx->tls_sock));
921 	case TLS_SIDE_SERVER_LISTEN:
922 		PJDLOG_ASSERT(tlsctx->tls_tcp != NULL);
923 
924 		return (proto_descriptor(tlsctx->tls_tcp));
925 	default:
926 		PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side);
927 	}
928 }
929 
930 static bool
931 tcp_address_match(const void *ctx, const char *addr)
932 {
933 	const struct tls_ctx *tlsctx = ctx;
934 
935 	PJDLOG_ASSERT(tlsctx != NULL);
936 	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
937 
938 	return (strcmp(tlsctx->tls_raddr, addr) == 0);
939 }
940 
941 static void
942 tls_local_address(const void *ctx, char *addr, size_t size)
943 {
944 	const struct tls_ctx *tlsctx = ctx;
945 
946 	PJDLOG_ASSERT(tlsctx != NULL);
947 	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
948 	PJDLOG_ASSERT(tlsctx->tls_wait_called);
949 
950 	switch (tlsctx->tls_side) {
951 	case TLS_SIDE_CLIENT:
952 		PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
953 
954 		PJDLOG_VERIFY(strlcpy(addr, "tls://N/A", size) < size);
955 		break;
956 	case TLS_SIDE_SERVER_WORK:
957 		PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
958 
959 		PJDLOG_VERIFY(strlcpy(addr, tlsctx->tls_laddr, size) < size);
960 		break;
961 	case TLS_SIDE_SERVER_LISTEN:
962 		PJDLOG_ASSERT(tlsctx->tls_tcp != NULL);
963 
964 		proto_local_address(tlsctx->tls_tcp, addr, size);
965 		PJDLOG_ASSERT(strncmp(addr, "tcp://", 6) == 0);
966 		/* Replace tcp:// prefix with tls:// */
967 		bcopy("tls://", addr, 6);
968 		break;
969 	default:
970 		PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side);
971 	}
972 }
973 
974 static void
975 tls_remote_address(const void *ctx, char *addr, size_t size)
976 {
977 	const struct tls_ctx *tlsctx = ctx;
978 
979 	PJDLOG_ASSERT(tlsctx != NULL);
980 	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
981 	PJDLOG_ASSERT(tlsctx->tls_wait_called);
982 
983 	switch (tlsctx->tls_side) {
984 	case TLS_SIDE_CLIENT:
985 		PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
986 
987 		PJDLOG_VERIFY(strlcpy(addr, "tls://N/A", size) < size);
988 		break;
989 	case TLS_SIDE_SERVER_WORK:
990 		PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
991 
992 		PJDLOG_VERIFY(strlcpy(addr, tlsctx->tls_raddr, size) < size);
993 		break;
994 	case TLS_SIDE_SERVER_LISTEN:
995 		PJDLOG_ASSERT(tlsctx->tls_tcp != NULL);
996 
997 		proto_remote_address(tlsctx->tls_tcp, addr, size);
998 		PJDLOG_ASSERT(strncmp(addr, "tcp://", 6) == 0);
999 		/* Replace tcp:// prefix with tls:// */
1000 		bcopy("tls://", addr, 6);
1001 		break;
1002 	default:
1003 		PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side);
1004 	}
1005 }
1006 
1007 static void
1008 tls_close(void *ctx)
1009 {
1010 	struct tls_ctx *tlsctx = ctx;
1011 
1012 	PJDLOG_ASSERT(tlsctx != NULL);
1013 	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
1014 
1015 	if (tlsctx->tls_sock != NULL) {
1016 		proto_close(tlsctx->tls_sock);
1017 		tlsctx->tls_sock = NULL;
1018 	}
1019 	if (tlsctx->tls_tcp != NULL) {
1020 		proto_close(tlsctx->tls_tcp);
1021 		tlsctx->tls_tcp = NULL;
1022 	}
1023 	tlsctx->tls_side = 0;
1024 	tlsctx->tls_magic = 0;
1025 	free(tlsctx);
1026 }
1027 
1028 static int
1029 tls_exec(int argc, char *argv[])
1030 {
1031 
1032 	PJDLOG_ASSERT(argc > 3);
1033 	PJDLOG_ASSERT(strcmp(argv[0], "tls") == 0);
1034 
1035 	pjdlog_init(atoi(argv[3]) == 0 ? PJDLOG_MODE_SYSLOG : PJDLOG_MODE_STD);
1036 
1037 	if (strcmp(argv[2], "client") == 0) {
1038 		if (argc != 10)
1039 			return (EINVAL);
1040 		tls_exec_client(argv[1], atoi(argv[3]),
1041 		    argv[4][0] == '\0' ? NULL : argv[4], argv[5], argv[6],
1042 		    argv[7], atoi(argv[8]), atoi(argv[9]));
1043 	} else if (strcmp(argv[2], "server") == 0) {
1044 		if (argc != 7)
1045 			return (EINVAL);
1046 		tls_exec_server(argv[1], atoi(argv[3]), argv[4], argv[5],
1047 		    atoi(argv[6]));
1048 	}
1049 	return (EINVAL);
1050 }
1051 
1052 static struct proto tls_proto = {
1053 	.prt_name = "tls",
1054 	.prt_connect = tls_connect,
1055 	.prt_connect_wait = tls_connect_wait,
1056 	.prt_server = tls_server,
1057 	.prt_accept = tls_accept,
1058 	.prt_wrap = tls_wrap,
1059 	.prt_send = tls_send,
1060 	.prt_recv = tls_recv,
1061 	.prt_descriptor = tls_descriptor,
1062 	.prt_address_match = tcp_address_match,
1063 	.prt_local_address = tls_local_address,
1064 	.prt_remote_address = tls_remote_address,
1065 	.prt_close = tls_close,
1066 	.prt_exec = tls_exec
1067 };
1068 
1069 static __constructor void
1070 tls_ctor(void)
1071 {
1072 
1073 	proto_register(&tls_proto, false);
1074 }
1075