xref: /freebsd/contrib/openbsm/bin/auditdistd/proto_tls.c (revision ce3adf4362fcca6a43e500b2531f0038adbfbd21)
1 /*-
2  * Copyright (c) 2011 The FreeBSD Foundation
3  * All rights reserved.
4  *
5  * This software was developed by Pawel Jakub Dawidek under sponsorship from
6  * the FreeBSD Foundation.
7  *
8  * Redistribution and use in source and binary forms, with or without
9  * modification, are permitted provided that the following conditions
10  * are met:
11  * 1. Redistributions of source code must retain the above copyright
12  *    notice, this list of conditions and the following disclaimer.
13  * 2. Redistributions in binary form must reproduce the above copyright
14  *    notice, this list of conditions and the following disclaimer in the
15  *    documentation and/or other materials provided with the distribution.
16  *
17  * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
18  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
21  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
23  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
24  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
26  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
27  * SUCH DAMAGE.
28  *
29  * $P4: //depot/projects/trustedbsd/openbsm/bin/auditdistd/proto_tls.c#2 $
30  */
31 
32 #include <config/config.h>
33 
34 #include <sys/param.h>	/* MAXHOSTNAMELEN */
35 #include <sys/socket.h>
36 
37 #include <arpa/inet.h>
38 
39 #include <netinet/in.h>
40 #include <netinet/tcp.h>
41 
42 #include <errno.h>
43 #include <fcntl.h>
44 #include <netdb.h>
45 #include <signal.h>
46 #include <stdbool.h>
47 #include <stdint.h>
48 #include <stdio.h>
49 #include <string.h>
50 #include <unistd.h>
51 
52 #include <openssl/err.h>
53 #include <openssl/ssl.h>
54 
55 #include <compat/compat.h>
56 #ifndef HAVE_CLOSEFROM
57 #include <compat/closefrom.h>
58 #endif
59 #ifndef HAVE_STRLCPY
60 #include <compat/strlcpy.h>
61 #endif
62 
63 #include "pjdlog.h"
64 #include "proto_impl.h"
65 #include "sandbox.h"
66 #include "subr.h"
67 
68 #define	TLS_CTX_MAGIC	0x715c7
69 struct tls_ctx {
70 	int		tls_magic;
71 	struct proto_conn *tls_sock;
72 	struct proto_conn *tls_tcp;
73 	char		tls_laddr[256];
74 	char		tls_raddr[256];
75 	int		tls_side;
76 #define	TLS_SIDE_CLIENT		0
77 #define	TLS_SIDE_SERVER_LISTEN	1
78 #define	TLS_SIDE_SERVER_WORK	2
79 	bool		tls_wait_called;
80 };
81 
82 #define	TLS_DEFAULT_TIMEOUT	30
83 
84 static int tls_connect_wait(void *ctx, int timeout);
85 static void tls_close(void *ctx);
86 
87 static void
88 block(int fd)
89 {
90 	int flags;
91 
92 	flags = fcntl(fd, F_GETFL);
93 	if (flags == -1)
94 		pjdlog_exit(EX_TEMPFAIL, "fcntl(F_GETFL) failed");
95 	flags &= ~O_NONBLOCK;
96 	if (fcntl(fd, F_SETFL, flags) == -1)
97 		pjdlog_exit(EX_TEMPFAIL, "fcntl(F_SETFL) failed");
98 }
99 
100 static void
101 nonblock(int fd)
102 {
103 	int flags;
104 
105 	flags = fcntl(fd, F_GETFL);
106 	if (flags == -1)
107 		pjdlog_exit(EX_TEMPFAIL, "fcntl(F_GETFL) failed");
108 	flags |= O_NONBLOCK;
109 	if (fcntl(fd, F_SETFL, flags) == -1)
110 		pjdlog_exit(EX_TEMPFAIL, "fcntl(F_SETFL) failed");
111 }
112 
113 static int
114 wait_for_fd(int fd, int timeout)
115 {
116 	struct timeval tv;
117 	fd_set fdset;
118 	int error, ret;
119 
120 	error = 0;
121 
122 	for (;;) {
123 		FD_ZERO(&fdset);
124 		FD_SET(fd, &fdset);
125 
126 		tv.tv_sec = timeout;
127 		tv.tv_usec = 0;
128 
129 		ret = select(fd + 1, NULL, &fdset, NULL,
130 		    timeout == -1 ? NULL : &tv);
131 		if (ret == 0) {
132 			error = ETIMEDOUT;
133 			break;
134 		} else if (ret == -1) {
135 			if (errno == EINTR)
136 				continue;
137 			error = errno;
138 			break;
139 		}
140 		PJDLOG_ASSERT(ret > 0);
141 		PJDLOG_ASSERT(FD_ISSET(fd, &fdset));
142 		break;
143 	}
144 
145 	return (error);
146 }
147 
148 static void
149 ssl_log_errors(void)
150 {
151 	unsigned long error;
152 
153 	while ((error = ERR_get_error()) != 0)
154 		pjdlog_error("SSL error: %s", ERR_error_string(error, NULL));
155 }
156 
157 static int
158 ssl_check_error(SSL *ssl, int ret)
159 {
160 	int error;
161 
162 	error = SSL_get_error(ssl, ret);
163 
164 	switch (error) {
165 	case SSL_ERROR_NONE:
166 		return (0);
167 	case SSL_ERROR_WANT_READ:
168 		pjdlog_debug(2, "SSL_ERROR_WANT_READ");
169 		return (-1);
170 	case SSL_ERROR_WANT_WRITE:
171 		pjdlog_debug(2, "SSL_ERROR_WANT_WRITE");
172 		return (-1);
173 	case SSL_ERROR_ZERO_RETURN:
174 		pjdlog_exitx(EX_OK, "Connection closed.");
175 	case SSL_ERROR_SYSCALL:
176 		ssl_log_errors();
177 		pjdlog_exitx(EX_TEMPFAIL, "SSL I/O error.");
178 	case SSL_ERROR_SSL:
179 		ssl_log_errors();
180 		pjdlog_exitx(EX_TEMPFAIL, "SSL protocol error.");
181 	default:
182 		ssl_log_errors();
183 		pjdlog_exitx(EX_TEMPFAIL, "Unknown SSL error (%d).", error);
184 	}
185 }
186 
187 static void
188 tcp_recv_ssl_send(int recvfd, SSL *sendssl)
189 {
190 	static unsigned char buf[65536];
191 	ssize_t tcpdone;
192 	int sendfd, ssldone;
193 
194 	sendfd = SSL_get_fd(sendssl);
195 	PJDLOG_ASSERT(sendfd >= 0);
196 	pjdlog_debug(2, "%s: start %d -> %d", __func__, recvfd, sendfd);
197 	for (;;) {
198 		tcpdone = recv(recvfd, buf, sizeof(buf), 0);
199 		pjdlog_debug(2, "%s: recv() returned %zd", __func__, tcpdone);
200 		if (tcpdone == 0) {
201 			pjdlog_debug(1, "Connection terminated.");
202 			exit(0);
203 		} else if (tcpdone == -1) {
204 			if (errno == EINTR)
205 				continue;
206 			else if (errno == EAGAIN)
207 				break;
208 			pjdlog_exit(EX_TEMPFAIL, "recv() failed");
209 		}
210 		for (;;) {
211 			ssldone = SSL_write(sendssl, buf, (int)tcpdone);
212 			pjdlog_debug(2, "%s: send() returned %d", __func__,
213 			    ssldone);
214 			if (ssl_check_error(sendssl, ssldone) == -1) {
215 				(void)wait_for_fd(sendfd, -1);
216 				continue;
217 			}
218 			PJDLOG_ASSERT(ssldone == tcpdone);
219 			break;
220 		}
221 	}
222 	pjdlog_debug(2, "%s: done %d -> %d", __func__, recvfd, sendfd);
223 }
224 
225 static void
226 ssl_recv_tcp_send(SSL *recvssl, int sendfd)
227 {
228 	static unsigned char buf[65536];
229 	unsigned char *ptr;
230 	ssize_t tcpdone;
231 	size_t todo;
232 	int recvfd, ssldone;
233 
234 	recvfd = SSL_get_fd(recvssl);
235 	PJDLOG_ASSERT(recvfd >= 0);
236 	pjdlog_debug(2, "%s: start %d -> %d", __func__, recvfd, sendfd);
237 	for (;;) {
238 		ssldone = SSL_read(recvssl, buf, sizeof(buf));
239 		pjdlog_debug(2, "%s: SSL_read() returned %d", __func__,
240 		    ssldone);
241 		if (ssl_check_error(recvssl, ssldone) == -1)
242 			break;
243 		todo = (size_t)ssldone;
244 		ptr = buf;
245 		do {
246 			tcpdone = send(sendfd, ptr, todo, MSG_NOSIGNAL);
247 			pjdlog_debug(2, "%s: send() returned %zd", __func__,
248 			    tcpdone);
249 			if (tcpdone == 0) {
250 				pjdlog_debug(1, "Connection terminated.");
251 				exit(0);
252 			} else if (tcpdone == -1) {
253 				if (errno == EINTR || errno == ENOBUFS)
254 					continue;
255 				if (errno == EAGAIN) {
256 					(void)wait_for_fd(sendfd, -1);
257 					continue;
258 				}
259 				pjdlog_exit(EX_TEMPFAIL, "send() failed");
260 			}
261 			todo -= tcpdone;
262 			ptr += tcpdone;
263 		} while (todo > 0);
264 	}
265 	pjdlog_debug(2, "%s: done %d -> %d", __func__, recvfd, sendfd);
266 }
267 
268 static void
269 tls_loop(int sockfd, SSL *tcpssl)
270 {
271 	fd_set fds;
272 	int maxfd, tcpfd;
273 
274 	tcpfd = SSL_get_fd(tcpssl);
275 	PJDLOG_ASSERT(tcpfd >= 0);
276 
277 	for (;;) {
278 		FD_ZERO(&fds);
279 		FD_SET(sockfd, &fds);
280 		FD_SET(tcpfd, &fds);
281 		maxfd = MAX(sockfd, tcpfd);
282 
283 		PJDLOG_ASSERT(maxfd + 1 <= (int)FD_SETSIZE);
284 		if (select(maxfd + 1, &fds, NULL, NULL, NULL) == -1) {
285 			if (errno == EINTR)
286 				continue;
287 			pjdlog_exit(EX_TEMPFAIL, "select() failed");
288 		}
289 		if (FD_ISSET(sockfd, &fds))
290 			tcp_recv_ssl_send(sockfd, tcpssl);
291 		if (FD_ISSET(tcpfd, &fds))
292 			ssl_recv_tcp_send(tcpssl, sockfd);
293 	}
294 }
295 
296 static void
297 tls_certificate_verify(SSL *ssl, const char *fingerprint)
298 {
299 	unsigned char md[EVP_MAX_MD_SIZE];
300 	char mdstr[sizeof("SHA256=") - 1 + EVP_MAX_MD_SIZE * 3];
301 	char *mdstrp;
302 	unsigned int i, mdsize;
303 	X509 *cert;
304 
305 	if (fingerprint[0] == '\0') {
306 		pjdlog_debug(1, "No fingerprint verification requested.");
307 		return;
308 	}
309 
310 	cert = SSL_get_peer_certificate(ssl);
311 	if (cert == NULL)
312 		pjdlog_exitx(EX_TEMPFAIL, "No peer certificate received.");
313 
314 	if (X509_digest(cert, EVP_sha256(), md, &mdsize) != 1)
315 		pjdlog_exitx(EX_TEMPFAIL, "X509_digest() failed.");
316 	PJDLOG_ASSERT(mdsize <= EVP_MAX_MD_SIZE);
317 
318 	X509_free(cert);
319 
320 	(void)strlcpy(mdstr, "SHA256=", sizeof(mdstr));
321 	mdstrp = mdstr + strlen(mdstr);
322 	for (i = 0; i < mdsize; i++) {
323 		PJDLOG_VERIFY(mdstrp + 3 <= mdstr + sizeof(mdstr));
324 		(void)sprintf(mdstrp, "%02hhX:", md[i]);
325 		mdstrp += 3;
326 	}
327 	/* Clear last colon. */
328 	mdstrp[-1] = '\0';
329 	if (strcasecmp(mdstr, fingerprint) != 0) {
330 		pjdlog_exitx(EX_NOPERM,
331 		    "Finger print doesn't match. Received \"%s\", expected \"%s\"",
332 		    mdstr, fingerprint);
333 	}
334 }
335 
336 static void
337 tls_exec_client(const char *user, int startfd, const char *srcaddr,
338     const char *dstaddr, const char *fingerprint, const char *defport,
339     int timeout, int debuglevel)
340 {
341 	struct proto_conn *tcp;
342 	char *saddr, *daddr;
343 	SSL_CTX *sslctx;
344 	SSL *ssl;
345 	long ret;
346 	int sockfd, tcpfd;
347 	uint8_t connected;
348 
349 	pjdlog_debug_set(debuglevel);
350 	pjdlog_prefix_set("[TLS sandbox] (client) ");
351 #ifdef HAVE_SETPROCTITLE
352 	setproctitle("[TLS sandbox] (client) ");
353 #endif
354 	proto_set("tcp:port", defport);
355 
356 	sockfd = startfd;
357 
358 	/* Change tls:// to tcp://. */
359 	if (srcaddr == NULL) {
360 		saddr = NULL;
361 	} else {
362 		saddr = strdup(srcaddr);
363 		if (saddr == NULL)
364 			pjdlog_exitx(EX_TEMPFAIL, "Unable to allocate memory.");
365 		bcopy("tcp://", saddr, 6);
366 	}
367 	daddr = strdup(dstaddr);
368 	if (daddr == NULL)
369 		pjdlog_exitx(EX_TEMPFAIL, "Unable to allocate memory.");
370 	bcopy("tcp://", daddr, 6);
371 
372 	/* Establish TCP connection. */
373 	if (proto_connect(saddr, daddr, timeout, &tcp) == -1)
374 		exit(EX_TEMPFAIL);
375 
376 	SSL_load_error_strings();
377 	SSL_library_init();
378 
379 	/*
380 	 * TODO: On FreeBSD we could move this below sandbox() once libc and
381 	 *       libcrypto use sysctl kern.arandom to obtain random data
382 	 *       instead of /dev/urandom and friends.
383 	 */
384 	sslctx = SSL_CTX_new(TLSv1_client_method());
385 	if (sslctx == NULL)
386 		pjdlog_exitx(EX_TEMPFAIL, "SSL_CTX_new() failed.");
387 
388 	if (sandbox(user, true, "proto_tls client: %s", dstaddr) != 0)
389 		pjdlog_exitx(EX_CONFIG, "Unable to sandbox TLS client.");
390 	pjdlog_debug(1, "Privileges successfully dropped.");
391 
392 	SSL_CTX_set_options(sslctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);
393 
394 	/* Load CA certs. */
395 	/* TODO */
396 	//SSL_CTX_load_verify_locations(sslctx, cacerts_file, NULL);
397 
398 	ssl = SSL_new(sslctx);
399 	if (ssl == NULL)
400 		pjdlog_exitx(EX_TEMPFAIL, "SSL_new() failed.");
401 
402 	tcpfd = proto_descriptor(tcp);
403 
404 	block(tcpfd);
405 
406 	if (SSL_set_fd(ssl, tcpfd) != 1)
407 		pjdlog_exitx(EX_TEMPFAIL, "SSL_set_fd() failed.");
408 
409 	ret = SSL_connect(ssl);
410 	ssl_check_error(ssl, (int)ret);
411 
412 	nonblock(sockfd);
413 	nonblock(tcpfd);
414 
415 	tls_certificate_verify(ssl, fingerprint);
416 
417 	/*
418 	 * The following byte is send to make proto_connect_wait() to work.
419 	 */
420 	connected = 1;
421 	for (;;) {
422 		switch (send(sockfd, &connected, sizeof(connected), 0)) {
423 		case -1:
424 			if (errno == EINTR || errno == ENOBUFS)
425 				continue;
426 			if (errno == EAGAIN) {
427 				(void)wait_for_fd(sockfd, -1);
428 				continue;
429 			}
430 			pjdlog_exit(EX_TEMPFAIL, "send() failed");
431 		case 0:
432 			pjdlog_debug(1, "Connection terminated.");
433 			exit(0);
434 		case 1:
435 			break;
436 		}
437 		break;
438 	}
439 
440 	tls_loop(sockfd, ssl);
441 }
442 
443 static void
444 tls_call_exec_client(struct proto_conn *sock, const char *srcaddr,
445     const char *dstaddr, int timeout)
446 {
447 	char *timeoutstr, *startfdstr, *debugstr;
448 	int startfd;
449 
450 	/* Declare that we are receiver. */
451 	proto_recv(sock, NULL, 0);
452 
453 	if (pjdlog_mode_get() == PJDLOG_MODE_STD)
454 		startfd = 3;
455 	else /* if (pjdlog_mode_get() == PJDLOG_MODE_SYSLOG) */
456 		startfd = 0;
457 
458 	if (proto_descriptor(sock) != startfd) {
459 		/* Move socketpair descriptor to descriptor number startfd. */
460 		if (dup2(proto_descriptor(sock), startfd) == -1)
461 			pjdlog_exit(EX_OSERR, "dup2() failed");
462 		proto_close(sock);
463 	} else {
464 		/*
465 		 * The FD_CLOEXEC is cleared by dup2(2), so when we not
466 		 * call it, we have to clear it by hand in case it is set.
467 		 */
468 		if (fcntl(startfd, F_SETFD, 0) == -1)
469 			pjdlog_exit(EX_OSERR, "fcntl() failed");
470 	}
471 
472 	closefrom(startfd + 1);
473 
474 	if (asprintf(&startfdstr, "%d", startfd) == -1)
475 		pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
476 	if (timeout == -1)
477 		timeout = TLS_DEFAULT_TIMEOUT;
478 	if (asprintf(&timeoutstr, "%d", timeout) == -1)
479 		pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
480 	if (asprintf(&debugstr, "%d", pjdlog_debug_get()) == -1)
481 		pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
482 
483 	execl(proto_get("execpath"), proto_get("execpath"), "proto", "tls",
484 	    proto_get("user"), "client", startfdstr,
485 	    srcaddr == NULL ? "" : srcaddr, dstaddr,
486 	    proto_get("tls:fingerprint"), proto_get("tcp:port"), timeoutstr,
487 	    debugstr, NULL);
488 	pjdlog_exit(EX_SOFTWARE, "execl() failed");
489 }
490 
491 static int
492 tls_connect(const char *srcaddr, const char *dstaddr, int timeout, void **ctxp)
493 {
494 	struct tls_ctx *tlsctx;
495 	struct proto_conn *sock;
496 	pid_t pid;
497 	int error;
498 
499 	PJDLOG_ASSERT(srcaddr == NULL || srcaddr[0] != '\0');
500 	PJDLOG_ASSERT(dstaddr != NULL);
501 	PJDLOG_ASSERT(timeout >= -1);
502 	PJDLOG_ASSERT(ctxp != NULL);
503 
504 	if (strncmp(dstaddr, "tls://", 6) != 0)
505 		return (-1);
506 	if (srcaddr != NULL && strncmp(srcaddr, "tls://", 6) != 0)
507 		return (-1);
508 
509 	if (proto_connect(NULL, "socketpair://", -1, &sock) == -1)
510 		return (errno);
511 
512 #if 0
513 	/*
514 	 * We use rfork() with the following flags to disable SIGCHLD
515 	 * delivery upon the sandbox process exit.
516 	 */
517 	pid = rfork(RFFDG | RFPROC | RFTSIGZMB | RFTSIGFLAGS(0));
518 #else
519 	/*
520 	 * We don't use rfork() to be able to log information about sandbox
521 	 * process exiting.
522 	 */
523 	pid = fork();
524 #endif
525 	switch (pid) {
526 	case -1:
527 		/* Failure. */
528 		error = errno;
529 		proto_close(sock);
530 		return (error);
531 	case 0:
532 		/* Child. */
533 		pjdlog_prefix_set("[TLS sandbox] (client) ");
534 #ifdef HAVE_SETPROCTITLE
535 		setproctitle("[TLS sandbox] (client) ");
536 #endif
537 		tls_call_exec_client(sock, srcaddr, dstaddr, timeout);
538 		/* NOTREACHED */
539 	default:
540 		/* Parent. */
541 		tlsctx = calloc(1, sizeof(*tlsctx));
542 		if (tlsctx == NULL) {
543 			error = errno;
544 			proto_close(sock);
545 			(void)kill(pid, SIGKILL);
546 			return (error);
547 		}
548 		proto_send(sock, NULL, 0);
549 		tlsctx->tls_sock = sock;
550 		tlsctx->tls_tcp = NULL;
551 		tlsctx->tls_side = TLS_SIDE_CLIENT;
552 		tlsctx->tls_wait_called = false;
553 		tlsctx->tls_magic = TLS_CTX_MAGIC;
554 		if (timeout >= 0) {
555 			error = tls_connect_wait(tlsctx, timeout);
556 			if (error != 0) {
557 				(void)kill(pid, SIGKILL);
558 				tls_close(tlsctx);
559 				return (error);
560 			}
561 		}
562 		*ctxp = tlsctx;
563 		return (0);
564 	}
565 }
566 
567 static int
568 tls_connect_wait(void *ctx, int timeout)
569 {
570 	struct tls_ctx *tlsctx = ctx;
571 	int error, sockfd;
572 	uint8_t connected;
573 
574 	PJDLOG_ASSERT(tlsctx != NULL);
575 	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
576 	PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT);
577 	PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
578 	PJDLOG_ASSERT(!tlsctx->tls_wait_called);
579 	PJDLOG_ASSERT(timeout >= 0);
580 
581 	sockfd = proto_descriptor(tlsctx->tls_sock);
582 	error = wait_for_fd(sockfd, timeout);
583 	if (error != 0)
584 		return (error);
585 
586 	for (;;) {
587 		switch (recv(sockfd, &connected, sizeof(connected),
588 		    MSG_WAITALL)) {
589 		case -1:
590 			if (errno == EINTR || errno == ENOBUFS)
591 				continue;
592 			error = errno;
593 			break;
594 		case 0:
595 			pjdlog_debug(1, "Connection terminated.");
596 			error = ENOTCONN;
597 			break;
598 		case 1:
599 			tlsctx->tls_wait_called = true;
600 			break;
601 		}
602 		break;
603 	}
604 
605 	return (error);
606 }
607 
608 static int
609 tls_server(const char *lstaddr, void **ctxp)
610 {
611 	struct proto_conn *tcp;
612 	struct tls_ctx *tlsctx;
613 	char *laddr;
614 	int error;
615 
616 	if (strncmp(lstaddr, "tls://", 6) != 0)
617 		return (-1);
618 
619 	tlsctx = malloc(sizeof(*tlsctx));
620 	if (tlsctx == NULL) {
621 		pjdlog_warning("Unable to allocate memory.");
622 		return (ENOMEM);
623 	}
624 
625 	laddr = strdup(lstaddr);
626 	if (laddr == NULL) {
627 		free(tlsctx);
628 		pjdlog_warning("Unable to allocate memory.");
629 		return (ENOMEM);
630 	}
631 	bcopy("tcp://", laddr, 6);
632 
633 	if (proto_server(laddr, &tcp) == -1) {
634 		error = errno;
635 		free(tlsctx);
636 		free(laddr);
637 		return (error);
638 	}
639 	free(laddr);
640 
641 	tlsctx->tls_sock = NULL;
642 	tlsctx->tls_tcp = tcp;
643 	tlsctx->tls_side = TLS_SIDE_SERVER_LISTEN;
644 	tlsctx->tls_wait_called = true;
645 	tlsctx->tls_magic = TLS_CTX_MAGIC;
646 	*ctxp = tlsctx;
647 
648 	return (0);
649 }
650 
651 static void
652 tls_exec_server(const char *user, int startfd, const char *privkey,
653     const char *cert, int debuglevel)
654 {
655 	SSL_CTX *sslctx;
656 	SSL *ssl;
657 	int sockfd, tcpfd, ret;
658 
659 	pjdlog_debug_set(debuglevel);
660 	pjdlog_prefix_set("[TLS sandbox] (server) ");
661 #ifdef HAVE_SETPROCTITLE
662 	setproctitle("[TLS sandbox] (server) ");
663 #endif
664 
665 	sockfd = startfd;
666 	tcpfd = startfd + 1;
667 
668 	SSL_load_error_strings();
669 	SSL_library_init();
670 
671 	sslctx = SSL_CTX_new(TLSv1_server_method());
672 	if (sslctx == NULL)
673 		pjdlog_exitx(EX_TEMPFAIL, "SSL_CTX_new() failed.");
674 
675 	SSL_CTX_set_options(sslctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);
676 
677 	ssl = SSL_new(sslctx);
678 	if (ssl == NULL)
679 		pjdlog_exitx(EX_TEMPFAIL, "SSL_new() failed.");
680 
681 	if (SSL_use_RSAPrivateKey_file(ssl, privkey, SSL_FILETYPE_PEM) != 1) {
682 		ssl_log_errors();
683 		pjdlog_exitx(EX_CONFIG,
684 		    "SSL_use_RSAPrivateKey_file(%s) failed.", privkey);
685 	}
686 
687 	if (SSL_use_certificate_file(ssl, cert, SSL_FILETYPE_PEM) != 1) {
688 		ssl_log_errors();
689 		pjdlog_exitx(EX_CONFIG, "SSL_use_certificate_file(%s) failed.",
690 		    cert);
691 	}
692 
693 	if (sandbox(user, true, "proto_tls server") != 0)
694 		pjdlog_exitx(EX_CONFIG, "Unable to sandbox TLS server.");
695 	pjdlog_debug(1, "Privileges successfully dropped.");
696 
697 	nonblock(sockfd);
698 	nonblock(tcpfd);
699 
700 	if (SSL_set_fd(ssl, tcpfd) != 1)
701 		pjdlog_exitx(EX_TEMPFAIL, "SSL_set_fd() failed.");
702 
703 	ret = SSL_accept(ssl);
704 	ssl_check_error(ssl, ret);
705 
706 	tls_loop(sockfd, ssl);
707 }
708 
709 static void
710 tls_call_exec_server(struct proto_conn *sock, struct proto_conn *tcp)
711 {
712 	int startfd, sockfd, tcpfd, safefd;
713 	char *startfdstr, *debugstr;
714 
715 	if (pjdlog_mode_get() == PJDLOG_MODE_STD)
716 		startfd = 3;
717 	else /* if (pjdlog_mode_get() == PJDLOG_MODE_SYSLOG) */
718 		startfd = 0;
719 
720 	/* Declare that we are receiver. */
721 	proto_send(sock, NULL, 0);
722 
723 	sockfd = proto_descriptor(sock);
724 	tcpfd = proto_descriptor(tcp);
725 
726 	safefd = MAX(sockfd, tcpfd);
727 	safefd = MAX(safefd, startfd);
728 	safefd++;
729 
730 	/* Move sockfd and tcpfd to safe numbers first. */
731 	if (dup2(sockfd, safefd) == -1)
732 		pjdlog_exit(EX_OSERR, "dup2() failed");
733 	proto_close(sock);
734 	sockfd = safefd;
735 	if (dup2(tcpfd, safefd + 1) == -1)
736 		pjdlog_exit(EX_OSERR, "dup2() failed");
737 	proto_close(tcp);
738 	tcpfd = safefd + 1;
739 
740 	/* Move socketpair descriptor to descriptor number startfd. */
741 	if (dup2(sockfd, startfd) == -1)
742 		pjdlog_exit(EX_OSERR, "dup2() failed");
743 	(void)close(sockfd);
744 	/* Move tcp descriptor to descriptor number startfd + 1. */
745 	if (dup2(tcpfd, startfd + 1) == -1)
746 		pjdlog_exit(EX_OSERR, "dup2() failed");
747 	(void)close(tcpfd);
748 
749 	closefrom(startfd + 2);
750 
751 	/*
752 	 * Even if FD_CLOEXEC was set on descriptors before dup2(), it should
753 	 * have been cleared on dup2(), but better be safe than sorry.
754 	 */
755 	if (fcntl(startfd, F_SETFD, 0) == -1)
756 		pjdlog_exit(EX_OSERR, "fcntl() failed");
757 	if (fcntl(startfd + 1, F_SETFD, 0) == -1)
758 		pjdlog_exit(EX_OSERR, "fcntl() failed");
759 
760 	if (asprintf(&startfdstr, "%d", startfd) == -1)
761 		pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
762 	if (asprintf(&debugstr, "%d", pjdlog_debug_get()) == -1)
763 		pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
764 
765 	execl(proto_get("execpath"), proto_get("execpath"), "proto", "tls",
766 	    proto_get("user"), "server", startfdstr, proto_get("tls:keyfile"),
767 	    proto_get("tls:certfile"), debugstr, NULL);
768 	pjdlog_exit(EX_SOFTWARE, "execl() failed");
769 }
770 
771 static int
772 tls_accept(void *ctx, void **newctxp)
773 {
774 	struct tls_ctx *tlsctx = ctx;
775 	struct tls_ctx *newtlsctx;
776 	struct proto_conn *sock, *tcp;
777 	pid_t pid;
778 	int error;
779 
780 	PJDLOG_ASSERT(tlsctx != NULL);
781 	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
782 	PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_SERVER_LISTEN);
783 
784 	if (proto_connect(NULL, "socketpair://", -1, &sock) == -1)
785 		return (errno);
786 
787 	/* Accept TCP connection. */
788 	if (proto_accept(tlsctx->tls_tcp, &tcp) == -1) {
789 		error = errno;
790 		proto_close(sock);
791 		return (error);
792 	}
793 
794 	pid = fork();
795 	switch (pid) {
796 	case -1:
797 		/* Failure. */
798 		error = errno;
799 		proto_close(sock);
800 		return (error);
801 	case 0:
802 		/* Child. */
803 		pjdlog_prefix_set("[TLS sandbox] (server) ");
804 #ifdef HAVE_SETPROCTITLE
805 		setproctitle("[TLS sandbox] (server) ");
806 #endif
807 		/* Close listen socket. */
808 		proto_close(tlsctx->tls_tcp);
809 		tls_call_exec_server(sock, tcp);
810 		/* NOTREACHED */
811 		PJDLOG_ABORT("Unreachable.");
812 	default:
813 		/* Parent. */
814 		newtlsctx = calloc(1, sizeof(*tlsctx));
815 		if (newtlsctx == NULL) {
816 			error = errno;
817 			proto_close(sock);
818 			proto_close(tcp);
819 			(void)kill(pid, SIGKILL);
820 			return (error);
821 		}
822 		proto_local_address(tcp, newtlsctx->tls_laddr,
823 		    sizeof(newtlsctx->tls_laddr));
824 		PJDLOG_ASSERT(strncmp(newtlsctx->tls_laddr, "tcp://", 6) == 0);
825 		bcopy("tls://", newtlsctx->tls_laddr, 6);
826 		*strrchr(newtlsctx->tls_laddr, ':') = '\0';
827 		proto_remote_address(tcp, newtlsctx->tls_raddr,
828 		    sizeof(newtlsctx->tls_raddr));
829 		PJDLOG_ASSERT(strncmp(newtlsctx->tls_raddr, "tcp://", 6) == 0);
830 		bcopy("tls://", newtlsctx->tls_raddr, 6);
831 		*strrchr(newtlsctx->tls_raddr, ':') = '\0';
832 		proto_close(tcp);
833 		proto_recv(sock, NULL, 0);
834 		newtlsctx->tls_sock = sock;
835 		newtlsctx->tls_tcp = NULL;
836 		newtlsctx->tls_wait_called = true;
837 		newtlsctx->tls_side = TLS_SIDE_SERVER_WORK;
838 		newtlsctx->tls_magic = TLS_CTX_MAGIC;
839 		*newctxp = newtlsctx;
840 		return (0);
841 	}
842 }
843 
844 static int
845 tls_wrap(int fd, bool client, void **ctxp)
846 {
847 	struct tls_ctx *tlsctx;
848 	struct proto_conn *sock;
849 	int error;
850 
851 	tlsctx = calloc(1, sizeof(*tlsctx));
852 	if (tlsctx == NULL)
853 		return (errno);
854 
855 	if (proto_wrap("socketpair", client, fd, &sock) == -1) {
856 		error = errno;
857 		free(tlsctx);
858 		return (error);
859 	}
860 
861 	tlsctx->tls_sock = sock;
862 	tlsctx->tls_tcp = NULL;
863 	tlsctx->tls_wait_called = (client ? false : true);
864 	tlsctx->tls_side = (client ? TLS_SIDE_CLIENT : TLS_SIDE_SERVER_WORK);
865 	tlsctx->tls_magic = TLS_CTX_MAGIC;
866 	*ctxp = tlsctx;
867 
868 	return (0);
869 }
870 
871 static int
872 tls_send(void *ctx, const unsigned char *data, size_t size, int fd)
873 {
874 	struct tls_ctx *tlsctx = ctx;
875 
876 	PJDLOG_ASSERT(tlsctx != NULL);
877 	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
878 	PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT ||
879 	    tlsctx->tls_side == TLS_SIDE_SERVER_WORK);
880 	PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
881 	PJDLOG_ASSERT(tlsctx->tls_wait_called);
882 	PJDLOG_ASSERT(fd == -1);
883 
884 	if (proto_send(tlsctx->tls_sock, data, size) == -1)
885 		return (errno);
886 
887 	return (0);
888 }
889 
890 static int
891 tls_recv(void *ctx, unsigned char *data, size_t size, int *fdp)
892 {
893 	struct tls_ctx *tlsctx = ctx;
894 
895 	PJDLOG_ASSERT(tlsctx != NULL);
896 	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
897 	PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT ||
898 	    tlsctx->tls_side == TLS_SIDE_SERVER_WORK);
899 	PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
900 	PJDLOG_ASSERT(tlsctx->tls_wait_called);
901 	PJDLOG_ASSERT(fdp == NULL);
902 
903 	if (proto_recv(tlsctx->tls_sock, data, size) == -1)
904 		return (errno);
905 
906 	return (0);
907 }
908 
909 static int
910 tls_descriptor(const void *ctx)
911 {
912 	const struct tls_ctx *tlsctx = ctx;
913 
914 	PJDLOG_ASSERT(tlsctx != NULL);
915 	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
916 
917 	switch (tlsctx->tls_side) {
918 	case TLS_SIDE_CLIENT:
919 	case TLS_SIDE_SERVER_WORK:
920 		PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
921 
922 		return (proto_descriptor(tlsctx->tls_sock));
923 	case TLS_SIDE_SERVER_LISTEN:
924 		PJDLOG_ASSERT(tlsctx->tls_tcp != NULL);
925 
926 		return (proto_descriptor(tlsctx->tls_tcp));
927 	default:
928 		PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side);
929 	}
930 }
931 
932 static bool
933 tcp_address_match(const void *ctx, const char *addr)
934 {
935 	const struct tls_ctx *tlsctx = ctx;
936 
937 	PJDLOG_ASSERT(tlsctx != NULL);
938 	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
939 
940 	return (strcmp(tlsctx->tls_raddr, addr) == 0);
941 }
942 
943 static void
944 tls_local_address(const void *ctx, char *addr, size_t size)
945 {
946 	const struct tls_ctx *tlsctx = ctx;
947 
948 	PJDLOG_ASSERT(tlsctx != NULL);
949 	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
950 	PJDLOG_ASSERT(tlsctx->tls_wait_called);
951 
952 	switch (tlsctx->tls_side) {
953 	case TLS_SIDE_CLIENT:
954 		PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
955 
956 		PJDLOG_VERIFY(strlcpy(addr, "tls://N/A", size) < size);
957 		break;
958 	case TLS_SIDE_SERVER_WORK:
959 		PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
960 
961 		PJDLOG_VERIFY(strlcpy(addr, tlsctx->tls_laddr, size) < size);
962 		break;
963 	case TLS_SIDE_SERVER_LISTEN:
964 		PJDLOG_ASSERT(tlsctx->tls_tcp != NULL);
965 
966 		proto_local_address(tlsctx->tls_tcp, addr, size);
967 		PJDLOG_ASSERT(strncmp(addr, "tcp://", 6) == 0);
968 		/* Replace tcp:// prefix with tls:// */
969 		bcopy("tls://", addr, 6);
970 		break;
971 	default:
972 		PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side);
973 	}
974 }
975 
976 static void
977 tls_remote_address(const void *ctx, char *addr, size_t size)
978 {
979 	const struct tls_ctx *tlsctx = ctx;
980 
981 	PJDLOG_ASSERT(tlsctx != NULL);
982 	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
983 	PJDLOG_ASSERT(tlsctx->tls_wait_called);
984 
985 	switch (tlsctx->tls_side) {
986 	case TLS_SIDE_CLIENT:
987 		PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
988 
989 		PJDLOG_VERIFY(strlcpy(addr, "tls://N/A", size) < size);
990 		break;
991 	case TLS_SIDE_SERVER_WORK:
992 		PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
993 
994 		PJDLOG_VERIFY(strlcpy(addr, tlsctx->tls_raddr, size) < size);
995 		break;
996 	case TLS_SIDE_SERVER_LISTEN:
997 		PJDLOG_ASSERT(tlsctx->tls_tcp != NULL);
998 
999 		proto_remote_address(tlsctx->tls_tcp, addr, size);
1000 		PJDLOG_ASSERT(strncmp(addr, "tcp://", 6) == 0);
1001 		/* Replace tcp:// prefix with tls:// */
1002 		bcopy("tls://", addr, 6);
1003 		break;
1004 	default:
1005 		PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side);
1006 	}
1007 }
1008 
1009 static void
1010 tls_close(void *ctx)
1011 {
1012 	struct tls_ctx *tlsctx = ctx;
1013 
1014 	PJDLOG_ASSERT(tlsctx != NULL);
1015 	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
1016 
1017 	if (tlsctx->tls_sock != NULL) {
1018 		proto_close(tlsctx->tls_sock);
1019 		tlsctx->tls_sock = NULL;
1020 	}
1021 	if (tlsctx->tls_tcp != NULL) {
1022 		proto_close(tlsctx->tls_tcp);
1023 		tlsctx->tls_tcp = NULL;
1024 	}
1025 	tlsctx->tls_side = 0;
1026 	tlsctx->tls_magic = 0;
1027 	free(tlsctx);
1028 }
1029 
1030 static int
1031 tls_exec(int argc, char *argv[])
1032 {
1033 
1034 	PJDLOG_ASSERT(argc > 3);
1035 	PJDLOG_ASSERT(strcmp(argv[0], "tls") == 0);
1036 
1037 	pjdlog_init(atoi(argv[3]) == 0 ? PJDLOG_MODE_SYSLOG : PJDLOG_MODE_STD);
1038 
1039 	if (strcmp(argv[2], "client") == 0) {
1040 		if (argc != 10)
1041 			return (EINVAL);
1042 		tls_exec_client(argv[1], atoi(argv[3]),
1043 		    argv[4][0] == '\0' ? NULL : argv[4], argv[5], argv[6],
1044 		    argv[7], atoi(argv[8]), atoi(argv[9]));
1045 	} else if (strcmp(argv[2], "server") == 0) {
1046 		if (argc != 7)
1047 			return (EINVAL);
1048 		tls_exec_server(argv[1], atoi(argv[3]), argv[4], argv[5],
1049 		    atoi(argv[6]));
1050 	}
1051 	return (EINVAL);
1052 }
1053 
1054 static struct proto tls_proto = {
1055 	.prt_name = "tls",
1056 	.prt_connect = tls_connect,
1057 	.prt_connect_wait = tls_connect_wait,
1058 	.prt_server = tls_server,
1059 	.prt_accept = tls_accept,
1060 	.prt_wrap = tls_wrap,
1061 	.prt_send = tls_send,
1062 	.prt_recv = tls_recv,
1063 	.prt_descriptor = tls_descriptor,
1064 	.prt_address_match = tcp_address_match,
1065 	.prt_local_address = tls_local_address,
1066 	.prt_remote_address = tls_remote_address,
1067 	.prt_close = tls_close,
1068 	.prt_exec = tls_exec
1069 };
1070 
1071 static __constructor void
1072 tls_ctor(void)
1073 {
1074 
1075 	proto_register(&tls_proto, false);
1076 }
1077