xref: /freebsd/tests/sys/kern/ktls_test.c (revision 3332f1b444d4a73238e9f59cca27bfc95fe936bd)
1 /*-
2  * SPDX-License-Identifier: BSD-2-Clause
3  *
4  * Copyright (c) 2021 Netflix Inc.
5  * Written by: John Baldwin <jhb@FreeBSD.org>
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions
9  * are met:
10  * 1. Redistributions of source code must retain the above copyright
11  *    notice, this list of conditions and the following disclaimer.
12  * 2. Redistributions in binary form must reproduce the above copyright
13  *    notice, this list of conditions and the following disclaimer in the
14  *    documentation and/or other materials provided with the distribution.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
20  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
22  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
23  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
24  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
25  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
26  * SUCH DAMAGE.
27  */
28 
29 #include <sys/types.h>
30 #include <sys/endian.h>
31 #include <sys/event.h>
32 #include <sys/ktls.h>
33 #include <sys/socket.h>
34 #include <sys/sysctl.h>
35 #include <netinet/in.h>
36 #include <netinet/tcp.h>
37 #include <crypto/cryptodev.h>
38 #include <assert.h>
39 #include <err.h>
40 #include <fcntl.h>
41 #include <poll.h>
42 #include <stdbool.h>
43 #include <stdlib.h>
44 #include <atf-c.h>
45 
46 #include <openssl/err.h>
47 #include <openssl/evp.h>
48 #include <openssl/hmac.h>
49 
50 static void
51 require_ktls(void)
52 {
53 	size_t len;
54 	bool enable;
55 
56 	len = sizeof(enable);
57 	if (sysctlbyname("kern.ipc.tls.enable", &enable, &len, NULL, 0) == -1) {
58 		if (errno == ENOENT)
59 			atf_tc_skip("kernel does not support TLS offload");
60 		atf_libc_error(errno, "Failed to read kern.ipc.tls.enable");
61 	}
62 
63 	if (!enable)
64 		atf_tc_skip("Kernel TLS is disabled");
65 }
66 
67 #define	ATF_REQUIRE_KTLS()	require_ktls()
68 
69 static char
70 rdigit(void)
71 {
72 	/* ASCII printable values between 0x20 and 0x7e */
73 	return (0x20 + random() % (0x7f - 0x20));
74 }
75 
76 static char *
77 alloc_buffer(size_t len)
78 {
79 	char *buf;
80 	size_t i;
81 
82 	if (len == 0)
83 		return (NULL);
84 	buf = malloc(len);
85 	for (i = 0; i < len; i++)
86 		buf[i] = rdigit();
87 	return (buf);
88 }
89 
90 static bool
91 socketpair_tcp(int *sv)
92 {
93 	struct pollfd pfd;
94 	struct sockaddr_in sin;
95 	socklen_t len;
96 	int as, cs, ls;
97 
98 	ls = socket(PF_INET, SOCK_STREAM, 0);
99 	if (ls == -1) {
100 		warn("socket() for listen");
101 		return (false);
102 	}
103 
104 	memset(&sin, 0, sizeof(sin));
105 	sin.sin_len = sizeof(sin);
106 	sin.sin_family = AF_INET;
107 	sin.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
108 	if (bind(ls, (struct sockaddr *)&sin, sizeof(sin)) == -1) {
109 		warn("bind");
110 		close(ls);
111 		return (false);
112 	}
113 
114 	if (listen(ls, 1) == -1) {
115 		warn("listen");
116 		close(ls);
117 		return (false);
118 	}
119 
120 	len = sizeof(sin);
121 	if (getsockname(ls, (struct sockaddr *)&sin, &len) == -1) {
122 		warn("getsockname");
123 		close(ls);
124 		return (false);
125 	}
126 
127 	cs = socket(PF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0);
128 	if (cs == -1) {
129 		warn("socket() for connect");
130 		close(ls);
131 		return (false);
132 	}
133 
134 	if (connect(cs, (struct sockaddr *)&sin, sizeof(sin)) == -1) {
135 		if (errno != EINPROGRESS) {
136 			warn("connect");
137 			close(ls);
138 			close(cs);
139 			return (false);
140 		}
141 	}
142 
143 	as = accept4(ls, NULL, NULL, SOCK_NONBLOCK);
144 	if (as == -1) {
145 		warn("accept4");
146 		close(ls);
147 		close(cs);
148 		return (false);
149 	}
150 
151 	close(ls);
152 
153 	pfd.fd = cs;
154 	pfd.events = POLLOUT;
155 	pfd.revents = 0;
156 	ATF_REQUIRE(poll(&pfd, 1, INFTIM) == 1);
157 	ATF_REQUIRE(pfd.revents == POLLOUT);
158 
159 	sv[0] = cs;
160 	sv[1] = as;
161 	return (true);
162 }
163 
164 static void
165 fd_set_blocking(int fd)
166 {
167 	int flags;
168 
169 	ATF_REQUIRE((flags = fcntl(fd, F_GETFL)) != -1);
170 	flags &= ~O_NONBLOCK;
171 	ATF_REQUIRE(fcntl(fd, F_SETFL, flags) != -1);
172 }
173 
174 static bool
175 cbc_decrypt(const EVP_CIPHER *cipher, const char *key, const char *iv,
176     const char *input, char *output, size_t size)
177 {
178 	EVP_CIPHER_CTX *ctx;
179 	int outl, total;
180 
181 	ctx = EVP_CIPHER_CTX_new();
182 	if (ctx == NULL) {
183 		warnx("EVP_CIPHER_CTX_new failed: %s",
184 		    ERR_error_string(ERR_get_error(), NULL));
185 		return (false);
186 	}
187 	if (EVP_CipherInit_ex(ctx, cipher, NULL, (const u_char *)key,
188 	    (const u_char *)iv, 0) != 1) {
189 		warnx("EVP_CipherInit_ex failed: %s",
190 		    ERR_error_string(ERR_get_error(), NULL));
191 		EVP_CIPHER_CTX_free(ctx);
192 		return (false);
193 	}
194 	EVP_CIPHER_CTX_set_padding(ctx, 0);
195 	if (EVP_CipherUpdate(ctx, (u_char *)output, &outl,
196 	    (const u_char *)input, size) != 1) {
197 		warnx("EVP_CipherUpdate failed: %s",
198 		    ERR_error_string(ERR_get_error(), NULL));
199 		EVP_CIPHER_CTX_free(ctx);
200 		return (false);
201 	}
202 	total = outl;
203 	if (EVP_CipherFinal_ex(ctx, (u_char *)output + outl, &outl) != 1) {
204 		warnx("EVP_CipherFinal_ex failed: %s",
205 		    ERR_error_string(ERR_get_error(), NULL));
206 		EVP_CIPHER_CTX_free(ctx);
207 		return (false);
208 	}
209 	total += outl;
210 	if ((size_t)total != size) {
211 		warnx("decrypt size mismatch: %zu vs %d", size, total);
212 		EVP_CIPHER_CTX_free(ctx);
213 		return (false);
214 	}
215 	EVP_CIPHER_CTX_free(ctx);
216 	return (true);
217 }
218 
219 static bool
220 verify_hash(const EVP_MD *md, const void *key, size_t key_len, const void *aad,
221     size_t aad_len, const void *buffer, size_t len, const void *digest)
222 {
223 	HMAC_CTX *ctx;
224 	unsigned char digest2[EVP_MAX_MD_SIZE];
225 	u_int digest_len;
226 
227 	ctx = HMAC_CTX_new();
228 	if (ctx == NULL) {
229 		warnx("HMAC_CTX_new failed: %s",
230 		    ERR_error_string(ERR_get_error(), NULL));
231 		return (false);
232 	}
233 	if (HMAC_Init_ex(ctx, key, key_len, md, NULL) != 1) {
234 		warnx("HMAC_Init_ex failed: %s",
235 		    ERR_error_string(ERR_get_error(), NULL));
236 		HMAC_CTX_free(ctx);
237 		return (false);
238 	}
239 	if (HMAC_Update(ctx, aad, aad_len) != 1) {
240 		warnx("HMAC_Update (aad) failed: %s",
241 		    ERR_error_string(ERR_get_error(), NULL));
242 		HMAC_CTX_free(ctx);
243 		return (false);
244 	}
245 	if (HMAC_Update(ctx, buffer, len) != 1) {
246 		warnx("HMAC_Update (payload) failed: %s",
247 		    ERR_error_string(ERR_get_error(), NULL));
248 		HMAC_CTX_free(ctx);
249 		return (false);
250 	}
251 	if (HMAC_Final(ctx, digest2, &digest_len) != 1) {
252 		warnx("HMAC_Final failed: %s",
253 		    ERR_error_string(ERR_get_error(), NULL));
254 		HMAC_CTX_free(ctx);
255 		return (false);
256 	}
257 	HMAC_CTX_free(ctx);
258 	if (memcmp(digest, digest2, digest_len) != 0) {
259 		warnx("HMAC mismatch");
260 		return (false);
261 	}
262 	return (true);
263 }
264 
265 static bool
266 aead_decrypt(const EVP_CIPHER *cipher, const char *key, const char *nonce,
267     const void *aad, size_t aad_len, const char *input, char *output,
268     size_t size, const char *tag, size_t tag_len)
269 {
270 	EVP_CIPHER_CTX *ctx;
271 	int outl, total;
272 	bool valid;
273 
274 	ctx = EVP_CIPHER_CTX_new();
275 	if (ctx == NULL) {
276 		warnx("EVP_CIPHER_CTX_new failed: %s",
277 		    ERR_error_string(ERR_get_error(), NULL));
278 		return (false);
279 	}
280 	if (EVP_DecryptInit_ex(ctx, cipher, NULL, (const u_char *)key,
281 	    (const u_char *)nonce) != 1) {
282 		warnx("EVP_DecryptInit_ex failed: %s",
283 		    ERR_error_string(ERR_get_error(), NULL));
284 		EVP_CIPHER_CTX_free(ctx);
285 		return (false);
286 	}
287 	EVP_CIPHER_CTX_set_padding(ctx, 0);
288 	if (aad != NULL) {
289 		if (EVP_DecryptUpdate(ctx, NULL, &outl, (const u_char *)aad,
290 		    aad_len) != 1) {
291 			warnx("EVP_DecryptUpdate for AAD failed: %s",
292 			    ERR_error_string(ERR_get_error(), NULL));
293 			EVP_CIPHER_CTX_free(ctx);
294 			return (false);
295 		}
296 	}
297 	if (EVP_DecryptUpdate(ctx, (u_char *)output, &outl,
298 	    (const u_char *)input, size) != 1) {
299 		warnx("EVP_DecryptUpdate failed: %s",
300 		    ERR_error_string(ERR_get_error(), NULL));
301 		EVP_CIPHER_CTX_free(ctx);
302 		return (false);
303 	}
304 	total = outl;
305 	if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_SET_TAG, tag_len,
306 	    __DECONST(char *, tag)) != 1) {
307 		warnx("EVP_CIPHER_CTX_ctrl(EVP_CTRL_AEAD_SET_TAG) failed: %s",
308 		    ERR_error_string(ERR_get_error(), NULL));
309 		EVP_CIPHER_CTX_free(ctx);
310 		return (false);
311 	}
312 	valid = (EVP_DecryptFinal_ex(ctx, (u_char *)output + outl, &outl) == 1);
313 	total += outl;
314 	if ((size_t)total != size) {
315 		warnx("decrypt size mismatch: %zu vs %d", size, total);
316 		EVP_CIPHER_CTX_free(ctx);
317 		return (false);
318 	}
319 	if (!valid)
320 		warnx("tag mismatch");
321 	EVP_CIPHER_CTX_free(ctx);
322 	return (valid);
323 }
324 
325 static void
326 build_tls_enable(int cipher_alg, size_t cipher_key_len, int auth_alg,
327     int minor, uint64_t seqno, struct tls_enable *en)
328 {
329 	u_int auth_key_len, iv_len;
330 
331 	memset(en, 0, sizeof(*en));
332 
333 	switch (cipher_alg) {
334 	case CRYPTO_AES_CBC:
335 		if (minor == TLS_MINOR_VER_ZERO)
336 			iv_len = AES_BLOCK_LEN;
337 		else
338 			iv_len = 0;
339 		break;
340 	case CRYPTO_AES_NIST_GCM_16:
341 		if (minor == TLS_MINOR_VER_TWO)
342 			iv_len = TLS_AEAD_GCM_LEN;
343 		else
344 			iv_len = TLS_1_3_GCM_IV_LEN;
345 		break;
346 	case CRYPTO_CHACHA20_POLY1305:
347 		iv_len = TLS_CHACHA20_IV_LEN;
348 		break;
349 	default:
350 		iv_len = 0;
351 		break;
352 	}
353 	switch (auth_alg) {
354 	case CRYPTO_SHA1_HMAC:
355 		auth_key_len = SHA1_HASH_LEN;
356 		break;
357 	case CRYPTO_SHA2_256_HMAC:
358 		auth_key_len = SHA2_256_HASH_LEN;
359 		break;
360 	case CRYPTO_SHA2_384_HMAC:
361 		auth_key_len = SHA2_384_HASH_LEN;
362 		break;
363 	default:
364 		auth_key_len = 0;
365 		break;
366 	}
367 	en->cipher_key = alloc_buffer(cipher_key_len);
368 	en->iv = alloc_buffer(iv_len);
369 	en->auth_key = alloc_buffer(auth_key_len);
370 	en->cipher_algorithm = cipher_alg;
371 	en->cipher_key_len = cipher_key_len;
372 	en->iv_len = iv_len;
373 	en->auth_algorithm = auth_alg;
374 	en->auth_key_len = auth_key_len;
375 	en->tls_vmajor = TLS_MAJOR_VER_ONE;
376 	en->tls_vminor = minor;
377 	be64enc(en->rec_seq, seqno);
378 }
379 
380 static void
381 free_tls_enable(struct tls_enable *en)
382 {
383 	free(__DECONST(void *, en->cipher_key));
384 	free(__DECONST(void *, en->iv));
385 	free(__DECONST(void *, en->auth_key));
386 }
387 
388 static const EVP_CIPHER *
389 tls_EVP_CIPHER(const struct tls_enable *en)
390 {
391 	switch (en->cipher_algorithm) {
392 	case CRYPTO_AES_CBC:
393 		switch (en->cipher_key_len) {
394 		case 128 / 8:
395 			return (EVP_aes_128_cbc());
396 		case 256 / 8:
397 			return (EVP_aes_256_cbc());
398 		default:
399 			return (NULL);
400 		}
401 		break;
402 	case CRYPTO_AES_NIST_GCM_16:
403 		switch (en->cipher_key_len) {
404 		case 128 / 8:
405 			return (EVP_aes_128_gcm());
406 		case 256 / 8:
407 			return (EVP_aes_256_gcm());
408 		default:
409 			return (NULL);
410 		}
411 		break;
412 	case CRYPTO_CHACHA20_POLY1305:
413 		return (EVP_chacha20_poly1305());
414 	default:
415 		return (NULL);
416 	}
417 }
418 
419 static const EVP_MD *
420 tls_EVP_MD(const struct tls_enable *en)
421 {
422 	switch (en->auth_algorithm) {
423 	case CRYPTO_SHA1_HMAC:
424 		return (EVP_sha1());
425 	case CRYPTO_SHA2_256_HMAC:
426 		return (EVP_sha256());
427 	case CRYPTO_SHA2_384_HMAC:
428 		return (EVP_sha384());
429 	default:
430 		return (NULL);
431 	}
432 }
433 
434 static size_t
435 tls_header_len(struct tls_enable *en)
436 {
437 	size_t len;
438 
439 	len = sizeof(struct tls_record_layer);
440 	switch (en->cipher_algorithm) {
441 	case CRYPTO_AES_CBC:
442 		if (en->tls_vminor != TLS_MINOR_VER_ZERO)
443 			len += AES_BLOCK_LEN;
444 		return (len);
445 	case CRYPTO_AES_NIST_GCM_16:
446 		if (en->tls_vminor == TLS_MINOR_VER_TWO)
447 			len += sizeof(uint64_t);
448 		return (len);
449 	case CRYPTO_CHACHA20_POLY1305:
450 		return (len);
451 	default:
452 		return (0);
453 	}
454 }
455 
456 static size_t
457 tls_mac_len(struct tls_enable *en)
458 {
459 	switch (en->cipher_algorithm) {
460 	case CRYPTO_AES_CBC:
461 		switch (en->auth_algorithm) {
462 		case CRYPTO_SHA1_HMAC:
463 			return (SHA1_HASH_LEN);
464 		case CRYPTO_SHA2_256_HMAC:
465 			return (SHA2_256_HASH_LEN);
466 		case CRYPTO_SHA2_384_HMAC:
467 			return (SHA2_384_HASH_LEN);
468 		default:
469 			return (0);
470 		}
471 	case CRYPTO_AES_NIST_GCM_16:
472 		return (AES_GMAC_HASH_LEN);
473 	case CRYPTO_CHACHA20_POLY1305:
474 		return (POLY1305_HASH_LEN);
475 	default:
476 		return (0);
477 	}
478 }
479 
480 /* Includes maximum padding for MTE. */
481 static size_t
482 tls_trailer_len(struct tls_enable *en)
483 {
484 	size_t len;
485 
486 	len = tls_mac_len(en);
487 	if (en->cipher_algorithm == CRYPTO_AES_CBC)
488 		len += AES_BLOCK_LEN;
489 	if (en->tls_vminor == TLS_MINOR_VER_THREE)
490 		len++;
491 	return (len);
492 }
493 
494 /* 'len' is the length of the payload application data. */
495 static void
496 tls_mte_aad(struct tls_enable *en, size_t len,
497     const struct tls_record_layer *hdr, uint64_t seqno, struct tls_mac_data *ad)
498 {
499 	ad->seq = htobe64(seqno);
500 	ad->type = hdr->tls_type;
501 	ad->tls_vmajor = hdr->tls_vmajor;
502 	ad->tls_vminor = hdr->tls_vminor;
503 	ad->tls_length = htons(len);
504 }
505 
506 static void
507 tls_12_aead_aad(struct tls_enable *en, size_t len,
508     const struct tls_record_layer *hdr, uint64_t seqno,
509     struct tls_aead_data *ad)
510 {
511 	ad->seq = htobe64(seqno);
512 	ad->type = hdr->tls_type;
513 	ad->tls_vmajor = hdr->tls_vmajor;
514 	ad->tls_vminor = hdr->tls_vminor;
515 	ad->tls_length = htons(len);
516 }
517 
518 static void
519 tls_13_aad(struct tls_enable *en, const struct tls_record_layer *hdr,
520     uint64_t seqno, struct tls_aead_data_13 *ad)
521 {
522 	ad->type = hdr->tls_type;
523 	ad->tls_vmajor = hdr->tls_vmajor;
524 	ad->tls_vminor = hdr->tls_vminor;
525 	ad->tls_length = hdr->tls_length;
526 }
527 
528 static void
529 tls_12_gcm_nonce(struct tls_enable *en, const struct tls_record_layer *hdr,
530     char *nonce)
531 {
532 	memcpy(nonce, en->iv, TLS_AEAD_GCM_LEN);
533 	memcpy(nonce + TLS_AEAD_GCM_LEN, hdr + 1, sizeof(uint64_t));
534 }
535 
536 static void
537 tls_13_nonce(struct tls_enable *en, uint64_t seqno, char *nonce)
538 {
539 	static_assert(TLS_1_3_GCM_IV_LEN == TLS_CHACHA20_IV_LEN,
540 	    "TLS 1.3 nonce length mismatch");
541 	memcpy(nonce, en->iv, TLS_1_3_GCM_IV_LEN);
542 	*(uint64_t *)(nonce + 4) ^= htobe64(seqno);
543 }
544 
545 /*
546  * Decrypt a TLS record 'len' bytes long at 'src' and store the result at
547  * 'dst'.  If the TLS record header length doesn't match or 'dst' doesn't
548  * have sufficient room ('avail'), fail the test.
549  */
550 static size_t
551 decrypt_tls_aes_cbc_mte(struct tls_enable *en, uint64_t seqno, const void *src,
552     size_t len, void *dst, size_t avail, uint8_t *record_type)
553 {
554 	const struct tls_record_layer *hdr;
555 	struct tls_mac_data aad;
556 	const char *iv;
557 	char *buf;
558 	size_t hdr_len, mac_len, payload_len;
559 	int padding;
560 
561 	hdr = src;
562 	hdr_len = tls_header_len(en);
563 	mac_len = tls_mac_len(en);
564 	ATF_REQUIRE(hdr->tls_vmajor == TLS_MAJOR_VER_ONE);
565 	ATF_REQUIRE(hdr->tls_vminor == en->tls_vminor);
566 
567 	/* First, decrypt the outer payload into a temporary buffer. */
568 	payload_len = len - hdr_len;
569 	buf = malloc(payload_len);
570 	if (en->tls_vminor == TLS_MINOR_VER_ZERO)
571 		iv = en->iv;
572 	else
573 		iv = (void *)(hdr + 1);
574 	ATF_REQUIRE(cbc_decrypt(tls_EVP_CIPHER(en), en->cipher_key, iv,
575 	    (const u_char *)src + hdr_len, buf, payload_len));
576 
577 	/*
578 	 * Copy the last encrypted block to use as the IV for the next
579 	 * record for TLS 1.0.
580 	 */
581 	if (en->tls_vminor == TLS_MINOR_VER_ZERO)
582 		memcpy(__DECONST(uint8_t *, en->iv), (const u_char *)src +
583 		    (len - AES_BLOCK_LEN), AES_BLOCK_LEN);
584 
585 	/*
586 	 * Verify trailing padding and strip.
587 	 *
588 	 * The kernel always generates the smallest amount of padding.
589 	 */
590 	padding = buf[payload_len - 1] + 1;
591 	ATF_REQUIRE(padding > 0 && padding <= AES_BLOCK_LEN);
592 	ATF_REQUIRE(payload_len >= mac_len + padding);
593 	payload_len -= padding;
594 
595 	/* Verify HMAC. */
596 	payload_len -= mac_len;
597 	tls_mte_aad(en, payload_len, hdr, seqno, &aad);
598 	ATF_REQUIRE(verify_hash(tls_EVP_MD(en), en->auth_key, en->auth_key_len,
599 	    &aad, sizeof(aad), buf, payload_len, buf + payload_len));
600 
601 	ATF_REQUIRE(payload_len <= avail);
602 	memcpy(dst, buf, payload_len);
603 	*record_type = hdr->tls_type;
604 	return (payload_len);
605 }
606 
607 static size_t
608 decrypt_tls_12_aead(struct tls_enable *en, uint64_t seqno, const void *src,
609     size_t len, void *dst, uint8_t *record_type)
610 {
611 	const struct tls_record_layer *hdr;
612 	struct tls_aead_data aad;
613 	char nonce[12];
614 	size_t hdr_len, mac_len, payload_len;
615 
616 	hdr = src;
617 
618 	hdr_len = tls_header_len(en);
619 	mac_len = tls_mac_len(en);
620 	payload_len = len - (hdr_len + mac_len);
621 	ATF_REQUIRE(hdr->tls_vmajor == TLS_MAJOR_VER_ONE);
622 	ATF_REQUIRE(hdr->tls_vminor == TLS_MINOR_VER_TWO);
623 
624 	tls_12_aead_aad(en, payload_len, hdr, seqno, &aad);
625 	if (en->cipher_algorithm == CRYPTO_AES_NIST_GCM_16)
626 		tls_12_gcm_nonce(en, hdr, nonce);
627 	else
628 		tls_13_nonce(en, seqno, nonce);
629 
630 	ATF_REQUIRE(aead_decrypt(tls_EVP_CIPHER(en), en->cipher_key, nonce,
631 	    &aad, sizeof(aad), (const char *)src + hdr_len, dst, payload_len,
632 	    (const char *)src + hdr_len + payload_len, mac_len));
633 
634 	*record_type = hdr->tls_type;
635 	return (payload_len);
636 }
637 
638 static size_t
639 decrypt_tls_13_aead(struct tls_enable *en, uint64_t seqno, const void *src,
640     size_t len, void *dst, uint8_t *record_type)
641 {
642 	const struct tls_record_layer *hdr;
643 	struct tls_aead_data_13 aad;
644 	char nonce[12];
645 	char *buf;
646 	size_t hdr_len, mac_len, payload_len;
647 
648 	hdr = src;
649 
650 	hdr_len = tls_header_len(en);
651 	mac_len = tls_mac_len(en);
652 	payload_len = len - (hdr_len + mac_len);
653 	ATF_REQUIRE(payload_len >= 1);
654 	ATF_REQUIRE(hdr->tls_type == TLS_RLTYPE_APP);
655 	ATF_REQUIRE(hdr->tls_vmajor == TLS_MAJOR_VER_ONE);
656 	ATF_REQUIRE(hdr->tls_vminor == TLS_MINOR_VER_TWO);
657 
658 	tls_13_aad(en, hdr, seqno, &aad);
659 	tls_13_nonce(en, seqno, nonce);
660 
661 	/*
662 	 * Have to use a temporary buffer for the output due to the
663 	 * record type as the last byte of the trailer.
664 	 */
665 	buf = malloc(payload_len);
666 
667 	ATF_REQUIRE(aead_decrypt(tls_EVP_CIPHER(en), en->cipher_key, nonce,
668 	    &aad, sizeof(aad), (const char *)src + hdr_len, buf, payload_len,
669 	    (const char *)src + hdr_len + payload_len, mac_len));
670 
671 	/* Trim record type. */
672 	*record_type = buf[payload_len - 1];
673 	payload_len--;
674 
675 	memcpy(dst, buf, payload_len);
676 	free(buf);
677 
678 	return (payload_len);
679 }
680 
681 static size_t
682 decrypt_tls_aead(struct tls_enable *en, uint64_t seqno, const void *src,
683     size_t len, void *dst, size_t avail, uint8_t *record_type)
684 {
685 	const struct tls_record_layer *hdr;
686 	size_t payload_len;
687 
688 	hdr = src;
689 	ATF_REQUIRE(ntohs(hdr->tls_length) + sizeof(*hdr) == len);
690 
691 	payload_len = len - (tls_header_len(en) + tls_trailer_len(en));
692 	ATF_REQUIRE(payload_len <= avail);
693 
694 	if (en->tls_vminor == TLS_MINOR_VER_TWO) {
695 		ATF_REQUIRE(decrypt_tls_12_aead(en, seqno, src, len, dst,
696 		    record_type) == payload_len);
697 	} else {
698 		ATF_REQUIRE(decrypt_tls_13_aead(en, seqno, src, len, dst,
699 		    record_type) == payload_len);
700 	}
701 
702 	return (payload_len);
703 }
704 
705 static size_t
706 decrypt_tls_record(struct tls_enable *en, uint64_t seqno, const void *src,
707     size_t len, void *dst, size_t avail, uint8_t *record_type)
708 {
709 	if (en->cipher_algorithm == CRYPTO_AES_CBC)
710 		return (decrypt_tls_aes_cbc_mte(en, seqno, src, len, dst, avail,
711 		    record_type));
712 	else
713 		return (decrypt_tls_aead(en, seqno, src, len, dst, avail,
714 		    record_type));
715 }
716 
717 static void
718 test_ktls_transmit_app_data(struct tls_enable *en, uint64_t seqno, size_t len)
719 {
720 	struct kevent ev;
721 	struct tls_record_layer *hdr;
722 	char *plaintext, *decrypted, *outbuf;
723 	size_t decrypted_len, outbuf_len, outbuf_cap, record_len, written;
724 	ssize_t rv;
725 	int kq, sockets[2];
726 	uint8_t record_type;
727 
728 	plaintext = alloc_buffer(len);
729 	decrypted = malloc(len);
730 	outbuf_cap = tls_header_len(en) + TLS_MAX_MSG_SIZE_V10_2 +
731 	    tls_trailer_len(en);
732 	outbuf = malloc(outbuf_cap);
733 	hdr = (struct tls_record_layer *)outbuf;
734 
735 	ATF_REQUIRE((kq = kqueue()) != -1);
736 
737 	ATF_REQUIRE_MSG(socketpair_tcp(sockets), "failed to create sockets");
738 
739 	ATF_REQUIRE(setsockopt(sockets[1], IPPROTO_TCP, TCP_TXTLS_ENABLE, en,
740 	    sizeof(*en)) == 0);
741 
742 	EV_SET(&ev, sockets[0], EVFILT_READ, EV_ADD, 0, 0, NULL);
743 	ATF_REQUIRE(kevent(kq, &ev, 1, NULL, 0, NULL) == 0);
744 	EV_SET(&ev, sockets[1], EVFILT_WRITE, EV_ADD, 0, 0, NULL);
745 	ATF_REQUIRE(kevent(kq, &ev, 1, NULL, 0, NULL) == 0);
746 
747 	decrypted_len = 0;
748 	outbuf_len = 0;
749 	written = 0;
750 
751 	while (decrypted_len != len) {
752 		ATF_REQUIRE(kevent(kq, NULL, 0, &ev, 1, NULL) == 1);
753 
754 		switch (ev.filter) {
755 		case EVFILT_WRITE:
756 			/* Try to write any remaining data. */
757 			rv = write(ev.ident, plaintext + written,
758 			    len - written);
759 			ATF_REQUIRE_MSG(rv > 0,
760 			    "failed to write to socket");
761 			written += rv;
762 			if (written == len) {
763 				ev.flags = EV_DISABLE;
764 				ATF_REQUIRE(kevent(kq, &ev, 1, NULL, 0,
765 				    NULL) == 0);
766 			}
767 			break;
768 
769 		case EVFILT_READ:
770 			ATF_REQUIRE((ev.flags & EV_EOF) == 0);
771 
772 			/*
773 			 * Try to read data for the next TLS record
774 			 * into outbuf.  Start by reading the header
775 			 * to determine how much additional data to
776 			 * read.
777 			 */
778 			if (outbuf_len < sizeof(struct tls_record_layer)) {
779 				rv = read(ev.ident, outbuf + outbuf_len,
780 				    sizeof(struct tls_record_layer) -
781 				    outbuf_len);
782 				ATF_REQUIRE_MSG(rv > 0,
783 				    "failed to read from socket");
784 				outbuf_len += rv;
785 			}
786 
787 			if (outbuf_len < sizeof(struct tls_record_layer))
788 				break;
789 
790 			record_len = sizeof(struct tls_record_layer) +
791 			    ntohs(hdr->tls_length);
792 			assert(record_len <= outbuf_cap);
793 			assert(record_len > outbuf_len);
794 			rv = read(ev.ident, outbuf + outbuf_len,
795 			    record_len - outbuf_len);
796 			if (rv == -1 && errno == EAGAIN)
797 				break;
798 			ATF_REQUIRE_MSG(rv > 0, "failed to read from socket");
799 
800 			outbuf_len += rv;
801 			if (outbuf_len == record_len) {
802 				decrypted_len += decrypt_tls_record(en, seqno,
803 				    outbuf, outbuf_len,
804 				    decrypted + decrypted_len,
805 				    len - decrypted_len, &record_type);
806 				ATF_REQUIRE(record_type == TLS_RLTYPE_APP);
807 
808 				seqno++;
809 				outbuf_len = 0;
810 			}
811 			break;
812 		}
813 	}
814 
815 	ATF_REQUIRE_MSG(written == decrypted_len,
816 	    "read %zu decrypted bytes, but wrote %zu", decrypted_len, written);
817 
818 	ATF_REQUIRE(memcmp(plaintext, decrypted, len) == 0);
819 
820 	free(outbuf);
821 	free(decrypted);
822 	free(plaintext);
823 
824 	close(sockets[1]);
825 	close(sockets[0]);
826 	close(kq);
827 }
828 
829 static void
830 ktls_send_control_message(int fd, uint8_t type, void *data, size_t len)
831 {
832 	struct msghdr msg;
833 	struct cmsghdr *cmsg;
834 	char cbuf[CMSG_SPACE(sizeof(type))];
835 	struct iovec iov;
836 
837 	memset(&msg, 0, sizeof(msg));
838 
839 	msg.msg_control = cbuf;
840 	msg.msg_controllen = sizeof(cbuf);
841 	cmsg = CMSG_FIRSTHDR(&msg);
842 	cmsg->cmsg_level = IPPROTO_TCP;
843 	cmsg->cmsg_type = TLS_SET_RECORD_TYPE;
844 	cmsg->cmsg_len = CMSG_LEN(sizeof(type));
845 	*(uint8_t *)CMSG_DATA(cmsg) = type;
846 
847 	iov.iov_base = data;
848 	iov.iov_len = len;
849 	msg.msg_iov = &iov;
850 	msg.msg_iovlen = 1;
851 
852 	ATF_REQUIRE(sendmsg(fd, &msg, 0) == (ssize_t)len);
853 }
854 
855 static void
856 test_ktls_transmit_control(struct tls_enable *en, uint64_t seqno, uint8_t type,
857     size_t len)
858 {
859 	struct tls_record_layer *hdr;
860 	char *plaintext, *decrypted, *outbuf;
861 	size_t outbuf_cap, payload_len, record_len;
862 	ssize_t rv;
863 	int sockets[2];
864 	uint8_t record_type;
865 
866 	ATF_REQUIRE(len <= TLS_MAX_MSG_SIZE_V10_2);
867 
868 	plaintext = alloc_buffer(len);
869 	decrypted = malloc(len);
870 	outbuf_cap = tls_header_len(en) + len + tls_trailer_len(en);
871 	outbuf = malloc(outbuf_cap);
872 	hdr = (struct tls_record_layer *)outbuf;
873 
874 	ATF_REQUIRE_MSG(socketpair_tcp(sockets), "failed to create sockets");
875 
876 	ATF_REQUIRE(setsockopt(sockets[1], IPPROTO_TCP, TCP_TXTLS_ENABLE, en,
877 	    sizeof(*en)) == 0);
878 
879 	fd_set_blocking(sockets[0]);
880 	fd_set_blocking(sockets[1]);
881 
882 	ktls_send_control_message(sockets[1], type, plaintext, len);
883 
884 	/*
885 	 * First read the header to determine how much additional data
886 	 * to read.
887 	 */
888 	rv = read(sockets[0], outbuf, sizeof(struct tls_record_layer));
889 	ATF_REQUIRE(rv == sizeof(struct tls_record_layer));
890 	payload_len = ntohs(hdr->tls_length);
891 	record_len = payload_len + sizeof(struct tls_record_layer);
892 	assert(record_len <= outbuf_cap);
893 	rv = read(sockets[0], outbuf + sizeof(struct tls_record_layer),
894 	    payload_len);
895 	ATF_REQUIRE(rv == (ssize_t)payload_len);
896 
897 	rv = decrypt_tls_record(en, seqno, outbuf, record_len, decrypted, len,
898 	    &record_type);
899 
900 	ATF_REQUIRE_MSG((ssize_t)len == rv,
901 	    "read %zd decrypted bytes, but wrote %zu", rv, len);
902 	ATF_REQUIRE(record_type == type);
903 
904 	ATF_REQUIRE(memcmp(plaintext, decrypted, len) == 0);
905 
906 	free(outbuf);
907 	free(decrypted);
908 	free(plaintext);
909 
910 	close(sockets[1]);
911 	close(sockets[0]);
912 }
913 
914 #define	AES_CBC_TESTS(M)						\
915 	M(aes128_cbc_1_0_sha1, CRYPTO_AES_CBC, 128 / 8,			\
916 	    CRYPTO_SHA1_HMAC, TLS_MINOR_VER_ZERO)			\
917 	M(aes256_cbc_1_0_sha1, CRYPTO_AES_CBC, 256 / 8,			\
918 	    CRYPTO_SHA1_HMAC, TLS_MINOR_VER_ZERO)			\
919 	M(aes128_cbc_1_1_sha1, CRYPTO_AES_CBC, 128 / 8,			\
920 	    CRYPTO_SHA1_HMAC, TLS_MINOR_VER_ONE)			\
921 	M(aes256_cbc_1_1_sha1, CRYPTO_AES_CBC, 256 / 8,			\
922 	    CRYPTO_SHA1_HMAC, TLS_MINOR_VER_ONE)			\
923 	M(aes128_cbc_1_2_sha1, CRYPTO_AES_CBC, 128 / 8,			\
924 	    CRYPTO_SHA1_HMAC, TLS_MINOR_VER_TWO)			\
925 	M(aes256_cbc_1_2_sha1, CRYPTO_AES_CBC, 256 / 8,			\
926 	    CRYPTO_SHA1_HMAC, TLS_MINOR_VER_TWO)			\
927 	M(aes128_cbc_1_2_sha256, CRYPTO_AES_CBC, 128 / 8,		\
928 	    CRYPTO_SHA2_256_HMAC, TLS_MINOR_VER_TWO)			\
929 	M(aes256_cbc_1_2_sha256, CRYPTO_AES_CBC, 256 / 8,		\
930 	    CRYPTO_SHA2_256_HMAC, TLS_MINOR_VER_TWO)			\
931 	M(aes128_cbc_1_2_sha384, CRYPTO_AES_CBC, 128 / 8,		\
932 	    CRYPTO_SHA2_384_HMAC, TLS_MINOR_VER_TWO)			\
933 	M(aes256_cbc_1_2_sha384, CRYPTO_AES_CBC, 256 / 8,		\
934 	    CRYPTO_SHA2_384_HMAC, TLS_MINOR_VER_TWO)			\
935 
936 #define AES_GCM_TESTS(M)						\
937 	M(aes128_gcm_1_2, CRYPTO_AES_NIST_GCM_16, 128 / 8, 0,		\
938 	    TLS_MINOR_VER_TWO)						\
939 	M(aes256_gcm_1_2, CRYPTO_AES_NIST_GCM_16, 256 / 8, 0,		\
940 	    TLS_MINOR_VER_TWO)						\
941 	M(aes128_gcm_1_3, CRYPTO_AES_NIST_GCM_16, 128 / 8, 0,		\
942 	    TLS_MINOR_VER_THREE)					\
943 	M(aes256_gcm_1_3, CRYPTO_AES_NIST_GCM_16, 256 / 8, 0,		\
944 	    TLS_MINOR_VER_THREE)
945 
946 #define CHACHA20_TESTS(M)						\
947 	M(chacha20_poly1305_1_2, CRYPTO_CHACHA20_POLY1305, 256 / 8, 0,	\
948 	    TLS_MINOR_VER_TWO)						\
949 	M(chacha20_poly1305_1_3, CRYPTO_CHACHA20_POLY1305, 256 / 8, 0,	\
950 	    TLS_MINOR_VER_THREE)
951 
952 #define GEN_TRANSMIT_APP_DATA_TEST(cipher_name, cipher_alg, key_size,	\
953 	    auth_alg, minor, name, len)					\
954 ATF_TC_WITHOUT_HEAD(ktls_transmit_##cipher_name##_##name);		\
955 ATF_TC_BODY(ktls_transmit_##cipher_name##_##name, tc)			\
956 {									\
957 	struct tls_enable en;						\
958 	uint64_t seqno;							\
959 									\
960 	ATF_REQUIRE_KTLS();						\
961 	seqno = random();						\
962 	build_tls_enable(cipher_alg, key_size, auth_alg, minor, seqno,	\
963 	    &en);							\
964 	test_ktls_transmit_app_data(&en, seqno, len);			\
965 	free_tls_enable(&en);						\
966 }
967 
968 #define ADD_TRANSMIT_APP_DATA_TEST(cipher_name, cipher_alg, key_size,	\
969 	    auth_alg, minor, name)					\
970 	ATF_TP_ADD_TC(tp, ktls_transmit_##cipher_name##_##name);
971 
972 #define GEN_TRANSMIT_CONTROL_TEST(cipher_name, cipher_alg, key_size,	\
973 	    auth_alg, minor, type, len)					\
974 ATF_TC_WITHOUT_HEAD(ktls_transmit_##cipher_name##_control);		\
975 ATF_TC_BODY(ktls_transmit_##cipher_name##_control, tc)			\
976 {									\
977 	struct tls_enable en;						\
978 	uint64_t seqno;							\
979 									\
980 	ATF_REQUIRE_KTLS();						\
981 	seqno = random();						\
982 	build_tls_enable(cipher_alg, key_size, auth_alg, minor,	seqno,	\
983 	    &en);							\
984 	test_ktls_transmit_control(&en, seqno, type, len);		\
985 	free_tls_enable(&en);						\
986 }
987 
988 #define ADD_TRANSMIT_CONTROL_TEST(cipher_name, cipher_alg, key_size,	\
989 	    auth_alg, minor)						\
990 	ATF_TP_ADD_TC(tp, ktls_transmit_##cipher_name##_control);
991 
992 #define GEN_TRANSMIT_TESTS(cipher_name, cipher_alg, key_size, auth_alg,	\
993 	    minor)							\
994 	GEN_TRANSMIT_APP_DATA_TEST(cipher_name, cipher_alg, key_size,	\
995 	    auth_alg, minor, short, 64)					\
996 	GEN_TRANSMIT_APP_DATA_TEST(cipher_name, cipher_alg, key_size,	\
997 	    auth_alg, minor, long, 64 * 1024)				\
998 	GEN_TRANSMIT_CONTROL_TEST(cipher_name, cipher_alg, key_size,	\
999 	    auth_alg, minor, 0x21 /* Alert */, 32)
1000 
1001 #define ADD_TRANSMIT_TESTS(cipher_name, cipher_alg, key_size, auth_alg,	\
1002 	    minor)							\
1003 	ADD_TRANSMIT_APP_DATA_TEST(cipher_name, cipher_alg, key_size,	\
1004 	    auth_alg, minor, short)					\
1005 	ADD_TRANSMIT_APP_DATA_TEST(cipher_name, cipher_alg, key_size,	\
1006 	    auth_alg, minor, long)					\
1007 	ADD_TRANSMIT_CONTROL_TEST(cipher_name, cipher_alg, key_size,	\
1008 	    auth_alg, minor)
1009 
1010 /*
1011  * For each supported cipher suite, run three transmit tests:
1012  *
1013  * - a short test which sends 64 bytes of application data (likely as
1014  *   a single TLS record)
1015  *
1016  * - a long test which sends 64KB of application data (split across
1017  *   multiple TLS records)
1018  *
1019  * - a control test which sends a single record with a specific
1020  *   content type via sendmsg()
1021  */
1022 AES_CBC_TESTS(GEN_TRANSMIT_TESTS);
1023 AES_GCM_TESTS(GEN_TRANSMIT_TESTS);
1024 CHACHA20_TESTS(GEN_TRANSMIT_TESTS);
1025 
1026 ATF_TP_ADD_TCS(tp)
1027 {
1028 	AES_CBC_TESTS(ADD_TRANSMIT_TESTS);
1029 	AES_GCM_TESTS(ADD_TRANSMIT_TESTS);
1030 	CHACHA20_TESTS(ADD_TRANSMIT_TESTS);
1031 
1032 	return (atf_no_error());
1033 }
1034