xref: /linux/tools/testing/selftests/net/tls.c (revision d3c9510dc900e9ff3ea330189c0465c9f00fba18)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 #define _GNU_SOURCE
4 
5 #include <arpa/inet.h>
6 #include <errno.h>
7 #include <error.h>
8 #include <fcntl.h>
9 #include <poll.h>
10 #include <stdio.h>
11 #include <stdlib.h>
12 #include <unistd.h>
13 
14 #include <linux/tls.h>
15 #include <linux/tcp.h>
16 #include <linux/socket.h>
17 
18 #include <sys/epoll.h>
19 #include <sys/types.h>
20 #include <sys/sendfile.h>
21 #include <sys/socket.h>
22 #include <sys/stat.h>
23 
24 #include "../kselftest_harness.h"
25 
26 #define TLS_PAYLOAD_MAX_LEN 16384
27 #define SOL_TLS 282
28 
29 static int fips_enabled;
30 
31 struct tls_crypto_info_keys {
32 	union {
33 		struct tls_crypto_info crypto_info;
34 		struct tls12_crypto_info_aes_gcm_128 aes128;
35 		struct tls12_crypto_info_chacha20_poly1305 chacha20;
36 		struct tls12_crypto_info_sm4_gcm sm4gcm;
37 		struct tls12_crypto_info_sm4_ccm sm4ccm;
38 		struct tls12_crypto_info_aes_ccm_128 aesccm128;
39 		struct tls12_crypto_info_aes_gcm_256 aesgcm256;
40 		struct tls12_crypto_info_aria_gcm_128 ariagcm128;
41 		struct tls12_crypto_info_aria_gcm_256 ariagcm256;
42 	};
43 	size_t len;
44 };
45 
46 static void tls_crypto_info_init(uint16_t tls_version, uint16_t cipher_type,
47 				 struct tls_crypto_info_keys *tls12,
48 				 char key_generation)
49 {
50 	memset(tls12, key_generation, sizeof(*tls12));
51 	memset(tls12, 0, sizeof(struct tls_crypto_info));
52 
53 	switch (cipher_type) {
54 	case TLS_CIPHER_CHACHA20_POLY1305:
55 		tls12->len = sizeof(struct tls12_crypto_info_chacha20_poly1305);
56 		tls12->chacha20.info.version = tls_version;
57 		tls12->chacha20.info.cipher_type = cipher_type;
58 		break;
59 	case TLS_CIPHER_AES_GCM_128:
60 		tls12->len = sizeof(struct tls12_crypto_info_aes_gcm_128);
61 		tls12->aes128.info.version = tls_version;
62 		tls12->aes128.info.cipher_type = cipher_type;
63 		break;
64 	case TLS_CIPHER_SM4_GCM:
65 		tls12->len = sizeof(struct tls12_crypto_info_sm4_gcm);
66 		tls12->sm4gcm.info.version = tls_version;
67 		tls12->sm4gcm.info.cipher_type = cipher_type;
68 		break;
69 	case TLS_CIPHER_SM4_CCM:
70 		tls12->len = sizeof(struct tls12_crypto_info_sm4_ccm);
71 		tls12->sm4ccm.info.version = tls_version;
72 		tls12->sm4ccm.info.cipher_type = cipher_type;
73 		break;
74 	case TLS_CIPHER_AES_CCM_128:
75 		tls12->len = sizeof(struct tls12_crypto_info_aes_ccm_128);
76 		tls12->aesccm128.info.version = tls_version;
77 		tls12->aesccm128.info.cipher_type = cipher_type;
78 		break;
79 	case TLS_CIPHER_AES_GCM_256:
80 		tls12->len = sizeof(struct tls12_crypto_info_aes_gcm_256);
81 		tls12->aesgcm256.info.version = tls_version;
82 		tls12->aesgcm256.info.cipher_type = cipher_type;
83 		break;
84 	case TLS_CIPHER_ARIA_GCM_128:
85 		tls12->len = sizeof(struct tls12_crypto_info_aria_gcm_128);
86 		tls12->ariagcm128.info.version = tls_version;
87 		tls12->ariagcm128.info.cipher_type = cipher_type;
88 		break;
89 	case TLS_CIPHER_ARIA_GCM_256:
90 		tls12->len = sizeof(struct tls12_crypto_info_aria_gcm_256);
91 		tls12->ariagcm256.info.version = tls_version;
92 		tls12->ariagcm256.info.cipher_type = cipher_type;
93 		break;
94 	default:
95 		break;
96 	}
97 }
98 
99 static void memrnd(void *s, size_t n)
100 {
101 	int *dword = s;
102 	char *byte;
103 
104 	for (; n >= 4; n -= 4)
105 		*dword++ = rand();
106 	byte = (void *)dword;
107 	while (n--)
108 		*byte++ = rand();
109 }
110 
111 static void ulp_sock_pair(struct __test_metadata *_metadata,
112 			  int *fd, int *cfd, bool *notls)
113 {
114 	struct sockaddr_in addr;
115 	socklen_t len;
116 	int sfd, ret;
117 
118 	*notls = false;
119 	len = sizeof(addr);
120 
121 	addr.sin_family = AF_INET;
122 	addr.sin_addr.s_addr = htonl(INADDR_ANY);
123 	addr.sin_port = 0;
124 
125 	*fd = socket(AF_INET, SOCK_STREAM, 0);
126 	sfd = socket(AF_INET, SOCK_STREAM, 0);
127 
128 	ret = bind(sfd, &addr, sizeof(addr));
129 	ASSERT_EQ(ret, 0);
130 	ret = listen(sfd, 10);
131 	ASSERT_EQ(ret, 0);
132 
133 	ret = getsockname(sfd, &addr, &len);
134 	ASSERT_EQ(ret, 0);
135 
136 	ret = connect(*fd, &addr, sizeof(addr));
137 	ASSERT_EQ(ret, 0);
138 
139 	*cfd = accept(sfd, &addr, &len);
140 	ASSERT_GE(*cfd, 0);
141 
142 	close(sfd);
143 
144 	ret = setsockopt(*fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
145 	if (ret != 0) {
146 		ASSERT_EQ(errno, ENOENT);
147 		*notls = true;
148 		printf("Failure setting TCP_ULP, testing without tls\n");
149 		return;
150 	}
151 
152 	ret = setsockopt(*cfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
153 	ASSERT_EQ(ret, 0);
154 }
155 
156 /* Produce a basic cmsg */
157 static int tls_send_cmsg(int fd, unsigned char record_type,
158 			 void *data, size_t len, int flags)
159 {
160 	char cbuf[CMSG_SPACE(sizeof(char))];
161 	int cmsg_len = sizeof(char);
162 	struct cmsghdr *cmsg;
163 	struct msghdr msg;
164 	struct iovec vec;
165 
166 	vec.iov_base = data;
167 	vec.iov_len = len;
168 	memset(&msg, 0, sizeof(struct msghdr));
169 	msg.msg_iov = &vec;
170 	msg.msg_iovlen = 1;
171 	msg.msg_control = cbuf;
172 	msg.msg_controllen = sizeof(cbuf);
173 	cmsg = CMSG_FIRSTHDR(&msg);
174 	cmsg->cmsg_level = SOL_TLS;
175 	/* test sending non-record types. */
176 	cmsg->cmsg_type = TLS_SET_RECORD_TYPE;
177 	cmsg->cmsg_len = CMSG_LEN(cmsg_len);
178 	*CMSG_DATA(cmsg) = record_type;
179 	msg.msg_controllen = cmsg->cmsg_len;
180 
181 	return sendmsg(fd, &msg, flags);
182 }
183 
184 static int tls_recv_cmsg(struct __test_metadata *_metadata,
185 			 int fd, unsigned char record_type,
186 			 void *data, size_t len, int flags)
187 {
188 	char cbuf[CMSG_SPACE(sizeof(char))];
189 	struct cmsghdr *cmsg;
190 	unsigned char ctype;
191 	struct msghdr msg;
192 	struct iovec vec;
193 	int n;
194 
195 	vec.iov_base = data;
196 	vec.iov_len = len;
197 	memset(&msg, 0, sizeof(struct msghdr));
198 	msg.msg_iov = &vec;
199 	msg.msg_iovlen = 1;
200 	msg.msg_control = cbuf;
201 	msg.msg_controllen = sizeof(cbuf);
202 
203 	n = recvmsg(fd, &msg, flags);
204 
205 	cmsg = CMSG_FIRSTHDR(&msg);
206 	EXPECT_NE(cmsg, NULL);
207 	EXPECT_EQ(cmsg->cmsg_level, SOL_TLS);
208 	EXPECT_EQ(cmsg->cmsg_type, TLS_GET_RECORD_TYPE);
209 	ctype = *((unsigned char *)CMSG_DATA(cmsg));
210 	EXPECT_EQ(ctype, record_type);
211 
212 	return n;
213 }
214 
215 FIXTURE(tls_basic)
216 {
217 	int fd, cfd;
218 	bool notls;
219 };
220 
221 FIXTURE_SETUP(tls_basic)
222 {
223 	ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
224 }
225 
226 FIXTURE_TEARDOWN(tls_basic)
227 {
228 	close(self->fd);
229 	close(self->cfd);
230 }
231 
232 /* Send some data through with ULP but no keys */
233 TEST_F(tls_basic, base_base)
234 {
235 	char const *test_str = "test_read";
236 	int send_len = 10;
237 	char buf[10];
238 
239 	ASSERT_EQ(strlen(test_str) + 1, send_len);
240 
241 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
242 	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
243 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
244 };
245 
246 TEST_F(tls_basic, bad_cipher)
247 {
248 	struct tls_crypto_info_keys tls12;
249 
250 	tls12.crypto_info.version = 200;
251 	tls12.crypto_info.cipher_type = TLS_CIPHER_AES_GCM_128;
252 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
253 
254 	tls12.crypto_info.version = TLS_1_2_VERSION;
255 	tls12.crypto_info.cipher_type = 50;
256 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
257 
258 	tls12.crypto_info.version = TLS_1_2_VERSION;
259 	tls12.crypto_info.cipher_type = 59;
260 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
261 
262 	tls12.crypto_info.version = TLS_1_2_VERSION;
263 	tls12.crypto_info.cipher_type = 10;
264 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
265 
266 	tls12.crypto_info.version = TLS_1_2_VERSION;
267 	tls12.crypto_info.cipher_type = 70;
268 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
269 }
270 
271 TEST_F(tls_basic, recseq_wrap)
272 {
273 	struct tls_crypto_info_keys tls12;
274 	char const *test_str = "test_read";
275 	int send_len = 10;
276 
277 	if (self->notls)
278 		SKIP(return, "no TLS support");
279 
280 	tls_crypto_info_init(TLS_1_2_VERSION, TLS_CIPHER_AES_GCM_128, &tls12, 0);
281 	memset(&tls12.aes128.rec_seq, 0xff, sizeof(tls12.aes128.rec_seq));
282 
283 	ASSERT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
284 	ASSERT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
285 
286 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), -1);
287 	EXPECT_EQ(errno, EBADMSG);
288 }
289 
290 FIXTURE(tls)
291 {
292 	int fd, cfd;
293 	bool notls;
294 };
295 
296 FIXTURE_VARIANT(tls)
297 {
298 	uint16_t tls_version;
299 	uint16_t cipher_type;
300 	bool nopad, fips_non_compliant;
301 };
302 
303 FIXTURE_VARIANT_ADD(tls, 12_aes_gcm)
304 {
305 	.tls_version = TLS_1_2_VERSION,
306 	.cipher_type = TLS_CIPHER_AES_GCM_128,
307 };
308 
309 FIXTURE_VARIANT_ADD(tls, 13_aes_gcm)
310 {
311 	.tls_version = TLS_1_3_VERSION,
312 	.cipher_type = TLS_CIPHER_AES_GCM_128,
313 };
314 
315 FIXTURE_VARIANT_ADD(tls, 12_chacha)
316 {
317 	.tls_version = TLS_1_2_VERSION,
318 	.cipher_type = TLS_CIPHER_CHACHA20_POLY1305,
319 	.fips_non_compliant = true,
320 };
321 
322 FIXTURE_VARIANT_ADD(tls, 13_chacha)
323 {
324 	.tls_version = TLS_1_3_VERSION,
325 	.cipher_type = TLS_CIPHER_CHACHA20_POLY1305,
326 	.fips_non_compliant = true,
327 };
328 
329 FIXTURE_VARIANT_ADD(tls, 13_sm4_gcm)
330 {
331 	.tls_version = TLS_1_3_VERSION,
332 	.cipher_type = TLS_CIPHER_SM4_GCM,
333 	.fips_non_compliant = true,
334 };
335 
336 FIXTURE_VARIANT_ADD(tls, 13_sm4_ccm)
337 {
338 	.tls_version = TLS_1_3_VERSION,
339 	.cipher_type = TLS_CIPHER_SM4_CCM,
340 	.fips_non_compliant = true,
341 };
342 
343 FIXTURE_VARIANT_ADD(tls, 12_aes_ccm)
344 {
345 	.tls_version = TLS_1_2_VERSION,
346 	.cipher_type = TLS_CIPHER_AES_CCM_128,
347 };
348 
349 FIXTURE_VARIANT_ADD(tls, 13_aes_ccm)
350 {
351 	.tls_version = TLS_1_3_VERSION,
352 	.cipher_type = TLS_CIPHER_AES_CCM_128,
353 };
354 
355 FIXTURE_VARIANT_ADD(tls, 12_aes_gcm_256)
356 {
357 	.tls_version = TLS_1_2_VERSION,
358 	.cipher_type = TLS_CIPHER_AES_GCM_256,
359 };
360 
361 FIXTURE_VARIANT_ADD(tls, 13_aes_gcm_256)
362 {
363 	.tls_version = TLS_1_3_VERSION,
364 	.cipher_type = TLS_CIPHER_AES_GCM_256,
365 };
366 
367 FIXTURE_VARIANT_ADD(tls, 13_nopad)
368 {
369 	.tls_version = TLS_1_3_VERSION,
370 	.cipher_type = TLS_CIPHER_AES_GCM_128,
371 	.nopad = true,
372 };
373 
374 FIXTURE_VARIANT_ADD(tls, 12_aria_gcm)
375 {
376 	.tls_version = TLS_1_2_VERSION,
377 	.cipher_type = TLS_CIPHER_ARIA_GCM_128,
378 };
379 
380 FIXTURE_VARIANT_ADD(tls, 12_aria_gcm_256)
381 {
382 	.tls_version = TLS_1_2_VERSION,
383 	.cipher_type = TLS_CIPHER_ARIA_GCM_256,
384 };
385 
386 FIXTURE_SETUP(tls)
387 {
388 	struct tls_crypto_info_keys tls12;
389 	int one = 1;
390 	int ret;
391 
392 	if (fips_enabled && variant->fips_non_compliant)
393 		SKIP(return, "Unsupported cipher in FIPS mode");
394 
395 	tls_crypto_info_init(variant->tls_version, variant->cipher_type,
396 			     &tls12, 0);
397 
398 	ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
399 
400 	if (self->notls)
401 		return;
402 
403 	ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
404 	ASSERT_EQ(ret, 0);
405 
406 	ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len);
407 	ASSERT_EQ(ret, 0);
408 
409 	if (variant->nopad) {
410 		ret = setsockopt(self->cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
411 				 (void *)&one, sizeof(one));
412 		ASSERT_EQ(ret, 0);
413 	}
414 }
415 
416 FIXTURE_TEARDOWN(tls)
417 {
418 	close(self->fd);
419 	close(self->cfd);
420 }
421 
422 TEST_F(tls, sendfile)
423 {
424 	int filefd = open("/proc/self/exe", O_RDONLY);
425 	struct stat st;
426 
427 	EXPECT_GE(filefd, 0);
428 	fstat(filefd, &st);
429 	EXPECT_GE(sendfile(self->fd, filefd, 0, st.st_size), 0);
430 }
431 
432 TEST_F(tls, send_then_sendfile)
433 {
434 	int filefd = open("/proc/self/exe", O_RDONLY);
435 	char const *test_str = "test_send";
436 	int to_send = strlen(test_str) + 1;
437 	char recv_buf[10];
438 	struct stat st;
439 	char *buf;
440 
441 	EXPECT_GE(filefd, 0);
442 	fstat(filefd, &st);
443 	buf = (char *)malloc(st.st_size);
444 
445 	EXPECT_EQ(send(self->fd, test_str, to_send, 0), to_send);
446 	EXPECT_EQ(recv(self->cfd, recv_buf, to_send, MSG_WAITALL), to_send);
447 	EXPECT_EQ(memcmp(test_str, recv_buf, to_send), 0);
448 
449 	EXPECT_GE(sendfile(self->fd, filefd, 0, st.st_size), 0);
450 	EXPECT_EQ(recv(self->cfd, buf, st.st_size, MSG_WAITALL), st.st_size);
451 }
452 
453 static void chunked_sendfile(struct __test_metadata *_metadata,
454 			     struct _test_data_tls *self,
455 			     uint16_t chunk_size,
456 			     uint16_t extra_payload_size)
457 {
458 	char buf[TLS_PAYLOAD_MAX_LEN];
459 	uint16_t test_payload_size;
460 	int size = 0;
461 	int ret;
462 	char filename[] = "/tmp/mytemp.XXXXXX";
463 	int fd = mkstemp(filename);
464 	off_t offset = 0;
465 
466 	unlink(filename);
467 	ASSERT_GE(fd, 0);
468 	EXPECT_GE(chunk_size, 1);
469 	test_payload_size = chunk_size + extra_payload_size;
470 	ASSERT_GE(TLS_PAYLOAD_MAX_LEN, test_payload_size);
471 	memset(buf, 1, test_payload_size);
472 	size = write(fd, buf, test_payload_size);
473 	EXPECT_EQ(size, test_payload_size);
474 	fsync(fd);
475 
476 	while (size > 0) {
477 		ret = sendfile(self->fd, fd, &offset, chunk_size);
478 		EXPECT_GE(ret, 0);
479 		size -= ret;
480 	}
481 
482 	EXPECT_EQ(recv(self->cfd, buf, test_payload_size, MSG_WAITALL),
483 		  test_payload_size);
484 
485 	close(fd);
486 }
487 
488 TEST_F(tls, multi_chunk_sendfile)
489 {
490 	chunked_sendfile(_metadata, self, 4096, 4096);
491 	chunked_sendfile(_metadata, self, 4096, 0);
492 	chunked_sendfile(_metadata, self, 4096, 1);
493 	chunked_sendfile(_metadata, self, 4096, 2048);
494 	chunked_sendfile(_metadata, self, 8192, 2048);
495 	chunked_sendfile(_metadata, self, 4096, 8192);
496 	chunked_sendfile(_metadata, self, 8192, 4096);
497 	chunked_sendfile(_metadata, self, 12288, 1024);
498 	chunked_sendfile(_metadata, self, 12288, 2000);
499 	chunked_sendfile(_metadata, self, 15360, 100);
500 	chunked_sendfile(_metadata, self, 15360, 300);
501 	chunked_sendfile(_metadata, self, 1, 4096);
502 	chunked_sendfile(_metadata, self, 2048, 4096);
503 	chunked_sendfile(_metadata, self, 2048, 8192);
504 	chunked_sendfile(_metadata, self, 4096, 8192);
505 	chunked_sendfile(_metadata, self, 1024, 12288);
506 	chunked_sendfile(_metadata, self, 2000, 12288);
507 	chunked_sendfile(_metadata, self, 100, 15360);
508 	chunked_sendfile(_metadata, self, 300, 15360);
509 }
510 
511 TEST_F(tls, recv_max)
512 {
513 	unsigned int send_len = TLS_PAYLOAD_MAX_LEN;
514 	char recv_mem[TLS_PAYLOAD_MAX_LEN];
515 	char buf[TLS_PAYLOAD_MAX_LEN];
516 
517 	memrnd(buf, sizeof(buf));
518 
519 	EXPECT_GE(send(self->fd, buf, send_len, 0), 0);
520 	EXPECT_NE(recv(self->cfd, recv_mem, send_len, 0), -1);
521 	EXPECT_EQ(memcmp(buf, recv_mem, send_len), 0);
522 }
523 
524 TEST_F(tls, recv_small)
525 {
526 	char const *test_str = "test_read";
527 	int send_len = 10;
528 	char buf[10];
529 
530 	send_len = strlen(test_str) + 1;
531 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
532 	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
533 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
534 }
535 
536 TEST_F(tls, msg_more)
537 {
538 	char const *test_str = "test_read";
539 	int send_len = 10;
540 	char buf[10 * 2];
541 
542 	EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
543 	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_DONTWAIT), -1);
544 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
545 	EXPECT_EQ(recv(self->cfd, buf, send_len * 2, MSG_WAITALL),
546 		  send_len * 2);
547 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
548 }
549 
550 TEST_F(tls, msg_more_unsent)
551 {
552 	char const *test_str = "test_read";
553 	int send_len = 10;
554 	char buf[10];
555 
556 	EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
557 	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_DONTWAIT), -1);
558 }
559 
560 TEST_F(tls, msg_eor)
561 {
562 	char const *test_str = "test_read";
563 	int send_len = 10;
564 	char buf[10];
565 
566 	EXPECT_EQ(send(self->fd, test_str, send_len, MSG_EOR), send_len);
567 	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
568 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
569 }
570 
571 TEST_F(tls, sendmsg_single)
572 {
573 	struct msghdr msg;
574 
575 	char const *test_str = "test_sendmsg";
576 	size_t send_len = 13;
577 	struct iovec vec;
578 	char buf[13];
579 
580 	vec.iov_base = (char *)test_str;
581 	vec.iov_len = send_len;
582 	memset(&msg, 0, sizeof(struct msghdr));
583 	msg.msg_iov = &vec;
584 	msg.msg_iovlen = 1;
585 	EXPECT_EQ(sendmsg(self->fd, &msg, 0), send_len);
586 	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
587 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
588 }
589 
590 #define MAX_FRAGS	64
591 #define SEND_LEN	13
592 TEST_F(tls, sendmsg_fragmented)
593 {
594 	char const *test_str = "test_sendmsg";
595 	char buf[SEND_LEN * MAX_FRAGS];
596 	struct iovec vec[MAX_FRAGS];
597 	struct msghdr msg;
598 	int i, frags;
599 
600 	for (frags = 1; frags <= MAX_FRAGS; frags++) {
601 		for (i = 0; i < frags; i++) {
602 			vec[i].iov_base = (char *)test_str;
603 			vec[i].iov_len = SEND_LEN;
604 		}
605 
606 		memset(&msg, 0, sizeof(struct msghdr));
607 		msg.msg_iov = vec;
608 		msg.msg_iovlen = frags;
609 
610 		EXPECT_EQ(sendmsg(self->fd, &msg, 0), SEND_LEN * frags);
611 		EXPECT_EQ(recv(self->cfd, buf, SEND_LEN * frags, MSG_WAITALL),
612 			  SEND_LEN * frags);
613 
614 		for (i = 0; i < frags; i++)
615 			EXPECT_EQ(memcmp(buf + SEND_LEN * i,
616 					 test_str, SEND_LEN), 0);
617 	}
618 }
619 #undef MAX_FRAGS
620 #undef SEND_LEN
621 
622 TEST_F(tls, sendmsg_large)
623 {
624 	void *mem = malloc(16384);
625 	size_t send_len = 16384;
626 	size_t sends = 128;
627 	struct msghdr msg;
628 	size_t recvs = 0;
629 	size_t sent = 0;
630 
631 	memset(&msg, 0, sizeof(struct msghdr));
632 	while (sent++ < sends) {
633 		struct iovec vec = { (void *)mem, send_len };
634 
635 		msg.msg_iov = &vec;
636 		msg.msg_iovlen = 1;
637 		EXPECT_EQ(sendmsg(self->fd, &msg, 0), send_len);
638 	}
639 
640 	while (recvs++ < sends) {
641 		EXPECT_NE(recv(self->cfd, mem, send_len, 0), -1);
642 	}
643 
644 	free(mem);
645 }
646 
647 TEST_F(tls, sendmsg_multiple)
648 {
649 	char const *test_str = "test_sendmsg_multiple";
650 	struct iovec vec[5];
651 	char *test_strs[5];
652 	struct msghdr msg;
653 	int total_len = 0;
654 	int len_cmp = 0;
655 	int iov_len = 5;
656 	char *buf;
657 	int i;
658 
659 	memset(&msg, 0, sizeof(struct msghdr));
660 	for (i = 0; i < iov_len; i++) {
661 		test_strs[i] = (char *)malloc(strlen(test_str) + 1);
662 		snprintf(test_strs[i], strlen(test_str) + 1, "%s", test_str);
663 		vec[i].iov_base = (void *)test_strs[i];
664 		vec[i].iov_len = strlen(test_strs[i]) + 1;
665 		total_len += vec[i].iov_len;
666 	}
667 	msg.msg_iov = vec;
668 	msg.msg_iovlen = iov_len;
669 
670 	EXPECT_EQ(sendmsg(self->fd, &msg, 0), total_len);
671 	buf = malloc(total_len);
672 	EXPECT_NE(recv(self->cfd, buf, total_len, 0), -1);
673 	for (i = 0; i < iov_len; i++) {
674 		EXPECT_EQ(memcmp(test_strs[i], buf + len_cmp,
675 				 strlen(test_strs[i])),
676 			  0);
677 		len_cmp += strlen(buf + len_cmp) + 1;
678 	}
679 	for (i = 0; i < iov_len; i++)
680 		free(test_strs[i]);
681 	free(buf);
682 }
683 
684 TEST_F(tls, sendmsg_multiple_stress)
685 {
686 	char const *test_str = "abcdefghijklmno";
687 	struct iovec vec[1024];
688 	char *test_strs[1024];
689 	int iov_len = 1024;
690 	int total_len = 0;
691 	char buf[1 << 14];
692 	struct msghdr msg;
693 	int len_cmp = 0;
694 	int i;
695 
696 	memset(&msg, 0, sizeof(struct msghdr));
697 	for (i = 0; i < iov_len; i++) {
698 		test_strs[i] = (char *)malloc(strlen(test_str) + 1);
699 		snprintf(test_strs[i], strlen(test_str) + 1, "%s", test_str);
700 		vec[i].iov_base = (void *)test_strs[i];
701 		vec[i].iov_len = strlen(test_strs[i]) + 1;
702 		total_len += vec[i].iov_len;
703 	}
704 	msg.msg_iov = vec;
705 	msg.msg_iovlen = iov_len;
706 
707 	EXPECT_EQ(sendmsg(self->fd, &msg, 0), total_len);
708 	EXPECT_NE(recv(self->cfd, buf, total_len, 0), -1);
709 
710 	for (i = 0; i < iov_len; i++)
711 		len_cmp += strlen(buf + len_cmp) + 1;
712 
713 	for (i = 0; i < iov_len; i++)
714 		free(test_strs[i]);
715 }
716 
717 TEST_F(tls, splice_from_pipe)
718 {
719 	int send_len = TLS_PAYLOAD_MAX_LEN;
720 	char mem_send[TLS_PAYLOAD_MAX_LEN];
721 	char mem_recv[TLS_PAYLOAD_MAX_LEN];
722 	int p[2];
723 
724 	ASSERT_GE(pipe(p), 0);
725 	EXPECT_GE(write(p[1], mem_send, send_len), 0);
726 	EXPECT_GE(splice(p[0], NULL, self->fd, NULL, send_len, 0), 0);
727 	EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
728 	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
729 }
730 
731 TEST_F(tls, splice_more)
732 {
733 	unsigned int f = SPLICE_F_NONBLOCK | SPLICE_F_MORE | SPLICE_F_GIFT;
734 	int send_len = TLS_PAYLOAD_MAX_LEN;
735 	char mem_send[TLS_PAYLOAD_MAX_LEN];
736 	int i, send_pipe = 1;
737 	int p[2];
738 
739 	ASSERT_GE(pipe(p), 0);
740 	EXPECT_GE(write(p[1], mem_send, send_len), 0);
741 	for (i = 0; i < 32; i++)
742 		EXPECT_EQ(splice(p[0], NULL, self->fd, NULL, send_pipe, f), 1);
743 }
744 
745 TEST_F(tls, splice_from_pipe2)
746 {
747 	int send_len = 16000;
748 	char mem_send[16000];
749 	char mem_recv[16000];
750 	int p2[2];
751 	int p[2];
752 
753 	memrnd(mem_send, sizeof(mem_send));
754 
755 	ASSERT_GE(pipe(p), 0);
756 	ASSERT_GE(pipe(p2), 0);
757 	EXPECT_EQ(write(p[1], mem_send, 8000), 8000);
758 	EXPECT_EQ(splice(p[0], NULL, self->fd, NULL, 8000, 0), 8000);
759 	EXPECT_EQ(write(p2[1], mem_send + 8000, 8000), 8000);
760 	EXPECT_EQ(splice(p2[0], NULL, self->fd, NULL, 8000, 0), 8000);
761 	EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
762 	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
763 }
764 
765 TEST_F(tls, send_and_splice)
766 {
767 	int send_len = TLS_PAYLOAD_MAX_LEN;
768 	char mem_send[TLS_PAYLOAD_MAX_LEN];
769 	char mem_recv[TLS_PAYLOAD_MAX_LEN];
770 	char const *test_str = "test_read";
771 	int send_len2 = 10;
772 	char buf[10];
773 	int p[2];
774 
775 	ASSERT_GE(pipe(p), 0);
776 	EXPECT_EQ(send(self->fd, test_str, send_len2, 0), send_len2);
777 	EXPECT_EQ(recv(self->cfd, buf, send_len2, MSG_WAITALL), send_len2);
778 	EXPECT_EQ(memcmp(test_str, buf, send_len2), 0);
779 
780 	EXPECT_GE(write(p[1], mem_send, send_len), send_len);
781 	EXPECT_GE(splice(p[0], NULL, self->fd, NULL, send_len, 0), send_len);
782 
783 	EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
784 	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
785 }
786 
787 TEST_F(tls, splice_to_pipe)
788 {
789 	int send_len = TLS_PAYLOAD_MAX_LEN;
790 	char mem_send[TLS_PAYLOAD_MAX_LEN];
791 	char mem_recv[TLS_PAYLOAD_MAX_LEN];
792 	int p[2];
793 
794 	memrnd(mem_send, sizeof(mem_send));
795 
796 	ASSERT_GE(pipe(p), 0);
797 	EXPECT_EQ(send(self->fd, mem_send, send_len, 0), send_len);
798 	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, send_len, 0), send_len);
799 	EXPECT_EQ(read(p[0], mem_recv, send_len), send_len);
800 	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
801 }
802 
803 TEST_F(tls, splice_cmsg_to_pipe)
804 {
805 	char *test_str = "test_read";
806 	char record_type = 100;
807 	int send_len = 10;
808 	char buf[10];
809 	int p[2];
810 
811 	if (self->notls)
812 		SKIP(return, "no TLS support");
813 
814 	ASSERT_GE(pipe(p), 0);
815 	EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
816 	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, send_len, 0), -1);
817 	EXPECT_EQ(errno, EINVAL);
818 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
819 	EXPECT_EQ(errno, EIO);
820 	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
821 				buf, sizeof(buf), MSG_WAITALL),
822 		  send_len);
823 	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
824 }
825 
826 TEST_F(tls, splice_dec_cmsg_to_pipe)
827 {
828 	char *test_str = "test_read";
829 	char record_type = 100;
830 	int send_len = 10;
831 	char buf[10];
832 	int p[2];
833 
834 	if (self->notls)
835 		SKIP(return, "no TLS support");
836 
837 	ASSERT_GE(pipe(p), 0);
838 	EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
839 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
840 	EXPECT_EQ(errno, EIO);
841 	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, send_len, 0), -1);
842 	EXPECT_EQ(errno, EINVAL);
843 	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
844 				buf, sizeof(buf), MSG_WAITALL),
845 		  send_len);
846 	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
847 }
848 
849 TEST_F(tls, recv_and_splice)
850 {
851 	int send_len = TLS_PAYLOAD_MAX_LEN;
852 	char mem_send[TLS_PAYLOAD_MAX_LEN];
853 	char mem_recv[TLS_PAYLOAD_MAX_LEN];
854 	int half = send_len / 2;
855 	int p[2];
856 
857 	ASSERT_GE(pipe(p), 0);
858 	EXPECT_EQ(send(self->fd, mem_send, send_len, 0), send_len);
859 	/* Recv hald of the record, splice the other half */
860 	EXPECT_EQ(recv(self->cfd, mem_recv, half, MSG_WAITALL), half);
861 	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, half, SPLICE_F_NONBLOCK),
862 		  half);
863 	EXPECT_EQ(read(p[0], &mem_recv[half], half), half);
864 	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
865 }
866 
867 TEST_F(tls, peek_and_splice)
868 {
869 	int send_len = TLS_PAYLOAD_MAX_LEN;
870 	char mem_send[TLS_PAYLOAD_MAX_LEN];
871 	char mem_recv[TLS_PAYLOAD_MAX_LEN];
872 	int chunk = TLS_PAYLOAD_MAX_LEN / 4;
873 	int n, i, p[2];
874 
875 	memrnd(mem_send, sizeof(mem_send));
876 
877 	ASSERT_GE(pipe(p), 0);
878 	for (i = 0; i < 4; i++)
879 		EXPECT_EQ(send(self->fd, &mem_send[chunk * i], chunk, 0),
880 			  chunk);
881 
882 	EXPECT_EQ(recv(self->cfd, mem_recv, chunk * 5 / 2,
883 		       MSG_WAITALL | MSG_PEEK),
884 		  chunk * 5 / 2);
885 	EXPECT_EQ(memcmp(mem_send, mem_recv, chunk * 5 / 2), 0);
886 
887 	n = 0;
888 	while (n < send_len) {
889 		i = splice(self->cfd, NULL, p[1], NULL, send_len - n, 0);
890 		EXPECT_GT(i, 0);
891 		n += i;
892 	}
893 	EXPECT_EQ(n, send_len);
894 	EXPECT_EQ(read(p[0], mem_recv, send_len), send_len);
895 	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
896 }
897 
898 TEST_F(tls, recvmsg_single)
899 {
900 	char const *test_str = "test_recvmsg_single";
901 	int send_len = strlen(test_str) + 1;
902 	char buf[20];
903 	struct msghdr hdr;
904 	struct iovec vec;
905 
906 	memset(&hdr, 0, sizeof(hdr));
907 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
908 	vec.iov_base = (char *)buf;
909 	vec.iov_len = send_len;
910 	hdr.msg_iovlen = 1;
911 	hdr.msg_iov = &vec;
912 	EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
913 	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
914 }
915 
916 TEST_F(tls, recvmsg_single_max)
917 {
918 	int send_len = TLS_PAYLOAD_MAX_LEN;
919 	char send_mem[TLS_PAYLOAD_MAX_LEN];
920 	char recv_mem[TLS_PAYLOAD_MAX_LEN];
921 	struct iovec vec;
922 	struct msghdr hdr;
923 
924 	memrnd(send_mem, sizeof(send_mem));
925 
926 	EXPECT_EQ(send(self->fd, send_mem, send_len, 0), send_len);
927 	vec.iov_base = (char *)recv_mem;
928 	vec.iov_len = TLS_PAYLOAD_MAX_LEN;
929 
930 	hdr.msg_iovlen = 1;
931 	hdr.msg_iov = &vec;
932 	EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
933 	EXPECT_EQ(memcmp(send_mem, recv_mem, send_len), 0);
934 }
935 
936 TEST_F(tls, recvmsg_multiple)
937 {
938 	unsigned int msg_iovlen = 1024;
939 	struct iovec vec[1024];
940 	char *iov_base[1024];
941 	unsigned int iov_len = 16;
942 	int send_len = 1 << 14;
943 	char buf[1 << 14];
944 	struct msghdr hdr;
945 	int i;
946 
947 	memrnd(buf, sizeof(buf));
948 
949 	EXPECT_EQ(send(self->fd, buf, send_len, 0), send_len);
950 	for (i = 0; i < msg_iovlen; i++) {
951 		iov_base[i] = (char *)malloc(iov_len);
952 		vec[i].iov_base = iov_base[i];
953 		vec[i].iov_len = iov_len;
954 	}
955 
956 	hdr.msg_iovlen = msg_iovlen;
957 	hdr.msg_iov = vec;
958 	EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
959 
960 	for (i = 0; i < msg_iovlen; i++)
961 		free(iov_base[i]);
962 }
963 
964 TEST_F(tls, single_send_multiple_recv)
965 {
966 	unsigned int total_len = TLS_PAYLOAD_MAX_LEN * 2;
967 	unsigned int send_len = TLS_PAYLOAD_MAX_LEN;
968 	char send_mem[TLS_PAYLOAD_MAX_LEN * 2];
969 	char recv_mem[TLS_PAYLOAD_MAX_LEN * 2];
970 
971 	memrnd(send_mem, sizeof(send_mem));
972 
973 	EXPECT_GE(send(self->fd, send_mem, total_len, 0), 0);
974 	memset(recv_mem, 0, total_len);
975 
976 	EXPECT_NE(recv(self->cfd, recv_mem, send_len, 0), -1);
977 	EXPECT_NE(recv(self->cfd, recv_mem + send_len, send_len, 0), -1);
978 	EXPECT_EQ(memcmp(send_mem, recv_mem, total_len), 0);
979 }
980 
981 TEST_F(tls, multiple_send_single_recv)
982 {
983 	unsigned int total_len = 2 * 10;
984 	unsigned int send_len = 10;
985 	char recv_mem[2 * 10];
986 	char send_mem[10];
987 
988 	memrnd(send_mem, sizeof(send_mem));
989 
990 	EXPECT_GE(send(self->fd, send_mem, send_len, 0), 0);
991 	EXPECT_GE(send(self->fd, send_mem, send_len, 0), 0);
992 	memset(recv_mem, 0, total_len);
993 	EXPECT_EQ(recv(self->cfd, recv_mem, total_len, MSG_WAITALL), total_len);
994 
995 	EXPECT_EQ(memcmp(send_mem, recv_mem, send_len), 0);
996 	EXPECT_EQ(memcmp(send_mem, recv_mem + send_len, send_len), 0);
997 }
998 
999 TEST_F(tls, single_send_multiple_recv_non_align)
1000 {
1001 	const unsigned int total_len = 15;
1002 	const unsigned int recv_len = 10;
1003 	char recv_mem[recv_len * 2];
1004 	char send_mem[total_len];
1005 
1006 	memrnd(send_mem, sizeof(send_mem));
1007 
1008 	EXPECT_GE(send(self->fd, send_mem, total_len, 0), 0);
1009 	memset(recv_mem, 0, total_len);
1010 
1011 	EXPECT_EQ(recv(self->cfd, recv_mem, recv_len, 0), recv_len);
1012 	EXPECT_EQ(recv(self->cfd, recv_mem + recv_len, recv_len, 0), 5);
1013 	EXPECT_EQ(memcmp(send_mem, recv_mem, total_len), 0);
1014 }
1015 
1016 TEST_F(tls, recv_partial)
1017 {
1018 	char const *test_str = "test_read_partial";
1019 	char const *test_str_first = "test_read";
1020 	char const *test_str_second = "_partial";
1021 	int send_len = strlen(test_str) + 1;
1022 	char recv_mem[18];
1023 
1024 	memset(recv_mem, 0, sizeof(recv_mem));
1025 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1026 	EXPECT_EQ(recv(self->cfd, recv_mem, strlen(test_str_first),
1027 		       MSG_WAITALL), strlen(test_str_first));
1028 	EXPECT_EQ(memcmp(test_str_first, recv_mem, strlen(test_str_first)), 0);
1029 	memset(recv_mem, 0, sizeof(recv_mem));
1030 	EXPECT_EQ(recv(self->cfd, recv_mem, strlen(test_str_second),
1031 		       MSG_WAITALL), strlen(test_str_second));
1032 	EXPECT_EQ(memcmp(test_str_second, recv_mem, strlen(test_str_second)),
1033 		  0);
1034 }
1035 
1036 TEST_F(tls, recv_nonblock)
1037 {
1038 	char buf[4096];
1039 	bool err;
1040 
1041 	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_DONTWAIT), -1);
1042 	err = (errno == EAGAIN || errno == EWOULDBLOCK);
1043 	EXPECT_EQ(err, true);
1044 }
1045 
1046 TEST_F(tls, recv_peek)
1047 {
1048 	char const *test_str = "test_read_peek";
1049 	int send_len = strlen(test_str) + 1;
1050 	char buf[15];
1051 
1052 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1053 	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_PEEK), send_len);
1054 	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
1055 	memset(buf, 0, sizeof(buf));
1056 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), send_len);
1057 	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
1058 }
1059 
1060 TEST_F(tls, recv_peek_multiple)
1061 {
1062 	char const *test_str = "test_read_peek";
1063 	int send_len = strlen(test_str) + 1;
1064 	unsigned int num_peeks = 100;
1065 	char buf[15];
1066 	int i;
1067 
1068 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1069 	for (i = 0; i < num_peeks; i++) {
1070 		EXPECT_NE(recv(self->cfd, buf, send_len, MSG_PEEK), -1);
1071 		EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
1072 		memset(buf, 0, sizeof(buf));
1073 	}
1074 	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1075 	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
1076 }
1077 
1078 TEST_F(tls, recv_peek_multiple_records)
1079 {
1080 	char const *test_str = "test_read_peek_mult_recs";
1081 	char const *test_str_first = "test_read_peek";
1082 	char const *test_str_second = "_mult_recs";
1083 	int len;
1084 	char buf[64];
1085 
1086 	len = strlen(test_str_first);
1087 	EXPECT_EQ(send(self->fd, test_str_first, len, 0), len);
1088 
1089 	len = strlen(test_str_second) + 1;
1090 	EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
1091 
1092 	len = strlen(test_str_first);
1093 	memset(buf, 0, len);
1094 	EXPECT_EQ(recv(self->cfd, buf, len, MSG_PEEK | MSG_WAITALL), len);
1095 
1096 	/* MSG_PEEK can only peek into the current record. */
1097 	len = strlen(test_str_first);
1098 	EXPECT_EQ(memcmp(test_str_first, buf, len), 0);
1099 
1100 	len = strlen(test_str) + 1;
1101 	memset(buf, 0, len);
1102 	EXPECT_EQ(recv(self->cfd, buf, len, MSG_WAITALL), len);
1103 
1104 	/* Non-MSG_PEEK will advance strparser (and therefore record)
1105 	 * however.
1106 	 */
1107 	len = strlen(test_str) + 1;
1108 	EXPECT_EQ(memcmp(test_str, buf, len), 0);
1109 
1110 	/* MSG_MORE will hold current record open, so later MSG_PEEK
1111 	 * will see everything.
1112 	 */
1113 	len = strlen(test_str_first);
1114 	EXPECT_EQ(send(self->fd, test_str_first, len, MSG_MORE), len);
1115 
1116 	len = strlen(test_str_second) + 1;
1117 	EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
1118 
1119 	len = strlen(test_str) + 1;
1120 	memset(buf, 0, len);
1121 	EXPECT_EQ(recv(self->cfd, buf, len, MSG_PEEK | MSG_WAITALL), len);
1122 
1123 	len = strlen(test_str) + 1;
1124 	EXPECT_EQ(memcmp(test_str, buf, len), 0);
1125 }
1126 
1127 TEST_F(tls, recv_peek_large_buf_mult_recs)
1128 {
1129 	char const *test_str = "test_read_peek_mult_recs";
1130 	char const *test_str_first = "test_read_peek";
1131 	char const *test_str_second = "_mult_recs";
1132 	int len;
1133 	char buf[64];
1134 
1135 	len = strlen(test_str_first);
1136 	EXPECT_EQ(send(self->fd, test_str_first, len, 0), len);
1137 
1138 	len = strlen(test_str_second) + 1;
1139 	EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
1140 
1141 	len = strlen(test_str) + 1;
1142 	memset(buf, 0, len);
1143 	EXPECT_NE((len = recv(self->cfd, buf, len,
1144 			      MSG_PEEK | MSG_WAITALL)), -1);
1145 	len = strlen(test_str) + 1;
1146 	EXPECT_EQ(memcmp(test_str, buf, len), 0);
1147 }
1148 
1149 TEST_F(tls, recv_lowat)
1150 {
1151 	char send_mem[10] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
1152 	char recv_mem[20];
1153 	int lowat = 8;
1154 
1155 	EXPECT_EQ(send(self->fd, send_mem, 10, 0), 10);
1156 	EXPECT_EQ(send(self->fd, send_mem, 5, 0), 5);
1157 
1158 	memset(recv_mem, 0, 20);
1159 	EXPECT_EQ(setsockopt(self->cfd, SOL_SOCKET, SO_RCVLOWAT,
1160 			     &lowat, sizeof(lowat)), 0);
1161 	EXPECT_EQ(recv(self->cfd, recv_mem, 1, MSG_WAITALL), 1);
1162 	EXPECT_EQ(recv(self->cfd, recv_mem + 1, 6, MSG_WAITALL), 6);
1163 	EXPECT_EQ(recv(self->cfd, recv_mem + 7, 10, 0), 8);
1164 
1165 	EXPECT_EQ(memcmp(send_mem, recv_mem, 10), 0);
1166 	EXPECT_EQ(memcmp(send_mem, recv_mem + 10, 5), 0);
1167 }
1168 
1169 TEST_F(tls, bidir)
1170 {
1171 	char const *test_str = "test_read";
1172 	int send_len = 10;
1173 	char buf[10];
1174 	int ret;
1175 
1176 	if (!self->notls) {
1177 		struct tls_crypto_info_keys tls12;
1178 
1179 		tls_crypto_info_init(variant->tls_version, variant->cipher_type,
1180 				     &tls12, 0);
1181 
1182 		ret = setsockopt(self->fd, SOL_TLS, TLS_RX, &tls12,
1183 				 tls12.len);
1184 		ASSERT_EQ(ret, 0);
1185 
1186 		ret = setsockopt(self->cfd, SOL_TLS, TLS_TX, &tls12,
1187 				 tls12.len);
1188 		ASSERT_EQ(ret, 0);
1189 	}
1190 
1191 	ASSERT_EQ(strlen(test_str) + 1, send_len);
1192 
1193 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1194 	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1195 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1196 
1197 	memset(buf, 0, sizeof(buf));
1198 
1199 	EXPECT_EQ(send(self->cfd, test_str, send_len, 0), send_len);
1200 	EXPECT_NE(recv(self->fd, buf, send_len, 0), -1);
1201 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1202 };
1203 
1204 TEST_F(tls, pollin)
1205 {
1206 	char const *test_str = "test_poll";
1207 	struct pollfd fd = { 0, 0, 0 };
1208 	char buf[10];
1209 	int send_len = 10;
1210 
1211 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1212 	fd.fd = self->cfd;
1213 	fd.events = POLLIN;
1214 
1215 	EXPECT_EQ(poll(&fd, 1, 20), 1);
1216 	EXPECT_EQ(fd.revents & POLLIN, 1);
1217 	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
1218 	/* Test timing out */
1219 	EXPECT_EQ(poll(&fd, 1, 20), 0);
1220 }
1221 
1222 TEST_F(tls, poll_wait)
1223 {
1224 	char const *test_str = "test_poll_wait";
1225 	int send_len = strlen(test_str) + 1;
1226 	struct pollfd fd = { 0, 0, 0 };
1227 	char recv_mem[15];
1228 
1229 	fd.fd = self->cfd;
1230 	fd.events = POLLIN;
1231 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1232 	/* Set timeout to inf. secs */
1233 	EXPECT_EQ(poll(&fd, 1, -1), 1);
1234 	EXPECT_EQ(fd.revents & POLLIN, 1);
1235 	EXPECT_EQ(recv(self->cfd, recv_mem, send_len, MSG_WAITALL), send_len);
1236 }
1237 
1238 TEST_F(tls, poll_wait_split)
1239 {
1240 	struct pollfd fd = { 0, 0, 0 };
1241 	char send_mem[20] = {};
1242 	char recv_mem[15];
1243 
1244 	fd.fd = self->cfd;
1245 	fd.events = POLLIN;
1246 	/* Send 20 bytes */
1247 	EXPECT_EQ(send(self->fd, send_mem, sizeof(send_mem), 0),
1248 		  sizeof(send_mem));
1249 	/* Poll with inf. timeout */
1250 	EXPECT_EQ(poll(&fd, 1, -1), 1);
1251 	EXPECT_EQ(fd.revents & POLLIN, 1);
1252 	EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), MSG_WAITALL),
1253 		  sizeof(recv_mem));
1254 
1255 	/* Now the remaining 5 bytes of record data are in TLS ULP */
1256 	fd.fd = self->cfd;
1257 	fd.events = POLLIN;
1258 	EXPECT_EQ(poll(&fd, 1, -1), 1);
1259 	EXPECT_EQ(fd.revents & POLLIN, 1);
1260 	EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), 0),
1261 		  sizeof(send_mem) - sizeof(recv_mem));
1262 }
1263 
1264 TEST_F(tls, blocking)
1265 {
1266 	size_t data = 100000;
1267 	int res = fork();
1268 
1269 	EXPECT_NE(res, -1);
1270 
1271 	if (res) {
1272 		/* parent */
1273 		size_t left = data;
1274 		char buf[16384];
1275 		int status;
1276 		int pid2;
1277 
1278 		while (left) {
1279 			int res = send(self->fd, buf,
1280 				       left > 16384 ? 16384 : left, 0);
1281 
1282 			EXPECT_GE(res, 0);
1283 			left -= res;
1284 		}
1285 
1286 		pid2 = wait(&status);
1287 		EXPECT_EQ(status, 0);
1288 		EXPECT_EQ(res, pid2);
1289 	} else {
1290 		/* child */
1291 		size_t left = data;
1292 		char buf[16384];
1293 
1294 		while (left) {
1295 			int res = recv(self->cfd, buf,
1296 				       left > 16384 ? 16384 : left, 0);
1297 
1298 			EXPECT_GE(res, 0);
1299 			left -= res;
1300 		}
1301 	}
1302 }
1303 
1304 TEST_F(tls, nonblocking)
1305 {
1306 	size_t data = 100000;
1307 	int sendbuf = 100;
1308 	int flags;
1309 	int res;
1310 
1311 	flags = fcntl(self->fd, F_GETFL, 0);
1312 	fcntl(self->fd, F_SETFL, flags | O_NONBLOCK);
1313 	fcntl(self->cfd, F_SETFL, flags | O_NONBLOCK);
1314 
1315 	/* Ensure nonblocking behavior by imposing a small send
1316 	 * buffer.
1317 	 */
1318 	EXPECT_EQ(setsockopt(self->fd, SOL_SOCKET, SO_SNDBUF,
1319 			     &sendbuf, sizeof(sendbuf)), 0);
1320 
1321 	res = fork();
1322 	EXPECT_NE(res, -1);
1323 
1324 	if (res) {
1325 		/* parent */
1326 		bool eagain = false;
1327 		size_t left = data;
1328 		char buf[16384];
1329 		int status;
1330 		int pid2;
1331 
1332 		while (left) {
1333 			int res = send(self->fd, buf,
1334 				       left > 16384 ? 16384 : left, 0);
1335 
1336 			if (res == -1 && errno == EAGAIN) {
1337 				eagain = true;
1338 				usleep(10000);
1339 				continue;
1340 			}
1341 			EXPECT_GE(res, 0);
1342 			left -= res;
1343 		}
1344 
1345 		EXPECT_TRUE(eagain);
1346 		pid2 = wait(&status);
1347 
1348 		EXPECT_EQ(status, 0);
1349 		EXPECT_EQ(res, pid2);
1350 	} else {
1351 		/* child */
1352 		bool eagain = false;
1353 		size_t left = data;
1354 		char buf[16384];
1355 
1356 		while (left) {
1357 			int res = recv(self->cfd, buf,
1358 				       left > 16384 ? 16384 : left, 0);
1359 
1360 			if (res == -1 && errno == EAGAIN) {
1361 				eagain = true;
1362 				usleep(10000);
1363 				continue;
1364 			}
1365 			EXPECT_GE(res, 0);
1366 			left -= res;
1367 		}
1368 		EXPECT_TRUE(eagain);
1369 	}
1370 }
1371 
1372 static void
1373 test_mutliproc(struct __test_metadata *_metadata, struct _test_data_tls *self,
1374 	       bool sendpg, unsigned int n_readers, unsigned int n_writers)
1375 {
1376 	const unsigned int n_children = n_readers + n_writers;
1377 	const size_t data = 6 * 1000 * 1000;
1378 	const size_t file_sz = data / 100;
1379 	size_t read_bias, write_bias;
1380 	int i, fd, child_id;
1381 	char buf[file_sz];
1382 	pid_t pid;
1383 
1384 	/* Only allow multiples for simplicity */
1385 	ASSERT_EQ(!(n_readers % n_writers) || !(n_writers % n_readers), true);
1386 	read_bias = n_writers / n_readers ?: 1;
1387 	write_bias = n_readers / n_writers ?: 1;
1388 
1389 	/* prep a file to send */
1390 	fd = open("/tmp/", O_TMPFILE | O_RDWR, 0600);
1391 	ASSERT_GE(fd, 0);
1392 
1393 	memset(buf, 0xac, file_sz);
1394 	ASSERT_EQ(write(fd, buf, file_sz), file_sz);
1395 
1396 	/* spawn children */
1397 	for (child_id = 0; child_id < n_children; child_id++) {
1398 		pid = fork();
1399 		ASSERT_NE(pid, -1);
1400 		if (!pid)
1401 			break;
1402 	}
1403 
1404 	/* parent waits for all children */
1405 	if (pid) {
1406 		for (i = 0; i < n_children; i++) {
1407 			int status;
1408 
1409 			wait(&status);
1410 			EXPECT_EQ(status, 0);
1411 		}
1412 
1413 		return;
1414 	}
1415 
1416 	/* Split threads for reading and writing */
1417 	if (child_id < n_readers) {
1418 		size_t left = data * read_bias;
1419 		char rb[8001];
1420 
1421 		while (left) {
1422 			int res;
1423 
1424 			res = recv(self->cfd, rb,
1425 				   left > sizeof(rb) ? sizeof(rb) : left, 0);
1426 
1427 			EXPECT_GE(res, 0);
1428 			left -= res;
1429 		}
1430 	} else {
1431 		size_t left = data * write_bias;
1432 
1433 		while (left) {
1434 			int res;
1435 
1436 			ASSERT_EQ(lseek(fd, 0, SEEK_SET), 0);
1437 			if (sendpg)
1438 				res = sendfile(self->fd, fd, NULL,
1439 					       left > file_sz ? file_sz : left);
1440 			else
1441 				res = send(self->fd, buf,
1442 					   left > file_sz ? file_sz : left, 0);
1443 
1444 			EXPECT_GE(res, 0);
1445 			left -= res;
1446 		}
1447 	}
1448 }
1449 
1450 TEST_F(tls, mutliproc_even)
1451 {
1452 	test_mutliproc(_metadata, self, false, 6, 6);
1453 }
1454 
1455 TEST_F(tls, mutliproc_readers)
1456 {
1457 	test_mutliproc(_metadata, self, false, 4, 12);
1458 }
1459 
1460 TEST_F(tls, mutliproc_writers)
1461 {
1462 	test_mutliproc(_metadata, self, false, 10, 2);
1463 }
1464 
1465 TEST_F(tls, mutliproc_sendpage_even)
1466 {
1467 	test_mutliproc(_metadata, self, true, 6, 6);
1468 }
1469 
1470 TEST_F(tls, mutliproc_sendpage_readers)
1471 {
1472 	test_mutliproc(_metadata, self, true, 4, 12);
1473 }
1474 
1475 TEST_F(tls, mutliproc_sendpage_writers)
1476 {
1477 	test_mutliproc(_metadata, self, true, 10, 2);
1478 }
1479 
1480 TEST_F(tls, control_msg)
1481 {
1482 	char *test_str = "test_read";
1483 	char record_type = 100;
1484 	int send_len = 10;
1485 	char buf[10];
1486 
1487 	if (self->notls)
1488 		SKIP(return, "no TLS support");
1489 
1490 	EXPECT_EQ(tls_send_cmsg(self->fd, record_type, test_str, send_len, 0),
1491 		  send_len);
1492 	/* Should fail because we didn't provide a control message */
1493 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
1494 
1495 	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
1496 				buf, sizeof(buf), MSG_WAITALL | MSG_PEEK),
1497 		  send_len);
1498 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1499 
1500 	/* Recv the message again without MSG_PEEK */
1501 	memset(buf, 0, sizeof(buf));
1502 
1503 	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
1504 				buf, sizeof(buf), MSG_WAITALL),
1505 		  send_len);
1506 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1507 }
1508 
1509 TEST_F(tls, control_msg_nomerge)
1510 {
1511 	char *rec1 = "1111";
1512 	char *rec2 = "2222";
1513 	int send_len = 5;
1514 	char buf[15];
1515 
1516 	if (self->notls)
1517 		SKIP(return, "no TLS support");
1518 
1519 	EXPECT_EQ(tls_send_cmsg(self->fd, 100, rec1, send_len, 0), send_len);
1520 	EXPECT_EQ(tls_send_cmsg(self->fd, 100, rec2, send_len, 0), send_len);
1521 
1522 	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, 100, buf, sizeof(buf), MSG_PEEK), send_len);
1523 	EXPECT_EQ(memcmp(buf, rec1, send_len), 0);
1524 
1525 	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, 100, buf, sizeof(buf), MSG_PEEK), send_len);
1526 	EXPECT_EQ(memcmp(buf, rec1, send_len), 0);
1527 
1528 	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, 100, buf, sizeof(buf), 0), send_len);
1529 	EXPECT_EQ(memcmp(buf, rec1, send_len), 0);
1530 
1531 	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, 100, buf, sizeof(buf), 0), send_len);
1532 	EXPECT_EQ(memcmp(buf, rec2, send_len), 0);
1533 }
1534 
1535 TEST_F(tls, data_control_data)
1536 {
1537 	char *rec1 = "1111";
1538 	char *rec2 = "2222";
1539 	char *rec3 = "3333";
1540 	int send_len = 5;
1541 	char buf[15];
1542 
1543 	if (self->notls)
1544 		SKIP(return, "no TLS support");
1545 
1546 	EXPECT_EQ(send(self->fd, rec1, send_len, 0), send_len);
1547 	EXPECT_EQ(tls_send_cmsg(self->fd, 100, rec2, send_len, 0), send_len);
1548 	EXPECT_EQ(send(self->fd, rec3, send_len, 0), send_len);
1549 
1550 	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_PEEK), send_len);
1551 	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_PEEK), send_len);
1552 }
1553 
1554 TEST_F(tls, shutdown)
1555 {
1556 	char const *test_str = "test_read";
1557 	int send_len = 10;
1558 	char buf[10];
1559 
1560 	ASSERT_EQ(strlen(test_str) + 1, send_len);
1561 
1562 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1563 	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1564 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1565 
1566 	shutdown(self->fd, SHUT_RDWR);
1567 	shutdown(self->cfd, SHUT_RDWR);
1568 }
1569 
1570 TEST_F(tls, shutdown_unsent)
1571 {
1572 	char const *test_str = "test_read";
1573 	int send_len = 10;
1574 
1575 	EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
1576 
1577 	shutdown(self->fd, SHUT_RDWR);
1578 	shutdown(self->cfd, SHUT_RDWR);
1579 }
1580 
1581 TEST_F(tls, shutdown_reuse)
1582 {
1583 	struct sockaddr_in addr;
1584 	int ret;
1585 
1586 	shutdown(self->fd, SHUT_RDWR);
1587 	shutdown(self->cfd, SHUT_RDWR);
1588 	close(self->cfd);
1589 
1590 	addr.sin_family = AF_INET;
1591 	addr.sin_addr.s_addr = htonl(INADDR_ANY);
1592 	addr.sin_port = 0;
1593 
1594 	ret = bind(self->fd, &addr, sizeof(addr));
1595 	EXPECT_EQ(ret, 0);
1596 	ret = listen(self->fd, 10);
1597 	EXPECT_EQ(ret, -1);
1598 	EXPECT_EQ(errno, EINVAL);
1599 
1600 	ret = connect(self->fd, &addr, sizeof(addr));
1601 	EXPECT_EQ(ret, -1);
1602 	EXPECT_EQ(errno, EISCONN);
1603 }
1604 
1605 TEST_F(tls, getsockopt)
1606 {
1607 	struct tls_crypto_info_keys expect, get;
1608 	socklen_t len;
1609 
1610 	/* get only the version/cipher */
1611 	len = sizeof(struct tls_crypto_info);
1612 	memrnd(&get, sizeof(get));
1613 	EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &get, &len), 0);
1614 	EXPECT_EQ(len, sizeof(struct tls_crypto_info));
1615 	EXPECT_EQ(get.crypto_info.version, variant->tls_version);
1616 	EXPECT_EQ(get.crypto_info.cipher_type, variant->cipher_type);
1617 
1618 	/* get the full crypto_info */
1619 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &expect, 0);
1620 	len = expect.len;
1621 	memrnd(&get, sizeof(get));
1622 	EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &get, &len), 0);
1623 	EXPECT_EQ(len, expect.len);
1624 	EXPECT_EQ(get.crypto_info.version, variant->tls_version);
1625 	EXPECT_EQ(get.crypto_info.cipher_type, variant->cipher_type);
1626 	EXPECT_EQ(memcmp(&get, &expect, expect.len), 0);
1627 
1628 	/* short get should fail */
1629 	len = sizeof(struct tls_crypto_info) - 1;
1630 	EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &get, &len), -1);
1631 	EXPECT_EQ(errno, EINVAL);
1632 
1633 	/* partial get of the cipher data should fail */
1634 	len = expect.len - 1;
1635 	EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &get, &len), -1);
1636 	EXPECT_EQ(errno, EINVAL);
1637 }
1638 
1639 TEST_F(tls, recv_efault)
1640 {
1641 	char *rec1 = "1111111111";
1642 	char *rec2 = "2222222222";
1643 	struct msghdr hdr = {};
1644 	struct iovec iov[2];
1645 	char recv_mem[12];
1646 	int ret;
1647 
1648 	if (self->notls)
1649 		SKIP(return, "no TLS support");
1650 
1651 	EXPECT_EQ(send(self->fd, rec1, 10, 0), 10);
1652 	EXPECT_EQ(send(self->fd, rec2, 10, 0), 10);
1653 
1654 	iov[0].iov_base = recv_mem;
1655 	iov[0].iov_len = sizeof(recv_mem);
1656 	iov[1].iov_base = NULL; /* broken iov to make process_rx_list fail */
1657 	iov[1].iov_len = 1;
1658 
1659 	hdr.msg_iovlen = 2;
1660 	hdr.msg_iov = iov;
1661 
1662 	EXPECT_EQ(recv(self->cfd, recv_mem, 1, 0), 1);
1663 	EXPECT_EQ(recv_mem[0], rec1[0]);
1664 
1665 	ret = recvmsg(self->cfd, &hdr, 0);
1666 	EXPECT_LE(ret, sizeof(recv_mem));
1667 	EXPECT_GE(ret, 9);
1668 	EXPECT_EQ(memcmp(rec1, recv_mem, 9), 0);
1669 	if (ret > 9)
1670 		EXPECT_EQ(memcmp(rec2, recv_mem + 9, ret - 9), 0);
1671 }
1672 
1673 #define TLS_RECORD_TYPE_HANDSHAKE      0x16
1674 /* key_update, length 1, update_not_requested */
1675 static const char key_update_msg[] = "\x18\x00\x00\x01\x00";
1676 static void tls_send_keyupdate(struct __test_metadata *_metadata, int fd)
1677 {
1678 	size_t len = sizeof(key_update_msg);
1679 
1680 	EXPECT_EQ(tls_send_cmsg(fd, TLS_RECORD_TYPE_HANDSHAKE,
1681 				(char *)key_update_msg, len, 0),
1682 		  len);
1683 }
1684 
1685 static void tls_recv_keyupdate(struct __test_metadata *_metadata, int fd, int flags)
1686 {
1687 	char buf[100];
1688 
1689 	EXPECT_EQ(tls_recv_cmsg(_metadata, fd, TLS_RECORD_TYPE_HANDSHAKE, buf, sizeof(buf), flags),
1690 		  sizeof(key_update_msg));
1691 	EXPECT_EQ(memcmp(buf, key_update_msg, sizeof(key_update_msg)), 0);
1692 }
1693 
1694 /* set the key to 0 then 1 for RX, immediately to 1 for TX */
1695 TEST_F(tls_basic, rekey_rx)
1696 {
1697 	struct tls_crypto_info_keys tls12_0, tls12_1;
1698 	char const *test_str = "test_message";
1699 	int send_len = strlen(test_str) + 1;
1700 	char buf[20];
1701 	int ret;
1702 
1703 	if (self->notls)
1704 		return;
1705 
1706 	tls_crypto_info_init(TLS_1_3_VERSION, TLS_CIPHER_AES_GCM_128,
1707 			     &tls12_0, 0);
1708 	tls_crypto_info_init(TLS_1_3_VERSION, TLS_CIPHER_AES_GCM_128,
1709 			     &tls12_1, 1);
1710 
1711 	ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12_1, tls12_1.len);
1712 	ASSERT_EQ(ret, 0);
1713 
1714 	ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12_0, tls12_0.len);
1715 	ASSERT_EQ(ret, 0);
1716 
1717 	ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12_1, tls12_1.len);
1718 	EXPECT_EQ(ret, 0);
1719 
1720 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1721 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), send_len);
1722 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1723 }
1724 
1725 /* set the key to 0 then 1 for TX, immediately to 1 for RX */
1726 TEST_F(tls_basic, rekey_tx)
1727 {
1728 	struct tls_crypto_info_keys tls12_0, tls12_1;
1729 	char const *test_str = "test_message";
1730 	int send_len = strlen(test_str) + 1;
1731 	char buf[20];
1732 	int ret;
1733 
1734 	if (self->notls)
1735 		return;
1736 
1737 	tls_crypto_info_init(TLS_1_3_VERSION, TLS_CIPHER_AES_GCM_128,
1738 			     &tls12_0, 0);
1739 	tls_crypto_info_init(TLS_1_3_VERSION, TLS_CIPHER_AES_GCM_128,
1740 			     &tls12_1, 1);
1741 
1742 	ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12_0, tls12_0.len);
1743 	ASSERT_EQ(ret, 0);
1744 
1745 	ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12_1, tls12_1.len);
1746 	ASSERT_EQ(ret, 0);
1747 
1748 	ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12_1, tls12_1.len);
1749 	EXPECT_EQ(ret, 0);
1750 
1751 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1752 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), send_len);
1753 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1754 }
1755 
1756 TEST_F(tls, rekey)
1757 {
1758 	char const *test_str_1 = "test_message_before_rekey";
1759 	char const *test_str_2 = "test_message_after_rekey";
1760 	struct tls_crypto_info_keys tls12;
1761 	int send_len;
1762 	char buf[100];
1763 
1764 	if (variant->tls_version != TLS_1_3_VERSION)
1765 		return;
1766 
1767 	/* initial send/recv */
1768 	send_len = strlen(test_str_1) + 1;
1769 	EXPECT_EQ(send(self->fd, test_str_1, send_len, 0), send_len);
1770 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), send_len);
1771 	EXPECT_EQ(memcmp(buf, test_str_1, send_len), 0);
1772 
1773 	/* update TX key */
1774 	tls_send_keyupdate(_metadata, self->fd);
1775 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
1776 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
1777 
1778 	/* send after rekey */
1779 	send_len = strlen(test_str_2) + 1;
1780 	EXPECT_EQ(send(self->fd, test_str_2, send_len, 0), send_len);
1781 
1782 	/* can't receive the KeyUpdate without a control message */
1783 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
1784 
1785 	/* get KeyUpdate */
1786 	tls_recv_keyupdate(_metadata, self->cfd, 0);
1787 
1788 	/* recv blocking -> -EKEYEXPIRED */
1789 	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), 0), -1);
1790 	EXPECT_EQ(errno, EKEYEXPIRED);
1791 
1792 	/* recv non-blocking -> -EKEYEXPIRED */
1793 	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_DONTWAIT), -1);
1794 	EXPECT_EQ(errno, EKEYEXPIRED);
1795 
1796 	/* update RX key */
1797 	EXPECT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
1798 
1799 	/* recv after rekey */
1800 	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1801 	EXPECT_EQ(memcmp(buf, test_str_2, send_len), 0);
1802 }
1803 
1804 TEST_F(tls, rekey_fail)
1805 {
1806 	char const *test_str_1 = "test_message_before_rekey";
1807 	char const *test_str_2 = "test_message_after_rekey";
1808 	struct tls_crypto_info_keys tls12;
1809 	int send_len;
1810 	char buf[100];
1811 
1812 	/* initial send/recv */
1813 	send_len = strlen(test_str_1) + 1;
1814 	EXPECT_EQ(send(self->fd, test_str_1, send_len, 0), send_len);
1815 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), send_len);
1816 	EXPECT_EQ(memcmp(buf, test_str_1, send_len), 0);
1817 
1818 	/* update TX key */
1819 	tls_send_keyupdate(_metadata, self->fd);
1820 
1821 	if (variant->tls_version != TLS_1_3_VERSION) {
1822 		/* just check that rekey is not supported and return */
1823 		tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
1824 		EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), -1);
1825 		EXPECT_EQ(errno, EBUSY);
1826 		return;
1827 	}
1828 
1829 	/* successful update */
1830 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
1831 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
1832 
1833 	/* invalid update: change of version */
1834 	tls_crypto_info_init(TLS_1_2_VERSION, variant->cipher_type, &tls12, 1);
1835 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), -1);
1836 	EXPECT_EQ(errno, EINVAL);
1837 
1838 	/* invalid update (RX socket): change of version */
1839 	tls_crypto_info_init(TLS_1_2_VERSION, variant->cipher_type, &tls12, 1);
1840 	EXPECT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), -1);
1841 	EXPECT_EQ(errno, EINVAL);
1842 
1843 	/* invalid update: change of cipher */
1844 	if (variant->cipher_type == TLS_CIPHER_AES_GCM_256)
1845 		tls_crypto_info_init(variant->tls_version, TLS_CIPHER_CHACHA20_POLY1305, &tls12, 1);
1846 	else
1847 		tls_crypto_info_init(variant->tls_version, TLS_CIPHER_AES_GCM_256, &tls12, 1);
1848 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), -1);
1849 	EXPECT_EQ(errno, EINVAL);
1850 
1851 	/* send after rekey, the invalid updates shouldn't have an effect */
1852 	send_len = strlen(test_str_2) + 1;
1853 	EXPECT_EQ(send(self->fd, test_str_2, send_len, 0), send_len);
1854 
1855 	/* can't receive the KeyUpdate without a control message */
1856 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
1857 
1858 	/* get KeyUpdate */
1859 	tls_recv_keyupdate(_metadata, self->cfd, 0);
1860 
1861 	/* recv blocking -> -EKEYEXPIRED */
1862 	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), 0), -1);
1863 	EXPECT_EQ(errno, EKEYEXPIRED);
1864 
1865 	/* recv non-blocking -> -EKEYEXPIRED */
1866 	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_DONTWAIT), -1);
1867 	EXPECT_EQ(errno, EKEYEXPIRED);
1868 
1869 	/* update RX key */
1870 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
1871 	EXPECT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
1872 
1873 	/* recv after rekey */
1874 	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1875 	EXPECT_EQ(memcmp(buf, test_str_2, send_len), 0);
1876 }
1877 
1878 TEST_F(tls, rekey_peek)
1879 {
1880 	char const *test_str_1 = "test_message_before_rekey";
1881 	struct tls_crypto_info_keys tls12;
1882 	int send_len;
1883 	char buf[100];
1884 
1885 	if (variant->tls_version != TLS_1_3_VERSION)
1886 		return;
1887 
1888 	send_len = strlen(test_str_1) + 1;
1889 	EXPECT_EQ(send(self->fd, test_str_1, send_len, 0), send_len);
1890 
1891 	/* update TX key */
1892 	tls_send_keyupdate(_metadata, self->fd);
1893 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
1894 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
1895 
1896 	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_PEEK), send_len);
1897 	EXPECT_EQ(memcmp(buf, test_str_1, send_len), 0);
1898 
1899 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), send_len);
1900 	EXPECT_EQ(memcmp(buf, test_str_1, send_len), 0);
1901 
1902 	/* can't receive the KeyUpdate without a control message */
1903 	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_PEEK), -1);
1904 
1905 	/* peek KeyUpdate */
1906 	tls_recv_keyupdate(_metadata, self->cfd, MSG_PEEK);
1907 
1908 	/* get KeyUpdate */
1909 	tls_recv_keyupdate(_metadata, self->cfd, 0);
1910 
1911 	/* update RX key */
1912 	EXPECT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
1913 }
1914 
1915 TEST_F(tls, splice_rekey)
1916 {
1917 	int send_len = TLS_PAYLOAD_MAX_LEN / 2;
1918 	char mem_send[TLS_PAYLOAD_MAX_LEN];
1919 	char mem_recv[TLS_PAYLOAD_MAX_LEN];
1920 	struct tls_crypto_info_keys tls12;
1921 	int p[2];
1922 
1923 	if (variant->tls_version != TLS_1_3_VERSION)
1924 		return;
1925 
1926 	memrnd(mem_send, sizeof(mem_send));
1927 
1928 	ASSERT_GE(pipe(p), 0);
1929 	EXPECT_EQ(send(self->fd, mem_send, send_len, 0), send_len);
1930 
1931 	/* update TX key */
1932 	tls_send_keyupdate(_metadata, self->fd);
1933 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
1934 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
1935 
1936 	EXPECT_EQ(send(self->fd, mem_send, send_len, 0), send_len);
1937 
1938 	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, TLS_PAYLOAD_MAX_LEN, 0), send_len);
1939 	EXPECT_EQ(read(p[0], mem_recv, send_len), send_len);
1940 	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
1941 
1942 	/* can't splice the KeyUpdate */
1943 	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, TLS_PAYLOAD_MAX_LEN, 0), -1);
1944 	EXPECT_EQ(errno, EINVAL);
1945 
1946 	/* peek KeyUpdate */
1947 	tls_recv_keyupdate(_metadata, self->cfd, MSG_PEEK);
1948 
1949 	/* get KeyUpdate */
1950 	tls_recv_keyupdate(_metadata, self->cfd, 0);
1951 
1952 	/* can't splice before updating the key */
1953 	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, TLS_PAYLOAD_MAX_LEN, 0), -1);
1954 	EXPECT_EQ(errno, EKEYEXPIRED);
1955 
1956 	/* update RX key */
1957 	EXPECT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
1958 
1959 	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, TLS_PAYLOAD_MAX_LEN, 0), send_len);
1960 	EXPECT_EQ(read(p[0], mem_recv, send_len), send_len);
1961 	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
1962 }
1963 
1964 TEST_F(tls, rekey_peek_splice)
1965 {
1966 	char const *test_str_1 = "test_message_before_rekey";
1967 	struct tls_crypto_info_keys tls12;
1968 	int send_len;
1969 	char buf[100];
1970 	char mem_recv[TLS_PAYLOAD_MAX_LEN];
1971 	int p[2];
1972 
1973 	if (variant->tls_version != TLS_1_3_VERSION)
1974 		return;
1975 
1976 	ASSERT_GE(pipe(p), 0);
1977 
1978 	send_len = strlen(test_str_1) + 1;
1979 	EXPECT_EQ(send(self->fd, test_str_1, send_len, 0), send_len);
1980 
1981 	/* update TX key */
1982 	tls_send_keyupdate(_metadata, self->fd);
1983 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
1984 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
1985 
1986 	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_PEEK), send_len);
1987 	EXPECT_EQ(memcmp(buf, test_str_1, send_len), 0);
1988 
1989 	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, TLS_PAYLOAD_MAX_LEN, 0), send_len);
1990 	EXPECT_EQ(read(p[0], mem_recv, send_len), send_len);
1991 	EXPECT_EQ(memcmp(mem_recv, test_str_1, send_len), 0);
1992 }
1993 
1994 TEST_F(tls, rekey_getsockopt)
1995 {
1996 	struct tls_crypto_info_keys tls12;
1997 	struct tls_crypto_info_keys tls12_get;
1998 	socklen_t len;
1999 
2000 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 0);
2001 
2002 	len = tls12.len;
2003 	EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &tls12_get, &len), 0);
2004 	EXPECT_EQ(len, tls12.len);
2005 	EXPECT_EQ(memcmp(&tls12_get, &tls12, tls12.len), 0);
2006 
2007 	len = tls12.len;
2008 	EXPECT_EQ(getsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12_get, &len), 0);
2009 	EXPECT_EQ(len, tls12.len);
2010 	EXPECT_EQ(memcmp(&tls12_get, &tls12, tls12.len), 0);
2011 
2012 	if (variant->tls_version != TLS_1_3_VERSION)
2013 		return;
2014 
2015 	tls_send_keyupdate(_metadata, self->fd);
2016 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
2017 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
2018 
2019 	tls_recv_keyupdate(_metadata, self->cfd, 0);
2020 	EXPECT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
2021 
2022 	len = tls12.len;
2023 	EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &tls12_get, &len), 0);
2024 	EXPECT_EQ(len, tls12.len);
2025 	EXPECT_EQ(memcmp(&tls12_get, &tls12, tls12.len), 0);
2026 
2027 	len = tls12.len;
2028 	EXPECT_EQ(getsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12_get, &len), 0);
2029 	EXPECT_EQ(len, tls12.len);
2030 	EXPECT_EQ(memcmp(&tls12_get, &tls12, tls12.len), 0);
2031 }
2032 
2033 TEST_F(tls, rekey_poll_pending)
2034 {
2035 	char const *test_str = "test_message_after_rekey";
2036 	struct tls_crypto_info_keys tls12;
2037 	struct pollfd pfd = { };
2038 	int send_len;
2039 	int ret;
2040 
2041 	if (variant->tls_version != TLS_1_3_VERSION)
2042 		return;
2043 
2044 	/* update TX key */
2045 	tls_send_keyupdate(_metadata, self->fd);
2046 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
2047 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
2048 
2049 	/* get KeyUpdate */
2050 	tls_recv_keyupdate(_metadata, self->cfd, 0);
2051 
2052 	/* send immediately after rekey */
2053 	send_len = strlen(test_str) + 1;
2054 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
2055 
2056 	/* key hasn't been updated, expect cfd to be non-readable */
2057 	pfd.fd = self->cfd;
2058 	pfd.events = POLLIN;
2059 	EXPECT_EQ(poll(&pfd, 1, 0), 0);
2060 
2061 	ret = fork();
2062 	ASSERT_GE(ret, 0);
2063 
2064 	if (ret) {
2065 		int pid2, status;
2066 
2067 		/* wait before installing the new key */
2068 		sleep(1);
2069 
2070 		/* update RX key while poll() is sleeping */
2071 		EXPECT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
2072 
2073 		pid2 = wait(&status);
2074 		EXPECT_EQ(pid2, ret);
2075 		EXPECT_EQ(status, 0);
2076 	} else {
2077 		pfd.fd = self->cfd;
2078 		pfd.events = POLLIN;
2079 		EXPECT_EQ(poll(&pfd, 1, 5000), 1);
2080 
2081 		exit(!__test_passed(_metadata));
2082 	}
2083 }
2084 
2085 TEST_F(tls, rekey_poll_delay)
2086 {
2087 	char const *test_str = "test_message_after_rekey";
2088 	struct tls_crypto_info_keys tls12;
2089 	struct pollfd pfd = { };
2090 	int send_len;
2091 	int ret;
2092 
2093 	if (variant->tls_version != TLS_1_3_VERSION)
2094 		return;
2095 
2096 	/* update TX key */
2097 	tls_send_keyupdate(_metadata, self->fd);
2098 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
2099 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
2100 
2101 	/* get KeyUpdate */
2102 	tls_recv_keyupdate(_metadata, self->cfd, 0);
2103 
2104 	ret = fork();
2105 	ASSERT_GE(ret, 0);
2106 
2107 	if (ret) {
2108 		int pid2, status;
2109 
2110 		/* wait before installing the new key */
2111 		sleep(1);
2112 
2113 		/* update RX key while poll() is sleeping */
2114 		EXPECT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
2115 
2116 		sleep(1);
2117 		send_len = strlen(test_str) + 1;
2118 		EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
2119 
2120 		pid2 = wait(&status);
2121 		EXPECT_EQ(pid2, ret);
2122 		EXPECT_EQ(status, 0);
2123 	} else {
2124 		pfd.fd = self->cfd;
2125 		pfd.events = POLLIN;
2126 		EXPECT_EQ(poll(&pfd, 1, 5000), 1);
2127 		exit(!__test_passed(_metadata));
2128 	}
2129 }
2130 
2131 FIXTURE(tls_err)
2132 {
2133 	int fd, cfd;
2134 	int fd2, cfd2;
2135 	bool notls;
2136 };
2137 
2138 FIXTURE_VARIANT(tls_err)
2139 {
2140 	uint16_t tls_version;
2141 };
2142 
2143 FIXTURE_VARIANT_ADD(tls_err, 12_aes_gcm)
2144 {
2145 	.tls_version = TLS_1_2_VERSION,
2146 };
2147 
2148 FIXTURE_VARIANT_ADD(tls_err, 13_aes_gcm)
2149 {
2150 	.tls_version = TLS_1_3_VERSION,
2151 };
2152 
2153 FIXTURE_SETUP(tls_err)
2154 {
2155 	struct tls_crypto_info_keys tls12;
2156 	int ret;
2157 
2158 	tls_crypto_info_init(variant->tls_version, TLS_CIPHER_AES_GCM_128,
2159 			     &tls12, 0);
2160 
2161 	ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
2162 	ulp_sock_pair(_metadata, &self->fd2, &self->cfd2, &self->notls);
2163 	if (self->notls)
2164 		return;
2165 
2166 	ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
2167 	ASSERT_EQ(ret, 0);
2168 
2169 	ret = setsockopt(self->cfd2, SOL_TLS, TLS_RX, &tls12, tls12.len);
2170 	ASSERT_EQ(ret, 0);
2171 }
2172 
2173 FIXTURE_TEARDOWN(tls_err)
2174 {
2175 	close(self->fd);
2176 	close(self->cfd);
2177 	close(self->fd2);
2178 	close(self->cfd2);
2179 }
2180 
2181 TEST_F(tls_err, bad_rec)
2182 {
2183 	char buf[64];
2184 
2185 	if (self->notls)
2186 		SKIP(return, "no TLS support");
2187 
2188 	memset(buf, 0x55, sizeof(buf));
2189 	EXPECT_EQ(send(self->fd2, buf, sizeof(buf), 0), sizeof(buf));
2190 	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2191 	EXPECT_EQ(errno, EMSGSIZE);
2192 	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), MSG_DONTWAIT), -1);
2193 	EXPECT_EQ(errno, EAGAIN);
2194 }
2195 
2196 TEST_F(tls_err, bad_auth)
2197 {
2198 	char buf[128];
2199 	int n;
2200 
2201 	if (self->notls)
2202 		SKIP(return, "no TLS support");
2203 
2204 	memrnd(buf, sizeof(buf) / 2);
2205 	EXPECT_EQ(send(self->fd, buf, sizeof(buf) / 2, 0), sizeof(buf) / 2);
2206 	n = recv(self->cfd, buf, sizeof(buf), 0);
2207 	EXPECT_GT(n, sizeof(buf) / 2);
2208 
2209 	buf[n - 1]++;
2210 
2211 	EXPECT_EQ(send(self->fd2, buf, n, 0), n);
2212 	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2213 	EXPECT_EQ(errno, EBADMSG);
2214 	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2215 	EXPECT_EQ(errno, EBADMSG);
2216 }
2217 
2218 TEST_F(tls_err, bad_in_large_read)
2219 {
2220 	char txt[3][64];
2221 	char cip[3][128];
2222 	char buf[3 * 128];
2223 	int i, n;
2224 
2225 	if (self->notls)
2226 		SKIP(return, "no TLS support");
2227 
2228 	/* Put 3 records in the sockets */
2229 	for (i = 0; i < 3; i++) {
2230 		memrnd(txt[i], sizeof(txt[i]));
2231 		EXPECT_EQ(send(self->fd, txt[i], sizeof(txt[i]), 0),
2232 			  sizeof(txt[i]));
2233 		n = recv(self->cfd, cip[i], sizeof(cip[i]), 0);
2234 		EXPECT_GT(n, sizeof(txt[i]));
2235 		/* Break the third message */
2236 		if (i == 2)
2237 			cip[2][n - 1]++;
2238 		EXPECT_EQ(send(self->fd2, cip[i], n, 0), n);
2239 	}
2240 
2241 	/* We should be able to receive the first two messages */
2242 	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), sizeof(txt[0]) * 2);
2243 	EXPECT_EQ(memcmp(buf, txt[0], sizeof(txt[0])), 0);
2244 	EXPECT_EQ(memcmp(buf + sizeof(txt[0]), txt[1], sizeof(txt[1])), 0);
2245 	/* Third mesasge is bad */
2246 	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2247 	EXPECT_EQ(errno, EBADMSG);
2248 	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2249 	EXPECT_EQ(errno, EBADMSG);
2250 }
2251 
2252 TEST_F(tls_err, bad_cmsg)
2253 {
2254 	char *test_str = "test_read";
2255 	int send_len = 10;
2256 	char cip[128];
2257 	char buf[128];
2258 	char txt[64];
2259 	int n;
2260 
2261 	if (self->notls)
2262 		SKIP(return, "no TLS support");
2263 
2264 	/* Queue up one data record */
2265 	memrnd(txt, sizeof(txt));
2266 	EXPECT_EQ(send(self->fd, txt, sizeof(txt), 0), sizeof(txt));
2267 	n = recv(self->cfd, cip, sizeof(cip), 0);
2268 	EXPECT_GT(n, sizeof(txt));
2269 	EXPECT_EQ(send(self->fd2, cip, n, 0), n);
2270 
2271 	EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
2272 	n = recv(self->cfd, cip, sizeof(cip), 0);
2273 	cip[n - 1]++; /* Break it */
2274 	EXPECT_GT(n, send_len);
2275 	EXPECT_EQ(send(self->fd2, cip, n, 0), n);
2276 
2277 	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), sizeof(txt));
2278 	EXPECT_EQ(memcmp(buf, txt, sizeof(txt)), 0);
2279 	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2280 	EXPECT_EQ(errno, EBADMSG);
2281 	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2282 	EXPECT_EQ(errno, EBADMSG);
2283 }
2284 
2285 TEST_F(tls_err, timeo)
2286 {
2287 	struct timeval tv = { .tv_usec = 10000, };
2288 	char buf[128];
2289 	int ret;
2290 
2291 	if (self->notls)
2292 		SKIP(return, "no TLS support");
2293 
2294 	ret = setsockopt(self->cfd2, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
2295 	ASSERT_EQ(ret, 0);
2296 
2297 	ret = fork();
2298 	ASSERT_GE(ret, 0);
2299 
2300 	if (ret) {
2301 		usleep(1000); /* Give child a head start */
2302 
2303 		EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2304 		EXPECT_EQ(errno, EAGAIN);
2305 
2306 		EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2307 		EXPECT_EQ(errno, EAGAIN);
2308 
2309 		wait(&ret);
2310 	} else {
2311 		EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2312 		EXPECT_EQ(errno, EAGAIN);
2313 		exit(0);
2314 	}
2315 }
2316 
2317 TEST_F(tls_err, poll_partial_rec)
2318 {
2319 	struct pollfd pfd = { };
2320 	ssize_t rec_len;
2321 	char rec[256];
2322 	char buf[128];
2323 
2324 	if (self->notls)
2325 		SKIP(return, "no TLS support");
2326 
2327 	pfd.fd = self->cfd2;
2328 	pfd.events = POLLIN;
2329 	EXPECT_EQ(poll(&pfd, 1, 1), 0);
2330 
2331 	memrnd(buf, sizeof(buf));
2332 	EXPECT_EQ(send(self->fd, buf, sizeof(buf), 0), sizeof(buf));
2333 	rec_len = recv(self->cfd, rec, sizeof(rec), 0);
2334 	EXPECT_GT(rec_len, sizeof(buf));
2335 
2336 	/* Write 100B, not the full record ... */
2337 	EXPECT_EQ(send(self->fd2, rec, 100, 0), 100);
2338 	/* ... no full record should mean no POLLIN */
2339 	pfd.fd = self->cfd2;
2340 	pfd.events = POLLIN;
2341 	EXPECT_EQ(poll(&pfd, 1, 1), 0);
2342 	/* Now write the rest, and it should all pop out of the other end. */
2343 	EXPECT_EQ(send(self->fd2, rec + 100, rec_len - 100, 0), rec_len - 100);
2344 	pfd.fd = self->cfd2;
2345 	pfd.events = POLLIN;
2346 	EXPECT_EQ(poll(&pfd, 1, 1), 1);
2347 	EXPECT_EQ(recv(self->cfd2, rec, sizeof(rec), 0), sizeof(buf));
2348 	EXPECT_EQ(memcmp(buf, rec, sizeof(buf)), 0);
2349 }
2350 
2351 TEST_F(tls_err, epoll_partial_rec)
2352 {
2353 	struct epoll_event ev, events[10];
2354 	ssize_t rec_len;
2355 	char rec[256];
2356 	char buf[128];
2357 	int epollfd;
2358 
2359 	if (self->notls)
2360 		SKIP(return, "no TLS support");
2361 
2362 	epollfd = epoll_create1(0);
2363 	ASSERT_GE(epollfd, 0);
2364 
2365 	memset(&ev, 0, sizeof(ev));
2366 	ev.events = EPOLLIN;
2367 	ev.data.fd = self->cfd2;
2368 	ASSERT_GE(epoll_ctl(epollfd, EPOLL_CTL_ADD, self->cfd2, &ev), 0);
2369 
2370 	EXPECT_EQ(epoll_wait(epollfd, events, 10, 0), 0);
2371 
2372 	memrnd(buf, sizeof(buf));
2373 	EXPECT_EQ(send(self->fd, buf, sizeof(buf), 0), sizeof(buf));
2374 	rec_len = recv(self->cfd, rec, sizeof(rec), 0);
2375 	EXPECT_GT(rec_len, sizeof(buf));
2376 
2377 	/* Write 100B, not the full record ... */
2378 	EXPECT_EQ(send(self->fd2, rec, 100, 0), 100);
2379 	/* ... no full record should mean no POLLIN */
2380 	EXPECT_EQ(epoll_wait(epollfd, events, 10, 0), 0);
2381 	/* Now write the rest, and it should all pop out of the other end. */
2382 	EXPECT_EQ(send(self->fd2, rec + 100, rec_len - 100, 0), rec_len - 100);
2383 	EXPECT_EQ(epoll_wait(epollfd, events, 10, 0), 1);
2384 	EXPECT_EQ(recv(self->cfd2, rec, sizeof(rec), 0), sizeof(buf));
2385 	EXPECT_EQ(memcmp(buf, rec, sizeof(buf)), 0);
2386 
2387 	close(epollfd);
2388 }
2389 
2390 TEST_F(tls_err, poll_partial_rec_async)
2391 {
2392 	struct pollfd pfd = { };
2393 	ssize_t rec_len;
2394 	char rec[256];
2395 	char buf[128];
2396 	char token;
2397 	int p[2];
2398 	int ret;
2399 
2400 	if (self->notls)
2401 		SKIP(return, "no TLS support");
2402 
2403 	ASSERT_GE(pipe(p), 0);
2404 
2405 	memrnd(buf, sizeof(buf));
2406 	EXPECT_EQ(send(self->fd, buf, sizeof(buf), 0), sizeof(buf));
2407 	rec_len = recv(self->cfd, rec, sizeof(rec), 0);
2408 	EXPECT_GT(rec_len, sizeof(buf));
2409 
2410 	ret = fork();
2411 	ASSERT_GE(ret, 0);
2412 
2413 	if (ret) {
2414 		int status, pid2;
2415 
2416 		close(p[1]);
2417 		usleep(1000); /* Give child a head start */
2418 
2419 		EXPECT_EQ(send(self->fd2, rec, 100, 0), 100);
2420 
2421 		EXPECT_EQ(read(p[0], &token, 1), 1); /* Barrier #1 */
2422 
2423 		EXPECT_EQ(send(self->fd2, rec + 100, rec_len - 100, 0),
2424 			  rec_len - 100);
2425 
2426 		pid2 = wait(&status);
2427 		EXPECT_EQ(pid2, ret);
2428 		EXPECT_EQ(status, 0);
2429 	} else {
2430 		close(p[0]);
2431 
2432 		/* Child should sleep in poll(), never get a wake */
2433 		pfd.fd = self->cfd2;
2434 		pfd.events = POLLIN;
2435 		EXPECT_EQ(poll(&pfd, 1, 20), 0);
2436 
2437 		EXPECT_EQ(write(p[1], &token, 1), 1); /* Barrier #1 */
2438 
2439 		pfd.fd = self->cfd2;
2440 		pfd.events = POLLIN;
2441 		EXPECT_EQ(poll(&pfd, 1, 20), 1);
2442 
2443 		exit(!__test_passed(_metadata));
2444 	}
2445 }
2446 
2447 TEST(non_established) {
2448 	struct tls12_crypto_info_aes_gcm_256 tls12;
2449 	struct sockaddr_in addr;
2450 	int sfd, ret, fd;
2451 	socklen_t len;
2452 
2453 	len = sizeof(addr);
2454 
2455 	memset(&tls12, 0, sizeof(tls12));
2456 	tls12.info.version = TLS_1_2_VERSION;
2457 	tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
2458 
2459 	addr.sin_family = AF_INET;
2460 	addr.sin_addr.s_addr = htonl(INADDR_ANY);
2461 	addr.sin_port = 0;
2462 
2463 	fd = socket(AF_INET, SOCK_STREAM, 0);
2464 	sfd = socket(AF_INET, SOCK_STREAM, 0);
2465 
2466 	ret = bind(sfd, &addr, sizeof(addr));
2467 	ASSERT_EQ(ret, 0);
2468 	ret = listen(sfd, 10);
2469 	ASSERT_EQ(ret, 0);
2470 
2471 	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2472 	EXPECT_EQ(ret, -1);
2473 	/* TLS ULP not supported */
2474 	if (errno == ENOENT)
2475 		return;
2476 	EXPECT_EQ(errno, ENOTCONN);
2477 
2478 	ret = setsockopt(sfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2479 	EXPECT_EQ(ret, -1);
2480 	EXPECT_EQ(errno, ENOTCONN);
2481 
2482 	ret = getsockname(sfd, &addr, &len);
2483 	ASSERT_EQ(ret, 0);
2484 
2485 	ret = connect(fd, &addr, sizeof(addr));
2486 	ASSERT_EQ(ret, 0);
2487 
2488 	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2489 	ASSERT_EQ(ret, 0);
2490 
2491 	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2492 	EXPECT_EQ(ret, -1);
2493 	EXPECT_EQ(errno, EEXIST);
2494 
2495 	close(fd);
2496 	close(sfd);
2497 }
2498 
2499 TEST(keysizes) {
2500 	struct tls12_crypto_info_aes_gcm_256 tls12;
2501 	int ret, fd, cfd;
2502 	bool notls;
2503 
2504 	memset(&tls12, 0, sizeof(tls12));
2505 	tls12.info.version = TLS_1_2_VERSION;
2506 	tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
2507 
2508 	ulp_sock_pair(_metadata, &fd, &cfd, &notls);
2509 
2510 	if (!notls) {
2511 		ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12,
2512 				 sizeof(tls12));
2513 		EXPECT_EQ(ret, 0);
2514 
2515 		ret = setsockopt(cfd, SOL_TLS, TLS_RX, &tls12,
2516 				 sizeof(tls12));
2517 		EXPECT_EQ(ret, 0);
2518 	}
2519 
2520 	close(fd);
2521 	close(cfd);
2522 }
2523 
2524 TEST(no_pad) {
2525 	struct tls12_crypto_info_aes_gcm_256 tls12;
2526 	int ret, fd, cfd, val;
2527 	socklen_t len;
2528 	bool notls;
2529 
2530 	memset(&tls12, 0, sizeof(tls12));
2531 	tls12.info.version = TLS_1_3_VERSION;
2532 	tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
2533 
2534 	ulp_sock_pair(_metadata, &fd, &cfd, &notls);
2535 
2536 	if (notls)
2537 		exit(KSFT_SKIP);
2538 
2539 	ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12, sizeof(tls12));
2540 	EXPECT_EQ(ret, 0);
2541 
2542 	ret = setsockopt(cfd, SOL_TLS, TLS_RX, &tls12, sizeof(tls12));
2543 	EXPECT_EQ(ret, 0);
2544 
2545 	val = 1;
2546 	ret = setsockopt(cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
2547 			 (void *)&val, sizeof(val));
2548 	EXPECT_EQ(ret, 0);
2549 
2550 	len = sizeof(val);
2551 	val = 2;
2552 	ret = getsockopt(cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
2553 			 (void *)&val, &len);
2554 	EXPECT_EQ(ret, 0);
2555 	EXPECT_EQ(val, 1);
2556 	EXPECT_EQ(len, 4);
2557 
2558 	val = 0;
2559 	ret = setsockopt(cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
2560 			 (void *)&val, sizeof(val));
2561 	EXPECT_EQ(ret, 0);
2562 
2563 	len = sizeof(val);
2564 	val = 2;
2565 	ret = getsockopt(cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
2566 			 (void *)&val, &len);
2567 	EXPECT_EQ(ret, 0);
2568 	EXPECT_EQ(val, 0);
2569 	EXPECT_EQ(len, 4);
2570 
2571 	close(fd);
2572 	close(cfd);
2573 }
2574 
2575 TEST(tls_v6ops) {
2576 	struct tls_crypto_info_keys tls12;
2577 	struct sockaddr_in6 addr, addr2;
2578 	int sfd, ret, fd;
2579 	socklen_t len, len2;
2580 
2581 	tls_crypto_info_init(TLS_1_2_VERSION, TLS_CIPHER_AES_GCM_128, &tls12, 0);
2582 
2583 	addr.sin6_family = AF_INET6;
2584 	addr.sin6_addr = in6addr_any;
2585 	addr.sin6_port = 0;
2586 
2587 	fd = socket(AF_INET6, SOCK_STREAM, 0);
2588 	sfd = socket(AF_INET6, SOCK_STREAM, 0);
2589 
2590 	ret = bind(sfd, &addr, sizeof(addr));
2591 	ASSERT_EQ(ret, 0);
2592 	ret = listen(sfd, 10);
2593 	ASSERT_EQ(ret, 0);
2594 
2595 	len = sizeof(addr);
2596 	ret = getsockname(sfd, &addr, &len);
2597 	ASSERT_EQ(ret, 0);
2598 
2599 	ret = connect(fd, &addr, sizeof(addr));
2600 	ASSERT_EQ(ret, 0);
2601 
2602 	len = sizeof(addr);
2603 	ret = getsockname(fd, &addr, &len);
2604 	ASSERT_EQ(ret, 0);
2605 
2606 	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2607 	if (ret) {
2608 		ASSERT_EQ(errno, ENOENT);
2609 		SKIP(return, "no TLS support");
2610 	}
2611 	ASSERT_EQ(ret, 0);
2612 
2613 	ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
2614 	ASSERT_EQ(ret, 0);
2615 
2616 	ret = setsockopt(fd, SOL_TLS, TLS_RX, &tls12, tls12.len);
2617 	ASSERT_EQ(ret, 0);
2618 
2619 	len2 = sizeof(addr2);
2620 	ret = getsockname(fd, &addr2, &len2);
2621 	ASSERT_EQ(ret, 0);
2622 
2623 	EXPECT_EQ(len2, len);
2624 	EXPECT_EQ(memcmp(&addr, &addr2, len), 0);
2625 
2626 	close(fd);
2627 	close(sfd);
2628 }
2629 
2630 TEST(prequeue) {
2631 	struct tls_crypto_info_keys tls12;
2632 	char buf[20000], buf2[20000];
2633 	struct sockaddr_in addr;
2634 	int sfd, cfd, ret, fd;
2635 	socklen_t len;
2636 
2637 	len = sizeof(addr);
2638 	memrnd(buf, sizeof(buf));
2639 
2640 	tls_crypto_info_init(TLS_1_2_VERSION, TLS_CIPHER_AES_GCM_256, &tls12, 0);
2641 
2642 	addr.sin_family = AF_INET;
2643 	addr.sin_addr.s_addr = htonl(INADDR_ANY);
2644 	addr.sin_port = 0;
2645 
2646 	fd = socket(AF_INET, SOCK_STREAM, 0);
2647 	sfd = socket(AF_INET, SOCK_STREAM, 0);
2648 
2649 	ASSERT_EQ(bind(sfd, &addr, sizeof(addr)), 0);
2650 	ASSERT_EQ(listen(sfd, 10), 0);
2651 	ASSERT_EQ(getsockname(sfd, &addr, &len), 0);
2652 	ASSERT_EQ(connect(fd, &addr, sizeof(addr)), 0);
2653 	ASSERT_GE(cfd = accept(sfd, &addr, &len), 0);
2654 	close(sfd);
2655 
2656 	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2657 	if (ret) {
2658 		ASSERT_EQ(errno, ENOENT);
2659 		SKIP(return, "no TLS support");
2660 	}
2661 
2662 	ASSERT_EQ(setsockopt(fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
2663 	EXPECT_EQ(send(fd, buf, sizeof(buf), MSG_DONTWAIT), sizeof(buf));
2664 
2665 	ASSERT_EQ(setsockopt(cfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls")), 0);
2666 	ASSERT_EQ(setsockopt(cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
2667 	EXPECT_EQ(recv(cfd, buf2, sizeof(buf2), MSG_WAITALL), sizeof(buf2));
2668 
2669 	EXPECT_EQ(memcmp(buf, buf2, sizeof(buf)), 0);
2670 
2671 	close(fd);
2672 	close(cfd);
2673 }
2674 
2675 static void __attribute__((constructor)) fips_check(void) {
2676 	int res;
2677 	FILE *f;
2678 
2679 	f = fopen("/proc/sys/crypto/fips_enabled", "r");
2680 	if (f) {
2681 		res = fscanf(f, "%d", &fips_enabled);
2682 		if (res != 1)
2683 			ksft_print_msg("ERROR: Couldn't read /proc/sys/crypto/fips_enabled\n");
2684 		fclose(f);
2685 	}
2686 }
2687 
2688 TEST_HARNESS_MAIN
2689