xref: /linux/tools/testing/selftests/net/tls.c (revision 634ec1fc7982efeeeeed4a7688b0004827b43a21)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 #define _GNU_SOURCE
4 
5 #include <arpa/inet.h>
6 #include <errno.h>
7 #include <error.h>
8 #include <fcntl.h>
9 #include <poll.h>
10 #include <stdio.h>
11 #include <stdlib.h>
12 #include <unistd.h>
13 
14 #include <linux/tls.h>
15 #include <linux/tcp.h>
16 #include <linux/socket.h>
17 
18 #include <sys/epoll.h>
19 #include <sys/types.h>
20 #include <sys/sendfile.h>
21 #include <sys/socket.h>
22 #include <sys/stat.h>
23 
24 #include "../kselftest_harness.h"
25 
26 #define TLS_PAYLOAD_MAX_LEN 16384
27 #define SOL_TLS 282
28 
29 static int fips_enabled;
30 
31 struct tls_crypto_info_keys {
32 	union {
33 		struct tls_crypto_info crypto_info;
34 		struct tls12_crypto_info_aes_gcm_128 aes128;
35 		struct tls12_crypto_info_chacha20_poly1305 chacha20;
36 		struct tls12_crypto_info_sm4_gcm sm4gcm;
37 		struct tls12_crypto_info_sm4_ccm sm4ccm;
38 		struct tls12_crypto_info_aes_ccm_128 aesccm128;
39 		struct tls12_crypto_info_aes_gcm_256 aesgcm256;
40 		struct tls12_crypto_info_aria_gcm_128 ariagcm128;
41 		struct tls12_crypto_info_aria_gcm_256 ariagcm256;
42 	};
43 	size_t len;
44 };
45 
tls_crypto_info_init(uint16_t tls_version,uint16_t cipher_type,struct tls_crypto_info_keys * tls12,char key_generation)46 static void tls_crypto_info_init(uint16_t tls_version, uint16_t cipher_type,
47 				 struct tls_crypto_info_keys *tls12,
48 				 char key_generation)
49 {
50 	memset(tls12, key_generation, sizeof(*tls12));
51 	memset(tls12, 0, sizeof(struct tls_crypto_info));
52 
53 	switch (cipher_type) {
54 	case TLS_CIPHER_CHACHA20_POLY1305:
55 		tls12->len = sizeof(struct tls12_crypto_info_chacha20_poly1305);
56 		tls12->chacha20.info.version = tls_version;
57 		tls12->chacha20.info.cipher_type = cipher_type;
58 		break;
59 	case TLS_CIPHER_AES_GCM_128:
60 		tls12->len = sizeof(struct tls12_crypto_info_aes_gcm_128);
61 		tls12->aes128.info.version = tls_version;
62 		tls12->aes128.info.cipher_type = cipher_type;
63 		break;
64 	case TLS_CIPHER_SM4_GCM:
65 		tls12->len = sizeof(struct tls12_crypto_info_sm4_gcm);
66 		tls12->sm4gcm.info.version = tls_version;
67 		tls12->sm4gcm.info.cipher_type = cipher_type;
68 		break;
69 	case TLS_CIPHER_SM4_CCM:
70 		tls12->len = sizeof(struct tls12_crypto_info_sm4_ccm);
71 		tls12->sm4ccm.info.version = tls_version;
72 		tls12->sm4ccm.info.cipher_type = cipher_type;
73 		break;
74 	case TLS_CIPHER_AES_CCM_128:
75 		tls12->len = sizeof(struct tls12_crypto_info_aes_ccm_128);
76 		tls12->aesccm128.info.version = tls_version;
77 		tls12->aesccm128.info.cipher_type = cipher_type;
78 		break;
79 	case TLS_CIPHER_AES_GCM_256:
80 		tls12->len = sizeof(struct tls12_crypto_info_aes_gcm_256);
81 		tls12->aesgcm256.info.version = tls_version;
82 		tls12->aesgcm256.info.cipher_type = cipher_type;
83 		break;
84 	case TLS_CIPHER_ARIA_GCM_128:
85 		tls12->len = sizeof(struct tls12_crypto_info_aria_gcm_128);
86 		tls12->ariagcm128.info.version = tls_version;
87 		tls12->ariagcm128.info.cipher_type = cipher_type;
88 		break;
89 	case TLS_CIPHER_ARIA_GCM_256:
90 		tls12->len = sizeof(struct tls12_crypto_info_aria_gcm_256);
91 		tls12->ariagcm256.info.version = tls_version;
92 		tls12->ariagcm256.info.cipher_type = cipher_type;
93 		break;
94 	default:
95 		break;
96 	}
97 }
98 
memrnd(void * s,size_t n)99 static void memrnd(void *s, size_t n)
100 {
101 	int *dword = s;
102 	char *byte;
103 
104 	for (; n >= 4; n -= 4)
105 		*dword++ = rand();
106 	byte = (void *)dword;
107 	while (n--)
108 		*byte++ = rand();
109 }
110 
ulp_sock_pair(struct __test_metadata * _metadata,int * fd,int * cfd,bool * notls)111 static void ulp_sock_pair(struct __test_metadata *_metadata,
112 			  int *fd, int *cfd, bool *notls)
113 {
114 	struct sockaddr_in addr;
115 	socklen_t len;
116 	int sfd, ret;
117 
118 	*notls = false;
119 	len = sizeof(addr);
120 
121 	addr.sin_family = AF_INET;
122 	addr.sin_addr.s_addr = htonl(INADDR_ANY);
123 	addr.sin_port = 0;
124 
125 	*fd = socket(AF_INET, SOCK_STREAM, 0);
126 	sfd = socket(AF_INET, SOCK_STREAM, 0);
127 
128 	ret = bind(sfd, &addr, sizeof(addr));
129 	ASSERT_EQ(ret, 0);
130 	ret = listen(sfd, 10);
131 	ASSERT_EQ(ret, 0);
132 
133 	ret = getsockname(sfd, &addr, &len);
134 	ASSERT_EQ(ret, 0);
135 
136 	ret = connect(*fd, &addr, sizeof(addr));
137 	ASSERT_EQ(ret, 0);
138 
139 	*cfd = accept(sfd, &addr, &len);
140 	ASSERT_GE(*cfd, 0);
141 
142 	close(sfd);
143 
144 	ret = setsockopt(*fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
145 	if (ret != 0) {
146 		ASSERT_EQ(errno, ENOENT);
147 		*notls = true;
148 		printf("Failure setting TCP_ULP, testing without tls\n");
149 		return;
150 	}
151 
152 	ret = setsockopt(*cfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
153 	ASSERT_EQ(ret, 0);
154 }
155 
156 /* Produce a basic cmsg */
tls_send_cmsg(int fd,unsigned char record_type,void * data,size_t len,int flags)157 static int tls_send_cmsg(int fd, unsigned char record_type,
158 			 void *data, size_t len, int flags)
159 {
160 	char cbuf[CMSG_SPACE(sizeof(char))];
161 	int cmsg_len = sizeof(char);
162 	struct cmsghdr *cmsg;
163 	struct msghdr msg;
164 	struct iovec vec;
165 
166 	vec.iov_base = data;
167 	vec.iov_len = len;
168 	memset(&msg, 0, sizeof(struct msghdr));
169 	msg.msg_iov = &vec;
170 	msg.msg_iovlen = 1;
171 	msg.msg_control = cbuf;
172 	msg.msg_controllen = sizeof(cbuf);
173 	cmsg = CMSG_FIRSTHDR(&msg);
174 	cmsg->cmsg_level = SOL_TLS;
175 	/* test sending non-record types. */
176 	cmsg->cmsg_type = TLS_SET_RECORD_TYPE;
177 	cmsg->cmsg_len = CMSG_LEN(cmsg_len);
178 	*CMSG_DATA(cmsg) = record_type;
179 	msg.msg_controllen = cmsg->cmsg_len;
180 
181 	return sendmsg(fd, &msg, flags);
182 }
183 
__tls_recv_cmsg(struct __test_metadata * _metadata,int fd,unsigned char * ctype,void * data,size_t len,int flags)184 static int __tls_recv_cmsg(struct __test_metadata *_metadata,
185 			   int fd, unsigned char *ctype,
186 			   void *data, size_t len, int flags)
187 {
188 	char cbuf[CMSG_SPACE(sizeof(char))];
189 	struct cmsghdr *cmsg;
190 	struct msghdr msg;
191 	struct iovec vec;
192 	int n;
193 
194 	vec.iov_base = data;
195 	vec.iov_len = len;
196 	memset(&msg, 0, sizeof(struct msghdr));
197 	msg.msg_iov = &vec;
198 	msg.msg_iovlen = 1;
199 	msg.msg_control = cbuf;
200 	msg.msg_controllen = sizeof(cbuf);
201 
202 	n = recvmsg(fd, &msg, flags);
203 
204 	cmsg = CMSG_FIRSTHDR(&msg);
205 	EXPECT_NE(cmsg, NULL);
206 	EXPECT_EQ(cmsg->cmsg_level, SOL_TLS);
207 	EXPECT_EQ(cmsg->cmsg_type, TLS_GET_RECORD_TYPE);
208 	if (ctype)
209 		*ctype = *((unsigned char *)CMSG_DATA(cmsg));
210 
211 	return n;
212 }
213 
tls_recv_cmsg(struct __test_metadata * _metadata,int fd,unsigned char record_type,void * data,size_t len,int flags)214 static int tls_recv_cmsg(struct __test_metadata *_metadata,
215 			 int fd, unsigned char record_type,
216 			 void *data, size_t len, int flags)
217 {
218 	unsigned char ctype;
219 	int n;
220 
221 	n = __tls_recv_cmsg(_metadata, fd, &ctype, data, len, flags);
222 	EXPECT_EQ(ctype, record_type);
223 
224 	return n;
225 }
226 
FIXTURE(tls_basic)227 FIXTURE(tls_basic)
228 {
229 	int fd, cfd;
230 	bool notls;
231 };
232 
FIXTURE_SETUP(tls_basic)233 FIXTURE_SETUP(tls_basic)
234 {
235 	ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
236 }
237 
FIXTURE_TEARDOWN(tls_basic)238 FIXTURE_TEARDOWN(tls_basic)
239 {
240 	close(self->fd);
241 	close(self->cfd);
242 }
243 
244 /* Send some data through with ULP but no keys */
TEST_F(tls_basic,base_base)245 TEST_F(tls_basic, base_base)
246 {
247 	char const *test_str = "test_read";
248 	int send_len = 10;
249 	char buf[10];
250 
251 	ASSERT_EQ(strlen(test_str) + 1, send_len);
252 
253 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
254 	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
255 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
256 };
257 
TEST_F(tls_basic,bad_cipher)258 TEST_F(tls_basic, bad_cipher)
259 {
260 	struct tls_crypto_info_keys tls12;
261 
262 	tls12.crypto_info.version = 200;
263 	tls12.crypto_info.cipher_type = TLS_CIPHER_AES_GCM_128;
264 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
265 
266 	tls12.crypto_info.version = TLS_1_2_VERSION;
267 	tls12.crypto_info.cipher_type = 50;
268 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
269 
270 	tls12.crypto_info.version = TLS_1_2_VERSION;
271 	tls12.crypto_info.cipher_type = 59;
272 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
273 
274 	tls12.crypto_info.version = TLS_1_2_VERSION;
275 	tls12.crypto_info.cipher_type = 10;
276 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
277 
278 	tls12.crypto_info.version = TLS_1_2_VERSION;
279 	tls12.crypto_info.cipher_type = 70;
280 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
281 }
282 
TEST_F(tls_basic,recseq_wrap)283 TEST_F(tls_basic, recseq_wrap)
284 {
285 	struct tls_crypto_info_keys tls12;
286 	char const *test_str = "test_read";
287 	int send_len = 10;
288 
289 	if (self->notls)
290 		SKIP(return, "no TLS support");
291 
292 	tls_crypto_info_init(TLS_1_2_VERSION, TLS_CIPHER_AES_GCM_128, &tls12, 0);
293 	memset(&tls12.aes128.rec_seq, 0xff, sizeof(tls12.aes128.rec_seq));
294 
295 	ASSERT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
296 	ASSERT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
297 
298 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), -1);
299 	EXPECT_EQ(errno, EBADMSG);
300 }
301 
FIXTURE(tls)302 FIXTURE(tls)
303 {
304 	int fd, cfd;
305 	bool notls;
306 };
307 
FIXTURE_VARIANT(tls)308 FIXTURE_VARIANT(tls)
309 {
310 	uint16_t tls_version;
311 	uint16_t cipher_type;
312 	bool nopad, fips_non_compliant;
313 };
314 
315 FIXTURE_VARIANT_ADD(tls, 12_aes_gcm)
316 {
317 	.tls_version = TLS_1_2_VERSION,
318 	.cipher_type = TLS_CIPHER_AES_GCM_128,
319 };
320 
321 FIXTURE_VARIANT_ADD(tls, 13_aes_gcm)
322 {
323 	.tls_version = TLS_1_3_VERSION,
324 	.cipher_type = TLS_CIPHER_AES_GCM_128,
325 };
326 
327 FIXTURE_VARIANT_ADD(tls, 12_chacha)
328 {
329 	.tls_version = TLS_1_2_VERSION,
330 	.cipher_type = TLS_CIPHER_CHACHA20_POLY1305,
331 	.fips_non_compliant = true,
332 };
333 
334 FIXTURE_VARIANT_ADD(tls, 13_chacha)
335 {
336 	.tls_version = TLS_1_3_VERSION,
337 	.cipher_type = TLS_CIPHER_CHACHA20_POLY1305,
338 	.fips_non_compliant = true,
339 };
340 
341 FIXTURE_VARIANT_ADD(tls, 13_sm4_gcm)
342 {
343 	.tls_version = TLS_1_3_VERSION,
344 	.cipher_type = TLS_CIPHER_SM4_GCM,
345 	.fips_non_compliant = true,
346 };
347 
348 FIXTURE_VARIANT_ADD(tls, 13_sm4_ccm)
349 {
350 	.tls_version = TLS_1_3_VERSION,
351 	.cipher_type = TLS_CIPHER_SM4_CCM,
352 	.fips_non_compliant = true,
353 };
354 
355 FIXTURE_VARIANT_ADD(tls, 12_aes_ccm)
356 {
357 	.tls_version = TLS_1_2_VERSION,
358 	.cipher_type = TLS_CIPHER_AES_CCM_128,
359 };
360 
361 FIXTURE_VARIANT_ADD(tls, 13_aes_ccm)
362 {
363 	.tls_version = TLS_1_3_VERSION,
364 	.cipher_type = TLS_CIPHER_AES_CCM_128,
365 };
366 
367 FIXTURE_VARIANT_ADD(tls, 12_aes_gcm_256)
368 {
369 	.tls_version = TLS_1_2_VERSION,
370 	.cipher_type = TLS_CIPHER_AES_GCM_256,
371 };
372 
373 FIXTURE_VARIANT_ADD(tls, 13_aes_gcm_256)
374 {
375 	.tls_version = TLS_1_3_VERSION,
376 	.cipher_type = TLS_CIPHER_AES_GCM_256,
377 };
378 
379 FIXTURE_VARIANT_ADD(tls, 13_nopad)
380 {
381 	.tls_version = TLS_1_3_VERSION,
382 	.cipher_type = TLS_CIPHER_AES_GCM_128,
383 	.nopad = true,
384 };
385 
386 FIXTURE_VARIANT_ADD(tls, 12_aria_gcm)
387 {
388 	.tls_version = TLS_1_2_VERSION,
389 	.cipher_type = TLS_CIPHER_ARIA_GCM_128,
390 };
391 
392 FIXTURE_VARIANT_ADD(tls, 12_aria_gcm_256)
393 {
394 	.tls_version = TLS_1_2_VERSION,
395 	.cipher_type = TLS_CIPHER_ARIA_GCM_256,
396 };
397 
FIXTURE_SETUP(tls)398 FIXTURE_SETUP(tls)
399 {
400 	struct tls_crypto_info_keys tls12;
401 	int one = 1;
402 	int ret;
403 
404 	if (fips_enabled && variant->fips_non_compliant)
405 		SKIP(return, "Unsupported cipher in FIPS mode");
406 
407 	tls_crypto_info_init(variant->tls_version, variant->cipher_type,
408 			     &tls12, 0);
409 
410 	ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
411 
412 	if (self->notls)
413 		return;
414 
415 	ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
416 	ASSERT_EQ(ret, 0);
417 
418 	ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len);
419 	ASSERT_EQ(ret, 0);
420 
421 	if (variant->nopad) {
422 		ret = setsockopt(self->cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
423 				 (void *)&one, sizeof(one));
424 		ASSERT_EQ(ret, 0);
425 	}
426 }
427 
FIXTURE_TEARDOWN(tls)428 FIXTURE_TEARDOWN(tls)
429 {
430 	close(self->fd);
431 	close(self->cfd);
432 }
433 
TEST_F(tls,sendfile)434 TEST_F(tls, sendfile)
435 {
436 	int filefd = open("/proc/self/exe", O_RDONLY);
437 	struct stat st;
438 
439 	EXPECT_GE(filefd, 0);
440 	fstat(filefd, &st);
441 	EXPECT_GE(sendfile(self->fd, filefd, 0, st.st_size), 0);
442 
443 	close(filefd);
444 }
445 
TEST_F(tls,send_then_sendfile)446 TEST_F(tls, send_then_sendfile)
447 {
448 	int filefd = open("/proc/self/exe", O_RDONLY);
449 	char const *test_str = "test_send";
450 	int to_send = strlen(test_str) + 1;
451 	char recv_buf[10];
452 	struct stat st;
453 	char *buf;
454 
455 	EXPECT_GE(filefd, 0);
456 	fstat(filefd, &st);
457 	buf = (char *)malloc(st.st_size);
458 
459 	EXPECT_EQ(send(self->fd, test_str, to_send, 0), to_send);
460 	EXPECT_EQ(recv(self->cfd, recv_buf, to_send, MSG_WAITALL), to_send);
461 	EXPECT_EQ(memcmp(test_str, recv_buf, to_send), 0);
462 
463 	EXPECT_GE(sendfile(self->fd, filefd, 0, st.st_size), 0);
464 	EXPECT_EQ(recv(self->cfd, buf, st.st_size, MSG_WAITALL), st.st_size);
465 
466 	free(buf);
467 	close(filefd);
468 }
469 
chunked_sendfile(struct __test_metadata * _metadata,struct _test_data_tls * self,uint16_t chunk_size,uint16_t extra_payload_size)470 static void chunked_sendfile(struct __test_metadata *_metadata,
471 			     struct _test_data_tls *self,
472 			     uint16_t chunk_size,
473 			     uint16_t extra_payload_size)
474 {
475 	char buf[TLS_PAYLOAD_MAX_LEN];
476 	uint16_t test_payload_size;
477 	int size = 0;
478 	int ret;
479 	char filename[] = "/tmp/mytemp.XXXXXX";
480 	int fd = mkstemp(filename);
481 	off_t offset = 0;
482 
483 	unlink(filename);
484 	ASSERT_GE(fd, 0);
485 	EXPECT_GE(chunk_size, 1);
486 	test_payload_size = chunk_size + extra_payload_size;
487 	ASSERT_GE(TLS_PAYLOAD_MAX_LEN, test_payload_size);
488 	memset(buf, 1, test_payload_size);
489 	size = write(fd, buf, test_payload_size);
490 	EXPECT_EQ(size, test_payload_size);
491 	fsync(fd);
492 
493 	while (size > 0) {
494 		ret = sendfile(self->fd, fd, &offset, chunk_size);
495 		EXPECT_GE(ret, 0);
496 		size -= ret;
497 	}
498 
499 	EXPECT_EQ(recv(self->cfd, buf, test_payload_size, MSG_WAITALL),
500 		  test_payload_size);
501 
502 	close(fd);
503 }
504 
TEST_F(tls,multi_chunk_sendfile)505 TEST_F(tls, multi_chunk_sendfile)
506 {
507 	chunked_sendfile(_metadata, self, 4096, 4096);
508 	chunked_sendfile(_metadata, self, 4096, 0);
509 	chunked_sendfile(_metadata, self, 4096, 1);
510 	chunked_sendfile(_metadata, self, 4096, 2048);
511 	chunked_sendfile(_metadata, self, 8192, 2048);
512 	chunked_sendfile(_metadata, self, 4096, 8192);
513 	chunked_sendfile(_metadata, self, 8192, 4096);
514 	chunked_sendfile(_metadata, self, 12288, 1024);
515 	chunked_sendfile(_metadata, self, 12288, 2000);
516 	chunked_sendfile(_metadata, self, 15360, 100);
517 	chunked_sendfile(_metadata, self, 15360, 300);
518 	chunked_sendfile(_metadata, self, 1, 4096);
519 	chunked_sendfile(_metadata, self, 2048, 4096);
520 	chunked_sendfile(_metadata, self, 2048, 8192);
521 	chunked_sendfile(_metadata, self, 4096, 8192);
522 	chunked_sendfile(_metadata, self, 1024, 12288);
523 	chunked_sendfile(_metadata, self, 2000, 12288);
524 	chunked_sendfile(_metadata, self, 100, 15360);
525 	chunked_sendfile(_metadata, self, 300, 15360);
526 }
527 
TEST_F(tls,recv_max)528 TEST_F(tls, recv_max)
529 {
530 	unsigned int send_len = TLS_PAYLOAD_MAX_LEN;
531 	char recv_mem[TLS_PAYLOAD_MAX_LEN];
532 	char buf[TLS_PAYLOAD_MAX_LEN];
533 
534 	memrnd(buf, sizeof(buf));
535 
536 	EXPECT_GE(send(self->fd, buf, send_len, 0), 0);
537 	EXPECT_NE(recv(self->cfd, recv_mem, send_len, 0), -1);
538 	EXPECT_EQ(memcmp(buf, recv_mem, send_len), 0);
539 }
540 
TEST_F(tls,recv_small)541 TEST_F(tls, recv_small)
542 {
543 	char const *test_str = "test_read";
544 	int send_len = 10;
545 	char buf[10];
546 
547 	send_len = strlen(test_str) + 1;
548 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
549 	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
550 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
551 }
552 
TEST_F(tls,msg_more)553 TEST_F(tls, msg_more)
554 {
555 	char const *test_str = "test_read";
556 	int send_len = 10;
557 	char buf[10 * 2];
558 
559 	EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
560 	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_DONTWAIT), -1);
561 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
562 	EXPECT_EQ(recv(self->cfd, buf, send_len * 2, MSG_WAITALL),
563 		  send_len * 2);
564 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
565 }
566 
TEST_F(tls,cmsg_msg_more)567 TEST_F(tls, cmsg_msg_more)
568 {
569 	char *test_str =  "test_read";
570 	char record_type = 100;
571 	int send_len = 10;
572 
573 	/* we don't allow MSG_MORE with non-DATA records */
574 	EXPECT_EQ(tls_send_cmsg(self->fd, record_type, test_str, send_len,
575 				MSG_MORE), -1);
576 	EXPECT_EQ(errno, EINVAL);
577 }
578 
TEST_F(tls,msg_more_then_cmsg)579 TEST_F(tls, msg_more_then_cmsg)
580 {
581 	char *test_str = "test_read";
582 	char record_type = 100;
583 	int send_len = 10;
584 	char buf[10 * 2];
585 	int ret;
586 
587 	EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
588 	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_DONTWAIT), -1);
589 
590 	ret = tls_send_cmsg(self->fd, record_type, test_str, send_len, 0);
591 	EXPECT_EQ(ret, send_len);
592 
593 	/* initial DATA record didn't get merged with the non-DATA record */
594 	EXPECT_EQ(recv(self->cfd, buf, send_len * 2, 0), send_len);
595 
596 	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
597 				buf, sizeof(buf), MSG_WAITALL),
598 		  send_len);
599 }
600 
TEST_F(tls,msg_more_unsent)601 TEST_F(tls, msg_more_unsent)
602 {
603 	char const *test_str = "test_read";
604 	int send_len = 10;
605 	char buf[10];
606 
607 	EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
608 	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_DONTWAIT), -1);
609 }
610 
TEST_F(tls,msg_eor)611 TEST_F(tls, msg_eor)
612 {
613 	char const *test_str = "test_read";
614 	int send_len = 10;
615 	char buf[10];
616 
617 	EXPECT_EQ(send(self->fd, test_str, send_len, MSG_EOR), send_len);
618 	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
619 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
620 }
621 
TEST_F(tls,sendmsg_single)622 TEST_F(tls, sendmsg_single)
623 {
624 	struct msghdr msg;
625 
626 	char const *test_str = "test_sendmsg";
627 	size_t send_len = 13;
628 	struct iovec vec;
629 	char buf[13];
630 
631 	vec.iov_base = (char *)test_str;
632 	vec.iov_len = send_len;
633 	memset(&msg, 0, sizeof(struct msghdr));
634 	msg.msg_iov = &vec;
635 	msg.msg_iovlen = 1;
636 	EXPECT_EQ(sendmsg(self->fd, &msg, 0), send_len);
637 	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
638 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
639 }
640 
641 #define MAX_FRAGS	64
642 #define SEND_LEN	13
TEST_F(tls,sendmsg_fragmented)643 TEST_F(tls, sendmsg_fragmented)
644 {
645 	char const *test_str = "test_sendmsg";
646 	char buf[SEND_LEN * MAX_FRAGS];
647 	struct iovec vec[MAX_FRAGS];
648 	struct msghdr msg;
649 	int i, frags;
650 
651 	for (frags = 1; frags <= MAX_FRAGS; frags++) {
652 		for (i = 0; i < frags; i++) {
653 			vec[i].iov_base = (char *)test_str;
654 			vec[i].iov_len = SEND_LEN;
655 		}
656 
657 		memset(&msg, 0, sizeof(struct msghdr));
658 		msg.msg_iov = vec;
659 		msg.msg_iovlen = frags;
660 
661 		EXPECT_EQ(sendmsg(self->fd, &msg, 0), SEND_LEN * frags);
662 		EXPECT_EQ(recv(self->cfd, buf, SEND_LEN * frags, MSG_WAITALL),
663 			  SEND_LEN * frags);
664 
665 		for (i = 0; i < frags; i++)
666 			EXPECT_EQ(memcmp(buf + SEND_LEN * i,
667 					 test_str, SEND_LEN), 0);
668 	}
669 }
670 #undef MAX_FRAGS
671 #undef SEND_LEN
672 
TEST_F(tls,sendmsg_large)673 TEST_F(tls, sendmsg_large)
674 {
675 	void *mem = malloc(16384);
676 	size_t send_len = 16384;
677 	size_t sends = 128;
678 	struct msghdr msg;
679 	size_t recvs = 0;
680 	size_t sent = 0;
681 
682 	memset(&msg, 0, sizeof(struct msghdr));
683 	while (sent++ < sends) {
684 		struct iovec vec = { (void *)mem, send_len };
685 
686 		msg.msg_iov = &vec;
687 		msg.msg_iovlen = 1;
688 		EXPECT_EQ(sendmsg(self->fd, &msg, 0), send_len);
689 	}
690 
691 	while (recvs++ < sends) {
692 		EXPECT_NE(recv(self->cfd, mem, send_len, 0), -1);
693 	}
694 
695 	free(mem);
696 }
697 
TEST_F(tls,sendmsg_multiple)698 TEST_F(tls, sendmsg_multiple)
699 {
700 	char const *test_str = "test_sendmsg_multiple";
701 	struct iovec vec[5];
702 	char *test_strs[5];
703 	struct msghdr msg;
704 	int total_len = 0;
705 	int len_cmp = 0;
706 	int iov_len = 5;
707 	char *buf;
708 	int i;
709 
710 	memset(&msg, 0, sizeof(struct msghdr));
711 	for (i = 0; i < iov_len; i++) {
712 		test_strs[i] = (char *)malloc(strlen(test_str) + 1);
713 		snprintf(test_strs[i], strlen(test_str) + 1, "%s", test_str);
714 		vec[i].iov_base = (void *)test_strs[i];
715 		vec[i].iov_len = strlen(test_strs[i]) + 1;
716 		total_len += vec[i].iov_len;
717 	}
718 	msg.msg_iov = vec;
719 	msg.msg_iovlen = iov_len;
720 
721 	EXPECT_EQ(sendmsg(self->fd, &msg, 0), total_len);
722 	buf = malloc(total_len);
723 	EXPECT_NE(recv(self->cfd, buf, total_len, 0), -1);
724 	for (i = 0; i < iov_len; i++) {
725 		EXPECT_EQ(memcmp(test_strs[i], buf + len_cmp,
726 				 strlen(test_strs[i])),
727 			  0);
728 		len_cmp += strlen(buf + len_cmp) + 1;
729 	}
730 	for (i = 0; i < iov_len; i++)
731 		free(test_strs[i]);
732 	free(buf);
733 }
734 
TEST_F(tls,sendmsg_multiple_stress)735 TEST_F(tls, sendmsg_multiple_stress)
736 {
737 	char const *test_str = "abcdefghijklmno";
738 	struct iovec vec[1024];
739 	char *test_strs[1024];
740 	int iov_len = 1024;
741 	int total_len = 0;
742 	char buf[1 << 14];
743 	struct msghdr msg;
744 	int len_cmp = 0;
745 	int i;
746 
747 	memset(&msg, 0, sizeof(struct msghdr));
748 	for (i = 0; i < iov_len; i++) {
749 		test_strs[i] = (char *)malloc(strlen(test_str) + 1);
750 		snprintf(test_strs[i], strlen(test_str) + 1, "%s", test_str);
751 		vec[i].iov_base = (void *)test_strs[i];
752 		vec[i].iov_len = strlen(test_strs[i]) + 1;
753 		total_len += vec[i].iov_len;
754 	}
755 	msg.msg_iov = vec;
756 	msg.msg_iovlen = iov_len;
757 
758 	EXPECT_EQ(sendmsg(self->fd, &msg, 0), total_len);
759 	EXPECT_NE(recv(self->cfd, buf, total_len, 0), -1);
760 
761 	for (i = 0; i < iov_len; i++)
762 		len_cmp += strlen(buf + len_cmp) + 1;
763 
764 	for (i = 0; i < iov_len; i++)
765 		free(test_strs[i]);
766 }
767 
TEST_F(tls,splice_from_pipe)768 TEST_F(tls, splice_from_pipe)
769 {
770 	int send_len = TLS_PAYLOAD_MAX_LEN;
771 	char mem_send[TLS_PAYLOAD_MAX_LEN];
772 	char mem_recv[TLS_PAYLOAD_MAX_LEN];
773 	int p[2];
774 
775 	ASSERT_GE(pipe(p), 0);
776 	EXPECT_GE(write(p[1], mem_send, send_len), 0);
777 	EXPECT_GE(splice(p[0], NULL, self->fd, NULL, send_len, 0), 0);
778 	EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
779 	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
780 }
781 
TEST_F(tls,splice_more)782 TEST_F(tls, splice_more)
783 {
784 	unsigned int f = SPLICE_F_NONBLOCK | SPLICE_F_MORE | SPLICE_F_GIFT;
785 	int send_len = TLS_PAYLOAD_MAX_LEN;
786 	char mem_send[TLS_PAYLOAD_MAX_LEN];
787 	int i, send_pipe = 1;
788 	int p[2];
789 
790 	ASSERT_GE(pipe(p), 0);
791 	EXPECT_GE(write(p[1], mem_send, send_len), 0);
792 	for (i = 0; i < 32; i++)
793 		EXPECT_EQ(splice(p[0], NULL, self->fd, NULL, send_pipe, f), 1);
794 }
795 
TEST_F(tls,splice_from_pipe2)796 TEST_F(tls, splice_from_pipe2)
797 {
798 	int send_len = 16000;
799 	char mem_send[16000];
800 	char mem_recv[16000];
801 	int p2[2];
802 	int p[2];
803 
804 	memrnd(mem_send, sizeof(mem_send));
805 
806 	ASSERT_GE(pipe(p), 0);
807 	ASSERT_GE(pipe(p2), 0);
808 	EXPECT_EQ(write(p[1], mem_send, 8000), 8000);
809 	EXPECT_EQ(splice(p[0], NULL, self->fd, NULL, 8000, 0), 8000);
810 	EXPECT_EQ(write(p2[1], mem_send + 8000, 8000), 8000);
811 	EXPECT_EQ(splice(p2[0], NULL, self->fd, NULL, 8000, 0), 8000);
812 	EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
813 	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
814 }
815 
TEST_F(tls,send_and_splice)816 TEST_F(tls, send_and_splice)
817 {
818 	int send_len = TLS_PAYLOAD_MAX_LEN;
819 	char mem_send[TLS_PAYLOAD_MAX_LEN];
820 	char mem_recv[TLS_PAYLOAD_MAX_LEN];
821 	char const *test_str = "test_read";
822 	int send_len2 = 10;
823 	char buf[10];
824 	int p[2];
825 
826 	ASSERT_GE(pipe(p), 0);
827 	EXPECT_EQ(send(self->fd, test_str, send_len2, 0), send_len2);
828 	EXPECT_EQ(recv(self->cfd, buf, send_len2, MSG_WAITALL), send_len2);
829 	EXPECT_EQ(memcmp(test_str, buf, send_len2), 0);
830 
831 	EXPECT_GE(write(p[1], mem_send, send_len), send_len);
832 	EXPECT_GE(splice(p[0], NULL, self->fd, NULL, send_len, 0), send_len);
833 
834 	EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
835 	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
836 }
837 
TEST_F(tls,splice_to_pipe)838 TEST_F(tls, splice_to_pipe)
839 {
840 	int send_len = TLS_PAYLOAD_MAX_LEN;
841 	char mem_send[TLS_PAYLOAD_MAX_LEN];
842 	char mem_recv[TLS_PAYLOAD_MAX_LEN];
843 	int p[2];
844 
845 	memrnd(mem_send, sizeof(mem_send));
846 
847 	ASSERT_GE(pipe(p), 0);
848 	EXPECT_EQ(send(self->fd, mem_send, send_len, 0), send_len);
849 	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, send_len, 0), send_len);
850 	EXPECT_EQ(read(p[0], mem_recv, send_len), send_len);
851 	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
852 }
853 
TEST_F(tls,splice_cmsg_to_pipe)854 TEST_F(tls, splice_cmsg_to_pipe)
855 {
856 	char *test_str = "test_read";
857 	char record_type = 100;
858 	int send_len = 10;
859 	char buf[10];
860 	int p[2];
861 
862 	if (self->notls)
863 		SKIP(return, "no TLS support");
864 
865 	ASSERT_GE(pipe(p), 0);
866 	EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
867 	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, send_len, 0), -1);
868 	EXPECT_EQ(errno, EINVAL);
869 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
870 	EXPECT_EQ(errno, EIO);
871 	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
872 				buf, sizeof(buf), MSG_WAITALL),
873 		  send_len);
874 	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
875 }
876 
TEST_F(tls,splice_dec_cmsg_to_pipe)877 TEST_F(tls, splice_dec_cmsg_to_pipe)
878 {
879 	char *test_str = "test_read";
880 	char record_type = 100;
881 	int send_len = 10;
882 	char buf[10];
883 	int p[2];
884 
885 	if (self->notls)
886 		SKIP(return, "no TLS support");
887 
888 	ASSERT_GE(pipe(p), 0);
889 	EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
890 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
891 	EXPECT_EQ(errno, EIO);
892 	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, send_len, 0), -1);
893 	EXPECT_EQ(errno, EINVAL);
894 	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
895 				buf, sizeof(buf), MSG_WAITALL),
896 		  send_len);
897 	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
898 }
899 
TEST_F(tls,recv_and_splice)900 TEST_F(tls, recv_and_splice)
901 {
902 	int send_len = TLS_PAYLOAD_MAX_LEN;
903 	char mem_send[TLS_PAYLOAD_MAX_LEN];
904 	char mem_recv[TLS_PAYLOAD_MAX_LEN];
905 	int half = send_len / 2;
906 	int p[2];
907 
908 	ASSERT_GE(pipe(p), 0);
909 	EXPECT_EQ(send(self->fd, mem_send, send_len, 0), send_len);
910 	/* Recv hald of the record, splice the other half */
911 	EXPECT_EQ(recv(self->cfd, mem_recv, half, MSG_WAITALL), half);
912 	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, half, SPLICE_F_NONBLOCK),
913 		  half);
914 	EXPECT_EQ(read(p[0], &mem_recv[half], half), half);
915 	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
916 }
917 
TEST_F(tls,peek_and_splice)918 TEST_F(tls, peek_and_splice)
919 {
920 	int send_len = TLS_PAYLOAD_MAX_LEN;
921 	char mem_send[TLS_PAYLOAD_MAX_LEN];
922 	char mem_recv[TLS_PAYLOAD_MAX_LEN];
923 	int chunk = TLS_PAYLOAD_MAX_LEN / 4;
924 	int n, i, p[2];
925 
926 	memrnd(mem_send, sizeof(mem_send));
927 
928 	ASSERT_GE(pipe(p), 0);
929 	for (i = 0; i < 4; i++)
930 		EXPECT_EQ(send(self->fd, &mem_send[chunk * i], chunk, 0),
931 			  chunk);
932 
933 	EXPECT_EQ(recv(self->cfd, mem_recv, chunk * 5 / 2,
934 		       MSG_WAITALL | MSG_PEEK),
935 		  chunk * 5 / 2);
936 	EXPECT_EQ(memcmp(mem_send, mem_recv, chunk * 5 / 2), 0);
937 
938 	n = 0;
939 	while (n < send_len) {
940 		i = splice(self->cfd, NULL, p[1], NULL, send_len - n, 0);
941 		EXPECT_GT(i, 0);
942 		n += i;
943 	}
944 	EXPECT_EQ(n, send_len);
945 	EXPECT_EQ(read(p[0], mem_recv, send_len), send_len);
946 	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
947 }
948 
949 #define MAX_FRAGS 48
TEST_F(tls,splice_short)950 TEST_F(tls, splice_short)
951 {
952 	struct iovec sendchar_iov;
953 	char read_buf[0x10000];
954 	char sendbuf[0x100];
955 	char sendchar = 'S';
956 	int pipefds[2];
957 	int i;
958 
959 	sendchar_iov.iov_base = &sendchar;
960 	sendchar_iov.iov_len = 1;
961 
962 	memset(sendbuf, 's', sizeof(sendbuf));
963 
964 	ASSERT_GE(pipe2(pipefds, O_NONBLOCK), 0);
965 	ASSERT_GE(fcntl(pipefds[0], F_SETPIPE_SZ, (MAX_FRAGS + 1) * 0x1000), 0);
966 
967 	for (i = 0; i < MAX_FRAGS; i++)
968 		ASSERT_GE(vmsplice(pipefds[1], &sendchar_iov, 1, 0), 0);
969 
970 	ASSERT_EQ(write(pipefds[1], sendbuf, sizeof(sendbuf)), sizeof(sendbuf));
971 
972 	EXPECT_EQ(splice(pipefds[0], NULL, self->fd, NULL, MAX_FRAGS + 0x1000, 0),
973 		  MAX_FRAGS + sizeof(sendbuf));
974 	EXPECT_EQ(recv(self->cfd, read_buf, sizeof(read_buf), 0), MAX_FRAGS + sizeof(sendbuf));
975 	EXPECT_EQ(recv(self->cfd, read_buf, sizeof(read_buf), MSG_DONTWAIT), -1);
976 	EXPECT_EQ(errno, EAGAIN);
977 }
978 #undef MAX_FRAGS
979 
TEST_F(tls,recvmsg_single)980 TEST_F(tls, recvmsg_single)
981 {
982 	char const *test_str = "test_recvmsg_single";
983 	int send_len = strlen(test_str) + 1;
984 	char buf[20];
985 	struct msghdr hdr;
986 	struct iovec vec;
987 
988 	memset(&hdr, 0, sizeof(hdr));
989 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
990 	vec.iov_base = (char *)buf;
991 	vec.iov_len = send_len;
992 	hdr.msg_iovlen = 1;
993 	hdr.msg_iov = &vec;
994 	EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
995 	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
996 }
997 
TEST_F(tls,recvmsg_single_max)998 TEST_F(tls, recvmsg_single_max)
999 {
1000 	int send_len = TLS_PAYLOAD_MAX_LEN;
1001 	char send_mem[TLS_PAYLOAD_MAX_LEN];
1002 	char recv_mem[TLS_PAYLOAD_MAX_LEN];
1003 	struct iovec vec;
1004 	struct msghdr hdr;
1005 
1006 	memrnd(send_mem, sizeof(send_mem));
1007 
1008 	EXPECT_EQ(send(self->fd, send_mem, send_len, 0), send_len);
1009 	vec.iov_base = (char *)recv_mem;
1010 	vec.iov_len = TLS_PAYLOAD_MAX_LEN;
1011 
1012 	hdr.msg_iovlen = 1;
1013 	hdr.msg_iov = &vec;
1014 	EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
1015 	EXPECT_EQ(memcmp(send_mem, recv_mem, send_len), 0);
1016 }
1017 
TEST_F(tls,recvmsg_multiple)1018 TEST_F(tls, recvmsg_multiple)
1019 {
1020 	unsigned int msg_iovlen = 1024;
1021 	struct iovec vec[1024];
1022 	char *iov_base[1024];
1023 	unsigned int iov_len = 16;
1024 	int send_len = 1 << 14;
1025 	char buf[1 << 14];
1026 	struct msghdr hdr;
1027 	int i;
1028 
1029 	memrnd(buf, sizeof(buf));
1030 
1031 	EXPECT_EQ(send(self->fd, buf, send_len, 0), send_len);
1032 	for (i = 0; i < msg_iovlen; i++) {
1033 		iov_base[i] = (char *)malloc(iov_len);
1034 		vec[i].iov_base = iov_base[i];
1035 		vec[i].iov_len = iov_len;
1036 	}
1037 
1038 	hdr.msg_iovlen = msg_iovlen;
1039 	hdr.msg_iov = vec;
1040 	EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
1041 
1042 	for (i = 0; i < msg_iovlen; i++)
1043 		free(iov_base[i]);
1044 }
1045 
TEST_F(tls,single_send_multiple_recv)1046 TEST_F(tls, single_send_multiple_recv)
1047 {
1048 	unsigned int total_len = TLS_PAYLOAD_MAX_LEN * 2;
1049 	unsigned int send_len = TLS_PAYLOAD_MAX_LEN;
1050 	char send_mem[TLS_PAYLOAD_MAX_LEN * 2];
1051 	char recv_mem[TLS_PAYLOAD_MAX_LEN * 2];
1052 
1053 	memrnd(send_mem, sizeof(send_mem));
1054 
1055 	EXPECT_GE(send(self->fd, send_mem, total_len, 0), 0);
1056 	memset(recv_mem, 0, total_len);
1057 
1058 	EXPECT_NE(recv(self->cfd, recv_mem, send_len, 0), -1);
1059 	EXPECT_NE(recv(self->cfd, recv_mem + send_len, send_len, 0), -1);
1060 	EXPECT_EQ(memcmp(send_mem, recv_mem, total_len), 0);
1061 }
1062 
TEST_F(tls,multiple_send_single_recv)1063 TEST_F(tls, multiple_send_single_recv)
1064 {
1065 	unsigned int total_len = 2 * 10;
1066 	unsigned int send_len = 10;
1067 	char recv_mem[2 * 10];
1068 	char send_mem[10];
1069 
1070 	memrnd(send_mem, sizeof(send_mem));
1071 
1072 	EXPECT_GE(send(self->fd, send_mem, send_len, 0), 0);
1073 	EXPECT_GE(send(self->fd, send_mem, send_len, 0), 0);
1074 	memset(recv_mem, 0, total_len);
1075 	EXPECT_EQ(recv(self->cfd, recv_mem, total_len, MSG_WAITALL), total_len);
1076 
1077 	EXPECT_EQ(memcmp(send_mem, recv_mem, send_len), 0);
1078 	EXPECT_EQ(memcmp(send_mem, recv_mem + send_len, send_len), 0);
1079 }
1080 
TEST_F(tls,single_send_multiple_recv_non_align)1081 TEST_F(tls, single_send_multiple_recv_non_align)
1082 {
1083 	const unsigned int total_len = 15;
1084 	const unsigned int recv_len = 10;
1085 	char recv_mem[recv_len * 2];
1086 	char send_mem[total_len];
1087 
1088 	memrnd(send_mem, sizeof(send_mem));
1089 
1090 	EXPECT_GE(send(self->fd, send_mem, total_len, 0), 0);
1091 	memset(recv_mem, 0, total_len);
1092 
1093 	EXPECT_EQ(recv(self->cfd, recv_mem, recv_len, 0), recv_len);
1094 	EXPECT_EQ(recv(self->cfd, recv_mem + recv_len, recv_len, 0), 5);
1095 	EXPECT_EQ(memcmp(send_mem, recv_mem, total_len), 0);
1096 }
1097 
TEST_F(tls,recv_partial)1098 TEST_F(tls, recv_partial)
1099 {
1100 	char const *test_str = "test_read_partial";
1101 	char const *test_str_first = "test_read";
1102 	char const *test_str_second = "_partial";
1103 	int send_len = strlen(test_str) + 1;
1104 	char recv_mem[18];
1105 
1106 	memset(recv_mem, 0, sizeof(recv_mem));
1107 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1108 	EXPECT_EQ(recv(self->cfd, recv_mem, strlen(test_str_first),
1109 		       MSG_WAITALL), strlen(test_str_first));
1110 	EXPECT_EQ(memcmp(test_str_first, recv_mem, strlen(test_str_first)), 0);
1111 	memset(recv_mem, 0, sizeof(recv_mem));
1112 	EXPECT_EQ(recv(self->cfd, recv_mem, strlen(test_str_second),
1113 		       MSG_WAITALL), strlen(test_str_second));
1114 	EXPECT_EQ(memcmp(test_str_second, recv_mem, strlen(test_str_second)),
1115 		  0);
1116 }
1117 
TEST_F(tls,recv_nonblock)1118 TEST_F(tls, recv_nonblock)
1119 {
1120 	char buf[4096];
1121 	bool err;
1122 
1123 	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_DONTWAIT), -1);
1124 	err = (errno == EAGAIN || errno == EWOULDBLOCK);
1125 	EXPECT_EQ(err, true);
1126 }
1127 
TEST_F(tls,recv_peek)1128 TEST_F(tls, recv_peek)
1129 {
1130 	char const *test_str = "test_read_peek";
1131 	int send_len = strlen(test_str) + 1;
1132 	char buf[15];
1133 
1134 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1135 	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_PEEK), send_len);
1136 	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
1137 	memset(buf, 0, sizeof(buf));
1138 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), send_len);
1139 	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
1140 }
1141 
TEST_F(tls,recv_peek_multiple)1142 TEST_F(tls, recv_peek_multiple)
1143 {
1144 	char const *test_str = "test_read_peek";
1145 	int send_len = strlen(test_str) + 1;
1146 	unsigned int num_peeks = 100;
1147 	char buf[15];
1148 	int i;
1149 
1150 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1151 	for (i = 0; i < num_peeks; i++) {
1152 		EXPECT_NE(recv(self->cfd, buf, send_len, MSG_PEEK), -1);
1153 		EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
1154 		memset(buf, 0, sizeof(buf));
1155 	}
1156 	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1157 	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
1158 }
1159 
TEST_F(tls,recv_peek_multiple_records)1160 TEST_F(tls, recv_peek_multiple_records)
1161 {
1162 	char const *test_str = "test_read_peek_mult_recs";
1163 	char const *test_str_first = "test_read_peek";
1164 	char const *test_str_second = "_mult_recs";
1165 	int len;
1166 	char buf[64];
1167 
1168 	len = strlen(test_str_first);
1169 	EXPECT_EQ(send(self->fd, test_str_first, len, 0), len);
1170 
1171 	len = strlen(test_str_second) + 1;
1172 	EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
1173 
1174 	len = strlen(test_str_first);
1175 	memset(buf, 0, len);
1176 	EXPECT_EQ(recv(self->cfd, buf, len, MSG_PEEK | MSG_WAITALL), len);
1177 
1178 	/* MSG_PEEK can only peek into the current record. */
1179 	len = strlen(test_str_first);
1180 	EXPECT_EQ(memcmp(test_str_first, buf, len), 0);
1181 
1182 	len = strlen(test_str) + 1;
1183 	memset(buf, 0, len);
1184 	EXPECT_EQ(recv(self->cfd, buf, len, MSG_WAITALL), len);
1185 
1186 	/* Non-MSG_PEEK will advance strparser (and therefore record)
1187 	 * however.
1188 	 */
1189 	len = strlen(test_str) + 1;
1190 	EXPECT_EQ(memcmp(test_str, buf, len), 0);
1191 
1192 	/* MSG_MORE will hold current record open, so later MSG_PEEK
1193 	 * will see everything.
1194 	 */
1195 	len = strlen(test_str_first);
1196 	EXPECT_EQ(send(self->fd, test_str_first, len, MSG_MORE), len);
1197 
1198 	len = strlen(test_str_second) + 1;
1199 	EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
1200 
1201 	len = strlen(test_str) + 1;
1202 	memset(buf, 0, len);
1203 	EXPECT_EQ(recv(self->cfd, buf, len, MSG_PEEK | MSG_WAITALL), len);
1204 
1205 	len = strlen(test_str) + 1;
1206 	EXPECT_EQ(memcmp(test_str, buf, len), 0);
1207 }
1208 
TEST_F(tls,recv_peek_large_buf_mult_recs)1209 TEST_F(tls, recv_peek_large_buf_mult_recs)
1210 {
1211 	char const *test_str = "test_read_peek_mult_recs";
1212 	char const *test_str_first = "test_read_peek";
1213 	char const *test_str_second = "_mult_recs";
1214 	int len;
1215 	char buf[64];
1216 
1217 	len = strlen(test_str_first);
1218 	EXPECT_EQ(send(self->fd, test_str_first, len, 0), len);
1219 
1220 	len = strlen(test_str_second) + 1;
1221 	EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
1222 
1223 	len = strlen(test_str) + 1;
1224 	memset(buf, 0, len);
1225 	EXPECT_NE((len = recv(self->cfd, buf, len,
1226 			      MSG_PEEK | MSG_WAITALL)), -1);
1227 	len = strlen(test_str) + 1;
1228 	EXPECT_EQ(memcmp(test_str, buf, len), 0);
1229 }
1230 
TEST_F(tls,recv_lowat)1231 TEST_F(tls, recv_lowat)
1232 {
1233 	char send_mem[10] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
1234 	char recv_mem[20];
1235 	int lowat = 8;
1236 
1237 	EXPECT_EQ(send(self->fd, send_mem, 10, 0), 10);
1238 	EXPECT_EQ(send(self->fd, send_mem, 5, 0), 5);
1239 
1240 	memset(recv_mem, 0, 20);
1241 	EXPECT_EQ(setsockopt(self->cfd, SOL_SOCKET, SO_RCVLOWAT,
1242 			     &lowat, sizeof(lowat)), 0);
1243 	EXPECT_EQ(recv(self->cfd, recv_mem, 1, MSG_WAITALL), 1);
1244 	EXPECT_EQ(recv(self->cfd, recv_mem + 1, 6, MSG_WAITALL), 6);
1245 	EXPECT_EQ(recv(self->cfd, recv_mem + 7, 10, 0), 8);
1246 
1247 	EXPECT_EQ(memcmp(send_mem, recv_mem, 10), 0);
1248 	EXPECT_EQ(memcmp(send_mem, recv_mem + 10, 5), 0);
1249 }
1250 
TEST_F(tls,bidir)1251 TEST_F(tls, bidir)
1252 {
1253 	char const *test_str = "test_read";
1254 	int send_len = 10;
1255 	char buf[10];
1256 	int ret;
1257 
1258 	if (!self->notls) {
1259 		struct tls_crypto_info_keys tls12;
1260 
1261 		tls_crypto_info_init(variant->tls_version, variant->cipher_type,
1262 				     &tls12, 0);
1263 
1264 		ret = setsockopt(self->fd, SOL_TLS, TLS_RX, &tls12,
1265 				 tls12.len);
1266 		ASSERT_EQ(ret, 0);
1267 
1268 		ret = setsockopt(self->cfd, SOL_TLS, TLS_TX, &tls12,
1269 				 tls12.len);
1270 		ASSERT_EQ(ret, 0);
1271 	}
1272 
1273 	ASSERT_EQ(strlen(test_str) + 1, send_len);
1274 
1275 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1276 	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1277 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1278 
1279 	memset(buf, 0, sizeof(buf));
1280 
1281 	EXPECT_EQ(send(self->cfd, test_str, send_len, 0), send_len);
1282 	EXPECT_NE(recv(self->fd, buf, send_len, 0), -1);
1283 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1284 };
1285 
TEST_F(tls,pollin)1286 TEST_F(tls, pollin)
1287 {
1288 	char const *test_str = "test_poll";
1289 	struct pollfd fd = { 0, 0, 0 };
1290 	char buf[10];
1291 	int send_len = 10;
1292 
1293 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1294 	fd.fd = self->cfd;
1295 	fd.events = POLLIN;
1296 
1297 	EXPECT_EQ(poll(&fd, 1, 20), 1);
1298 	EXPECT_EQ(fd.revents & POLLIN, 1);
1299 	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
1300 	/* Test timing out */
1301 	EXPECT_EQ(poll(&fd, 1, 20), 0);
1302 }
1303 
TEST_F(tls,poll_wait)1304 TEST_F(tls, poll_wait)
1305 {
1306 	char const *test_str = "test_poll_wait";
1307 	int send_len = strlen(test_str) + 1;
1308 	struct pollfd fd = { 0, 0, 0 };
1309 	char recv_mem[15];
1310 
1311 	fd.fd = self->cfd;
1312 	fd.events = POLLIN;
1313 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1314 	/* Set timeout to inf. secs */
1315 	EXPECT_EQ(poll(&fd, 1, -1), 1);
1316 	EXPECT_EQ(fd.revents & POLLIN, 1);
1317 	EXPECT_EQ(recv(self->cfd, recv_mem, send_len, MSG_WAITALL), send_len);
1318 }
1319 
TEST_F(tls,poll_wait_split)1320 TEST_F(tls, poll_wait_split)
1321 {
1322 	struct pollfd fd = { 0, 0, 0 };
1323 	char send_mem[20] = {};
1324 	char recv_mem[15];
1325 
1326 	fd.fd = self->cfd;
1327 	fd.events = POLLIN;
1328 	/* Send 20 bytes */
1329 	EXPECT_EQ(send(self->fd, send_mem, sizeof(send_mem), 0),
1330 		  sizeof(send_mem));
1331 	/* Poll with inf. timeout */
1332 	EXPECT_EQ(poll(&fd, 1, -1), 1);
1333 	EXPECT_EQ(fd.revents & POLLIN, 1);
1334 	EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), MSG_WAITALL),
1335 		  sizeof(recv_mem));
1336 
1337 	/* Now the remaining 5 bytes of record data are in TLS ULP */
1338 	fd.fd = self->cfd;
1339 	fd.events = POLLIN;
1340 	EXPECT_EQ(poll(&fd, 1, -1), 1);
1341 	EXPECT_EQ(fd.revents & POLLIN, 1);
1342 	EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), 0),
1343 		  sizeof(send_mem) - sizeof(recv_mem));
1344 }
1345 
TEST_F(tls,blocking)1346 TEST_F(tls, blocking)
1347 {
1348 	size_t data = 100000;
1349 	int res = fork();
1350 
1351 	EXPECT_NE(res, -1);
1352 
1353 	if (res) {
1354 		/* parent */
1355 		size_t left = data;
1356 		char buf[16384];
1357 		int status;
1358 		int pid2;
1359 
1360 		while (left) {
1361 			int res = send(self->fd, buf,
1362 				       left > 16384 ? 16384 : left, 0);
1363 
1364 			EXPECT_GE(res, 0);
1365 			left -= res;
1366 		}
1367 
1368 		pid2 = wait(&status);
1369 		EXPECT_EQ(status, 0);
1370 		EXPECT_EQ(res, pid2);
1371 	} else {
1372 		/* child */
1373 		size_t left = data;
1374 		char buf[16384];
1375 
1376 		while (left) {
1377 			int res = recv(self->cfd, buf,
1378 				       left > 16384 ? 16384 : left, 0);
1379 
1380 			EXPECT_GE(res, 0);
1381 			left -= res;
1382 		}
1383 	}
1384 }
1385 
TEST_F(tls,nonblocking)1386 TEST_F(tls, nonblocking)
1387 {
1388 	size_t data = 100000;
1389 	int sendbuf = 100;
1390 	int flags;
1391 	int res;
1392 
1393 	flags = fcntl(self->fd, F_GETFL, 0);
1394 	fcntl(self->fd, F_SETFL, flags | O_NONBLOCK);
1395 	fcntl(self->cfd, F_SETFL, flags | O_NONBLOCK);
1396 
1397 	/* Ensure nonblocking behavior by imposing a small send
1398 	 * buffer.
1399 	 */
1400 	EXPECT_EQ(setsockopt(self->fd, SOL_SOCKET, SO_SNDBUF,
1401 			     &sendbuf, sizeof(sendbuf)), 0);
1402 
1403 	res = fork();
1404 	EXPECT_NE(res, -1);
1405 
1406 	if (res) {
1407 		/* parent */
1408 		bool eagain = false;
1409 		size_t left = data;
1410 		char buf[16384];
1411 		int status;
1412 		int pid2;
1413 
1414 		while (left) {
1415 			int res = send(self->fd, buf,
1416 				       left > 16384 ? 16384 : left, 0);
1417 
1418 			if (res == -1 && errno == EAGAIN) {
1419 				eagain = true;
1420 				usleep(10000);
1421 				continue;
1422 			}
1423 			EXPECT_GE(res, 0);
1424 			left -= res;
1425 		}
1426 
1427 		EXPECT_TRUE(eagain);
1428 		pid2 = wait(&status);
1429 
1430 		EXPECT_EQ(status, 0);
1431 		EXPECT_EQ(res, pid2);
1432 	} else {
1433 		/* child */
1434 		bool eagain = false;
1435 		size_t left = data;
1436 		char buf[16384];
1437 
1438 		while (left) {
1439 			int res = recv(self->cfd, buf,
1440 				       left > 16384 ? 16384 : left, 0);
1441 
1442 			if (res == -1 && errno == EAGAIN) {
1443 				eagain = true;
1444 				usleep(10000);
1445 				continue;
1446 			}
1447 			EXPECT_GE(res, 0);
1448 			left -= res;
1449 		}
1450 		EXPECT_TRUE(eagain);
1451 	}
1452 }
1453 
1454 static void
test_mutliproc(struct __test_metadata * _metadata,struct _test_data_tls * self,bool sendpg,unsigned int n_readers,unsigned int n_writers)1455 test_mutliproc(struct __test_metadata *_metadata, struct _test_data_tls *self,
1456 	       bool sendpg, unsigned int n_readers, unsigned int n_writers)
1457 {
1458 	const unsigned int n_children = n_readers + n_writers;
1459 	const size_t data = 6 * 1000 * 1000;
1460 	const size_t file_sz = data / 100;
1461 	size_t read_bias, write_bias;
1462 	int i, fd, child_id;
1463 	char buf[file_sz];
1464 	pid_t pid;
1465 
1466 	/* Only allow multiples for simplicity */
1467 	ASSERT_EQ(!(n_readers % n_writers) || !(n_writers % n_readers), true);
1468 	read_bias = n_writers / n_readers ?: 1;
1469 	write_bias = n_readers / n_writers ?: 1;
1470 
1471 	/* prep a file to send */
1472 	fd = open("/tmp/", O_TMPFILE | O_RDWR, 0600);
1473 	ASSERT_GE(fd, 0);
1474 
1475 	memset(buf, 0xac, file_sz);
1476 	ASSERT_EQ(write(fd, buf, file_sz), file_sz);
1477 
1478 	/* spawn children */
1479 	for (child_id = 0; child_id < n_children; child_id++) {
1480 		pid = fork();
1481 		ASSERT_NE(pid, -1);
1482 		if (!pid)
1483 			break;
1484 	}
1485 
1486 	/* parent waits for all children */
1487 	if (pid) {
1488 		for (i = 0; i < n_children; i++) {
1489 			int status;
1490 
1491 			wait(&status);
1492 			EXPECT_EQ(status, 0);
1493 		}
1494 
1495 		return;
1496 	}
1497 
1498 	/* Split threads for reading and writing */
1499 	if (child_id < n_readers) {
1500 		size_t left = data * read_bias;
1501 		char rb[8001];
1502 
1503 		while (left) {
1504 			int res;
1505 
1506 			res = recv(self->cfd, rb,
1507 				   left > sizeof(rb) ? sizeof(rb) : left, 0);
1508 
1509 			EXPECT_GE(res, 0);
1510 			left -= res;
1511 		}
1512 	} else {
1513 		size_t left = data * write_bias;
1514 
1515 		while (left) {
1516 			int res;
1517 
1518 			ASSERT_EQ(lseek(fd, 0, SEEK_SET), 0);
1519 			if (sendpg)
1520 				res = sendfile(self->fd, fd, NULL,
1521 					       left > file_sz ? file_sz : left);
1522 			else
1523 				res = send(self->fd, buf,
1524 					   left > file_sz ? file_sz : left, 0);
1525 
1526 			EXPECT_GE(res, 0);
1527 			left -= res;
1528 		}
1529 	}
1530 }
1531 
TEST_F(tls,mutliproc_even)1532 TEST_F(tls, mutliproc_even)
1533 {
1534 	test_mutliproc(_metadata, self, false, 6, 6);
1535 }
1536 
TEST_F(tls,mutliproc_readers)1537 TEST_F(tls, mutliproc_readers)
1538 {
1539 	test_mutliproc(_metadata, self, false, 4, 12);
1540 }
1541 
TEST_F(tls,mutliproc_writers)1542 TEST_F(tls, mutliproc_writers)
1543 {
1544 	test_mutliproc(_metadata, self, false, 10, 2);
1545 }
1546 
TEST_F(tls,mutliproc_sendpage_even)1547 TEST_F(tls, mutliproc_sendpage_even)
1548 {
1549 	test_mutliproc(_metadata, self, true, 6, 6);
1550 }
1551 
TEST_F(tls,mutliproc_sendpage_readers)1552 TEST_F(tls, mutliproc_sendpage_readers)
1553 {
1554 	test_mutliproc(_metadata, self, true, 4, 12);
1555 }
1556 
TEST_F(tls,mutliproc_sendpage_writers)1557 TEST_F(tls, mutliproc_sendpage_writers)
1558 {
1559 	test_mutliproc(_metadata, self, true, 10, 2);
1560 }
1561 
TEST_F(tls,control_msg)1562 TEST_F(tls, control_msg)
1563 {
1564 	char *test_str = "test_read";
1565 	char record_type = 100;
1566 	int send_len = 10;
1567 	char buf[10];
1568 
1569 	if (self->notls)
1570 		SKIP(return, "no TLS support");
1571 
1572 	EXPECT_EQ(tls_send_cmsg(self->fd, record_type, test_str, send_len, 0),
1573 		  send_len);
1574 	/* Should fail because we didn't provide a control message */
1575 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
1576 
1577 	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
1578 				buf, sizeof(buf), MSG_WAITALL | MSG_PEEK),
1579 		  send_len);
1580 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1581 
1582 	/* Recv the message again without MSG_PEEK */
1583 	memset(buf, 0, sizeof(buf));
1584 
1585 	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
1586 				buf, sizeof(buf), MSG_WAITALL),
1587 		  send_len);
1588 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1589 }
1590 
TEST_F(tls,control_msg_nomerge)1591 TEST_F(tls, control_msg_nomerge)
1592 {
1593 	char *rec1 = "1111";
1594 	char *rec2 = "2222";
1595 	int send_len = 5;
1596 	char buf[15];
1597 
1598 	if (self->notls)
1599 		SKIP(return, "no TLS support");
1600 
1601 	EXPECT_EQ(tls_send_cmsg(self->fd, 100, rec1, send_len, 0), send_len);
1602 	EXPECT_EQ(tls_send_cmsg(self->fd, 100, rec2, send_len, 0), send_len);
1603 
1604 	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, 100, buf, sizeof(buf), MSG_PEEK), send_len);
1605 	EXPECT_EQ(memcmp(buf, rec1, send_len), 0);
1606 
1607 	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, 100, buf, sizeof(buf), MSG_PEEK), send_len);
1608 	EXPECT_EQ(memcmp(buf, rec1, send_len), 0);
1609 
1610 	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, 100, buf, sizeof(buf), 0), send_len);
1611 	EXPECT_EQ(memcmp(buf, rec1, send_len), 0);
1612 
1613 	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, 100, buf, sizeof(buf), 0), send_len);
1614 	EXPECT_EQ(memcmp(buf, rec2, send_len), 0);
1615 }
1616 
TEST_F(tls,data_control_data)1617 TEST_F(tls, data_control_data)
1618 {
1619 	char *rec1 = "1111";
1620 	char *rec2 = "2222";
1621 	char *rec3 = "3333";
1622 	int send_len = 5;
1623 	char buf[15];
1624 
1625 	if (self->notls)
1626 		SKIP(return, "no TLS support");
1627 
1628 	EXPECT_EQ(send(self->fd, rec1, send_len, 0), send_len);
1629 	EXPECT_EQ(tls_send_cmsg(self->fd, 100, rec2, send_len, 0), send_len);
1630 	EXPECT_EQ(send(self->fd, rec3, send_len, 0), send_len);
1631 
1632 	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_PEEK), send_len);
1633 	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_PEEK), send_len);
1634 }
1635 
TEST_F(tls,shutdown)1636 TEST_F(tls, shutdown)
1637 {
1638 	char const *test_str = "test_read";
1639 	int send_len = 10;
1640 	char buf[10];
1641 
1642 	ASSERT_EQ(strlen(test_str) + 1, send_len);
1643 
1644 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1645 	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1646 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1647 
1648 	shutdown(self->fd, SHUT_RDWR);
1649 	shutdown(self->cfd, SHUT_RDWR);
1650 }
1651 
TEST_F(tls,shutdown_unsent)1652 TEST_F(tls, shutdown_unsent)
1653 {
1654 	char const *test_str = "test_read";
1655 	int send_len = 10;
1656 
1657 	EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
1658 
1659 	shutdown(self->fd, SHUT_RDWR);
1660 	shutdown(self->cfd, SHUT_RDWR);
1661 }
1662 
TEST_F(tls,shutdown_reuse)1663 TEST_F(tls, shutdown_reuse)
1664 {
1665 	struct sockaddr_in addr;
1666 	int ret;
1667 
1668 	shutdown(self->fd, SHUT_RDWR);
1669 	shutdown(self->cfd, SHUT_RDWR);
1670 	close(self->cfd);
1671 
1672 	addr.sin_family = AF_INET;
1673 	addr.sin_addr.s_addr = htonl(INADDR_ANY);
1674 	addr.sin_port = 0;
1675 
1676 	ret = bind(self->fd, &addr, sizeof(addr));
1677 	EXPECT_EQ(ret, 0);
1678 	ret = listen(self->fd, 10);
1679 	EXPECT_EQ(ret, -1);
1680 	EXPECT_EQ(errno, EINVAL);
1681 
1682 	ret = connect(self->fd, &addr, sizeof(addr));
1683 	EXPECT_EQ(ret, -1);
1684 	EXPECT_EQ(errno, EISCONN);
1685 }
1686 
TEST_F(tls,getsockopt)1687 TEST_F(tls, getsockopt)
1688 {
1689 	struct tls_crypto_info_keys expect, get;
1690 	socklen_t len;
1691 
1692 	/* get only the version/cipher */
1693 	len = sizeof(struct tls_crypto_info);
1694 	memrnd(&get, sizeof(get));
1695 	EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &get, &len), 0);
1696 	EXPECT_EQ(len, sizeof(struct tls_crypto_info));
1697 	EXPECT_EQ(get.crypto_info.version, variant->tls_version);
1698 	EXPECT_EQ(get.crypto_info.cipher_type, variant->cipher_type);
1699 
1700 	/* get the full crypto_info */
1701 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &expect, 0);
1702 	len = expect.len;
1703 	memrnd(&get, sizeof(get));
1704 	EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &get, &len), 0);
1705 	EXPECT_EQ(len, expect.len);
1706 	EXPECT_EQ(get.crypto_info.version, variant->tls_version);
1707 	EXPECT_EQ(get.crypto_info.cipher_type, variant->cipher_type);
1708 	EXPECT_EQ(memcmp(&get, &expect, expect.len), 0);
1709 
1710 	/* short get should fail */
1711 	len = sizeof(struct tls_crypto_info) - 1;
1712 	EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &get, &len), -1);
1713 	EXPECT_EQ(errno, EINVAL);
1714 
1715 	/* partial get of the cipher data should fail */
1716 	len = expect.len - 1;
1717 	EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &get, &len), -1);
1718 	EXPECT_EQ(errno, EINVAL);
1719 }
1720 
TEST_F(tls,recv_efault)1721 TEST_F(tls, recv_efault)
1722 {
1723 	char *rec1 = "1111111111";
1724 	char *rec2 = "2222222222";
1725 	struct msghdr hdr = {};
1726 	struct iovec iov[2];
1727 	char recv_mem[12];
1728 	int ret;
1729 
1730 	if (self->notls)
1731 		SKIP(return, "no TLS support");
1732 
1733 	EXPECT_EQ(send(self->fd, rec1, 10, 0), 10);
1734 	EXPECT_EQ(send(self->fd, rec2, 10, 0), 10);
1735 
1736 	iov[0].iov_base = recv_mem;
1737 	iov[0].iov_len = sizeof(recv_mem);
1738 	iov[1].iov_base = NULL; /* broken iov to make process_rx_list fail */
1739 	iov[1].iov_len = 1;
1740 
1741 	hdr.msg_iovlen = 2;
1742 	hdr.msg_iov = iov;
1743 
1744 	EXPECT_EQ(recv(self->cfd, recv_mem, 1, 0), 1);
1745 	EXPECT_EQ(recv_mem[0], rec1[0]);
1746 
1747 	ret = recvmsg(self->cfd, &hdr, 0);
1748 	EXPECT_LE(ret, sizeof(recv_mem));
1749 	EXPECT_GE(ret, 9);
1750 	EXPECT_EQ(memcmp(rec1, recv_mem, 9), 0);
1751 	if (ret > 9)
1752 		EXPECT_EQ(memcmp(rec2, recv_mem + 9, ret - 9), 0);
1753 }
1754 
1755 #define TLS_RECORD_TYPE_HANDSHAKE      0x16
1756 /* key_update, length 1, update_not_requested */
1757 static const char key_update_msg[] = "\x18\x00\x00\x01\x00";
tls_send_keyupdate(struct __test_metadata * _metadata,int fd)1758 static void tls_send_keyupdate(struct __test_metadata *_metadata, int fd)
1759 {
1760 	size_t len = sizeof(key_update_msg);
1761 
1762 	EXPECT_EQ(tls_send_cmsg(fd, TLS_RECORD_TYPE_HANDSHAKE,
1763 				(char *)key_update_msg, len, 0),
1764 		  len);
1765 }
1766 
tls_recv_keyupdate(struct __test_metadata * _metadata,int fd,int flags)1767 static void tls_recv_keyupdate(struct __test_metadata *_metadata, int fd, int flags)
1768 {
1769 	char buf[100];
1770 
1771 	EXPECT_EQ(tls_recv_cmsg(_metadata, fd, TLS_RECORD_TYPE_HANDSHAKE, buf, sizeof(buf), flags),
1772 		  sizeof(key_update_msg));
1773 	EXPECT_EQ(memcmp(buf, key_update_msg, sizeof(key_update_msg)), 0);
1774 }
1775 
1776 /* set the key to 0 then 1 for RX, immediately to 1 for TX */
TEST_F(tls_basic,rekey_rx)1777 TEST_F(tls_basic, rekey_rx)
1778 {
1779 	struct tls_crypto_info_keys tls12_0, tls12_1;
1780 	char const *test_str = "test_message";
1781 	int send_len = strlen(test_str) + 1;
1782 	char buf[20];
1783 	int ret;
1784 
1785 	if (self->notls)
1786 		return;
1787 
1788 	tls_crypto_info_init(TLS_1_3_VERSION, TLS_CIPHER_AES_GCM_128,
1789 			     &tls12_0, 0);
1790 	tls_crypto_info_init(TLS_1_3_VERSION, TLS_CIPHER_AES_GCM_128,
1791 			     &tls12_1, 1);
1792 
1793 	ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12_1, tls12_1.len);
1794 	ASSERT_EQ(ret, 0);
1795 
1796 	ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12_0, tls12_0.len);
1797 	ASSERT_EQ(ret, 0);
1798 
1799 	ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12_1, tls12_1.len);
1800 	EXPECT_EQ(ret, 0);
1801 
1802 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1803 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), send_len);
1804 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1805 }
1806 
1807 /* set the key to 0 then 1 for TX, immediately to 1 for RX */
TEST_F(tls_basic,rekey_tx)1808 TEST_F(tls_basic, rekey_tx)
1809 {
1810 	struct tls_crypto_info_keys tls12_0, tls12_1;
1811 	char const *test_str = "test_message";
1812 	int send_len = strlen(test_str) + 1;
1813 	char buf[20];
1814 	int ret;
1815 
1816 	if (self->notls)
1817 		return;
1818 
1819 	tls_crypto_info_init(TLS_1_3_VERSION, TLS_CIPHER_AES_GCM_128,
1820 			     &tls12_0, 0);
1821 	tls_crypto_info_init(TLS_1_3_VERSION, TLS_CIPHER_AES_GCM_128,
1822 			     &tls12_1, 1);
1823 
1824 	ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12_0, tls12_0.len);
1825 	ASSERT_EQ(ret, 0);
1826 
1827 	ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12_1, tls12_1.len);
1828 	ASSERT_EQ(ret, 0);
1829 
1830 	ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12_1, tls12_1.len);
1831 	EXPECT_EQ(ret, 0);
1832 
1833 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1834 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), send_len);
1835 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1836 }
1837 
TEST_F(tls_basic,disconnect)1838 TEST_F(tls_basic, disconnect)
1839 {
1840 	char const *test_str = "test_message";
1841 	int send_len = strlen(test_str) + 1;
1842 	struct tls_crypto_info_keys key;
1843 	struct sockaddr_in addr;
1844 	char buf[20];
1845 	int ret;
1846 
1847 	if (self->notls)
1848 		return;
1849 
1850 	tls_crypto_info_init(TLS_1_3_VERSION, TLS_CIPHER_AES_GCM_128,
1851 			     &key, 0);
1852 
1853 	ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &key, key.len);
1854 	ASSERT_EQ(ret, 0);
1855 
1856 	/* Pre-queue the data so that setsockopt parses it but doesn't
1857 	 * dequeue it from the TCP socket. recvmsg would dequeue.
1858 	 */
1859 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1860 
1861 	ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &key, key.len);
1862 	ASSERT_EQ(ret, 0);
1863 
1864 	addr.sin_family = AF_UNSPEC;
1865 	addr.sin_addr.s_addr = htonl(INADDR_ANY);
1866 	addr.sin_port = 0;
1867 	ret = connect(self->cfd, &addr, sizeof(addr));
1868 	EXPECT_EQ(ret, -1);
1869 	EXPECT_EQ(errno, EOPNOTSUPP);
1870 
1871 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), send_len);
1872 }
1873 
TEST_F(tls,rekey)1874 TEST_F(tls, rekey)
1875 {
1876 	char const *test_str_1 = "test_message_before_rekey";
1877 	char const *test_str_2 = "test_message_after_rekey";
1878 	struct tls_crypto_info_keys tls12;
1879 	int send_len;
1880 	char buf[100];
1881 
1882 	if (variant->tls_version != TLS_1_3_VERSION)
1883 		return;
1884 
1885 	/* initial send/recv */
1886 	send_len = strlen(test_str_1) + 1;
1887 	EXPECT_EQ(send(self->fd, test_str_1, send_len, 0), send_len);
1888 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), send_len);
1889 	EXPECT_EQ(memcmp(buf, test_str_1, send_len), 0);
1890 
1891 	/* update TX key */
1892 	tls_send_keyupdate(_metadata, self->fd);
1893 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
1894 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
1895 
1896 	/* send after rekey */
1897 	send_len = strlen(test_str_2) + 1;
1898 	EXPECT_EQ(send(self->fd, test_str_2, send_len, 0), send_len);
1899 
1900 	/* can't receive the KeyUpdate without a control message */
1901 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
1902 
1903 	/* get KeyUpdate */
1904 	tls_recv_keyupdate(_metadata, self->cfd, 0);
1905 
1906 	/* recv blocking -> -EKEYEXPIRED */
1907 	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), 0), -1);
1908 	EXPECT_EQ(errno, EKEYEXPIRED);
1909 
1910 	/* recv non-blocking -> -EKEYEXPIRED */
1911 	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_DONTWAIT), -1);
1912 	EXPECT_EQ(errno, EKEYEXPIRED);
1913 
1914 	/* update RX key */
1915 	EXPECT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
1916 
1917 	/* recv after rekey */
1918 	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1919 	EXPECT_EQ(memcmp(buf, test_str_2, send_len), 0);
1920 }
1921 
TEST_F(tls,rekey_fail)1922 TEST_F(tls, rekey_fail)
1923 {
1924 	char const *test_str_1 = "test_message_before_rekey";
1925 	char const *test_str_2 = "test_message_after_rekey";
1926 	struct tls_crypto_info_keys tls12;
1927 	int send_len;
1928 	char buf[100];
1929 
1930 	/* initial send/recv */
1931 	send_len = strlen(test_str_1) + 1;
1932 	EXPECT_EQ(send(self->fd, test_str_1, send_len, 0), send_len);
1933 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), send_len);
1934 	EXPECT_EQ(memcmp(buf, test_str_1, send_len), 0);
1935 
1936 	/* update TX key */
1937 	tls_send_keyupdate(_metadata, self->fd);
1938 
1939 	if (variant->tls_version != TLS_1_3_VERSION) {
1940 		/* just check that rekey is not supported and return */
1941 		tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
1942 		EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), -1);
1943 		EXPECT_EQ(errno, EBUSY);
1944 		return;
1945 	}
1946 
1947 	/* successful update */
1948 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
1949 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
1950 
1951 	/* invalid update: change of version */
1952 	tls_crypto_info_init(TLS_1_2_VERSION, variant->cipher_type, &tls12, 1);
1953 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), -1);
1954 	EXPECT_EQ(errno, EINVAL);
1955 
1956 	/* invalid update (RX socket): change of version */
1957 	tls_crypto_info_init(TLS_1_2_VERSION, variant->cipher_type, &tls12, 1);
1958 	EXPECT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), -1);
1959 	EXPECT_EQ(errno, EINVAL);
1960 
1961 	/* invalid update: change of cipher */
1962 	if (variant->cipher_type == TLS_CIPHER_AES_GCM_256)
1963 		tls_crypto_info_init(variant->tls_version, TLS_CIPHER_CHACHA20_POLY1305, &tls12, 1);
1964 	else
1965 		tls_crypto_info_init(variant->tls_version, TLS_CIPHER_AES_GCM_256, &tls12, 1);
1966 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), -1);
1967 	EXPECT_EQ(errno, EINVAL);
1968 
1969 	/* send after rekey, the invalid updates shouldn't have an effect */
1970 	send_len = strlen(test_str_2) + 1;
1971 	EXPECT_EQ(send(self->fd, test_str_2, send_len, 0), send_len);
1972 
1973 	/* can't receive the KeyUpdate without a control message */
1974 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
1975 
1976 	/* get KeyUpdate */
1977 	tls_recv_keyupdate(_metadata, self->cfd, 0);
1978 
1979 	/* recv blocking -> -EKEYEXPIRED */
1980 	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), 0), -1);
1981 	EXPECT_EQ(errno, EKEYEXPIRED);
1982 
1983 	/* recv non-blocking -> -EKEYEXPIRED */
1984 	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_DONTWAIT), -1);
1985 	EXPECT_EQ(errno, EKEYEXPIRED);
1986 
1987 	/* update RX key */
1988 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
1989 	EXPECT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
1990 
1991 	/* recv after rekey */
1992 	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1993 	EXPECT_EQ(memcmp(buf, test_str_2, send_len), 0);
1994 }
1995 
TEST_F(tls,rekey_peek)1996 TEST_F(tls, rekey_peek)
1997 {
1998 	char const *test_str_1 = "test_message_before_rekey";
1999 	struct tls_crypto_info_keys tls12;
2000 	int send_len;
2001 	char buf[100];
2002 
2003 	if (variant->tls_version != TLS_1_3_VERSION)
2004 		return;
2005 
2006 	send_len = strlen(test_str_1) + 1;
2007 	EXPECT_EQ(send(self->fd, test_str_1, send_len, 0), send_len);
2008 
2009 	/* update TX key */
2010 	tls_send_keyupdate(_metadata, self->fd);
2011 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
2012 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
2013 
2014 	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_PEEK), send_len);
2015 	EXPECT_EQ(memcmp(buf, test_str_1, send_len), 0);
2016 
2017 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), send_len);
2018 	EXPECT_EQ(memcmp(buf, test_str_1, send_len), 0);
2019 
2020 	/* can't receive the KeyUpdate without a control message */
2021 	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_PEEK), -1);
2022 
2023 	/* peek KeyUpdate */
2024 	tls_recv_keyupdate(_metadata, self->cfd, MSG_PEEK);
2025 
2026 	/* get KeyUpdate */
2027 	tls_recv_keyupdate(_metadata, self->cfd, 0);
2028 
2029 	/* update RX key */
2030 	EXPECT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
2031 }
2032 
TEST_F(tls,splice_rekey)2033 TEST_F(tls, splice_rekey)
2034 {
2035 	int send_len = TLS_PAYLOAD_MAX_LEN / 2;
2036 	char mem_send[TLS_PAYLOAD_MAX_LEN];
2037 	char mem_recv[TLS_PAYLOAD_MAX_LEN];
2038 	struct tls_crypto_info_keys tls12;
2039 	int p[2];
2040 
2041 	if (variant->tls_version != TLS_1_3_VERSION)
2042 		return;
2043 
2044 	memrnd(mem_send, sizeof(mem_send));
2045 
2046 	ASSERT_GE(pipe(p), 0);
2047 	EXPECT_EQ(send(self->fd, mem_send, send_len, 0), send_len);
2048 
2049 	/* update TX key */
2050 	tls_send_keyupdate(_metadata, self->fd);
2051 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
2052 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
2053 
2054 	EXPECT_EQ(send(self->fd, mem_send, send_len, 0), send_len);
2055 
2056 	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, TLS_PAYLOAD_MAX_LEN, 0), send_len);
2057 	EXPECT_EQ(read(p[0], mem_recv, send_len), send_len);
2058 	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
2059 
2060 	/* can't splice the KeyUpdate */
2061 	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, TLS_PAYLOAD_MAX_LEN, 0), -1);
2062 	EXPECT_EQ(errno, EINVAL);
2063 
2064 	/* peek KeyUpdate */
2065 	tls_recv_keyupdate(_metadata, self->cfd, MSG_PEEK);
2066 
2067 	/* get KeyUpdate */
2068 	tls_recv_keyupdate(_metadata, self->cfd, 0);
2069 
2070 	/* can't splice before updating the key */
2071 	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, TLS_PAYLOAD_MAX_LEN, 0), -1);
2072 	EXPECT_EQ(errno, EKEYEXPIRED);
2073 
2074 	/* update RX key */
2075 	EXPECT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
2076 
2077 	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, TLS_PAYLOAD_MAX_LEN, 0), send_len);
2078 	EXPECT_EQ(read(p[0], mem_recv, send_len), send_len);
2079 	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
2080 }
2081 
TEST_F(tls,rekey_peek_splice)2082 TEST_F(tls, rekey_peek_splice)
2083 {
2084 	char const *test_str_1 = "test_message_before_rekey";
2085 	struct tls_crypto_info_keys tls12;
2086 	int send_len;
2087 	char buf[100];
2088 	char mem_recv[TLS_PAYLOAD_MAX_LEN];
2089 	int p[2];
2090 
2091 	if (variant->tls_version != TLS_1_3_VERSION)
2092 		return;
2093 
2094 	ASSERT_GE(pipe(p), 0);
2095 
2096 	send_len = strlen(test_str_1) + 1;
2097 	EXPECT_EQ(send(self->fd, test_str_1, send_len, 0), send_len);
2098 
2099 	/* update TX key */
2100 	tls_send_keyupdate(_metadata, self->fd);
2101 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
2102 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
2103 
2104 	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_PEEK), send_len);
2105 	EXPECT_EQ(memcmp(buf, test_str_1, send_len), 0);
2106 
2107 	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, TLS_PAYLOAD_MAX_LEN, 0), send_len);
2108 	EXPECT_EQ(read(p[0], mem_recv, send_len), send_len);
2109 	EXPECT_EQ(memcmp(mem_recv, test_str_1, send_len), 0);
2110 }
2111 
TEST_F(tls,rekey_getsockopt)2112 TEST_F(tls, rekey_getsockopt)
2113 {
2114 	struct tls_crypto_info_keys tls12;
2115 	struct tls_crypto_info_keys tls12_get;
2116 	socklen_t len;
2117 
2118 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 0);
2119 
2120 	len = tls12.len;
2121 	EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &tls12_get, &len), 0);
2122 	EXPECT_EQ(len, tls12.len);
2123 	EXPECT_EQ(memcmp(&tls12_get, &tls12, tls12.len), 0);
2124 
2125 	len = tls12.len;
2126 	EXPECT_EQ(getsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12_get, &len), 0);
2127 	EXPECT_EQ(len, tls12.len);
2128 	EXPECT_EQ(memcmp(&tls12_get, &tls12, tls12.len), 0);
2129 
2130 	if (variant->tls_version != TLS_1_3_VERSION)
2131 		return;
2132 
2133 	tls_send_keyupdate(_metadata, self->fd);
2134 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
2135 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
2136 
2137 	tls_recv_keyupdate(_metadata, self->cfd, 0);
2138 	EXPECT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
2139 
2140 	len = tls12.len;
2141 	EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &tls12_get, &len), 0);
2142 	EXPECT_EQ(len, tls12.len);
2143 	EXPECT_EQ(memcmp(&tls12_get, &tls12, tls12.len), 0);
2144 
2145 	len = tls12.len;
2146 	EXPECT_EQ(getsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12_get, &len), 0);
2147 	EXPECT_EQ(len, tls12.len);
2148 	EXPECT_EQ(memcmp(&tls12_get, &tls12, tls12.len), 0);
2149 }
2150 
TEST_F(tls,rekey_poll_pending)2151 TEST_F(tls, rekey_poll_pending)
2152 {
2153 	char const *test_str = "test_message_after_rekey";
2154 	struct tls_crypto_info_keys tls12;
2155 	struct pollfd pfd = { };
2156 	int send_len;
2157 	int ret;
2158 
2159 	if (variant->tls_version != TLS_1_3_VERSION)
2160 		return;
2161 
2162 	/* update TX key */
2163 	tls_send_keyupdate(_metadata, self->fd);
2164 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
2165 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
2166 
2167 	/* get KeyUpdate */
2168 	tls_recv_keyupdate(_metadata, self->cfd, 0);
2169 
2170 	/* send immediately after rekey */
2171 	send_len = strlen(test_str) + 1;
2172 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
2173 
2174 	/* key hasn't been updated, expect cfd to be non-readable */
2175 	pfd.fd = self->cfd;
2176 	pfd.events = POLLIN;
2177 	EXPECT_EQ(poll(&pfd, 1, 0), 0);
2178 
2179 	ret = fork();
2180 	ASSERT_GE(ret, 0);
2181 
2182 	if (ret) {
2183 		int pid2, status;
2184 
2185 		/* wait before installing the new key */
2186 		sleep(1);
2187 
2188 		/* update RX key while poll() is sleeping */
2189 		EXPECT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
2190 
2191 		pid2 = wait(&status);
2192 		EXPECT_EQ(pid2, ret);
2193 		EXPECT_EQ(status, 0);
2194 	} else {
2195 		pfd.fd = self->cfd;
2196 		pfd.events = POLLIN;
2197 		EXPECT_EQ(poll(&pfd, 1, 5000), 1);
2198 
2199 		exit(!__test_passed(_metadata));
2200 	}
2201 }
2202 
TEST_F(tls,rekey_poll_delay)2203 TEST_F(tls, rekey_poll_delay)
2204 {
2205 	char const *test_str = "test_message_after_rekey";
2206 	struct tls_crypto_info_keys tls12;
2207 	struct pollfd pfd = { };
2208 	int send_len;
2209 	int ret;
2210 
2211 	if (variant->tls_version != TLS_1_3_VERSION)
2212 		return;
2213 
2214 	/* update TX key */
2215 	tls_send_keyupdate(_metadata, self->fd);
2216 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
2217 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
2218 
2219 	/* get KeyUpdate */
2220 	tls_recv_keyupdate(_metadata, self->cfd, 0);
2221 
2222 	ret = fork();
2223 	ASSERT_GE(ret, 0);
2224 
2225 	if (ret) {
2226 		int pid2, status;
2227 
2228 		/* wait before installing the new key */
2229 		sleep(1);
2230 
2231 		/* update RX key while poll() is sleeping */
2232 		EXPECT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
2233 
2234 		sleep(1);
2235 		send_len = strlen(test_str) + 1;
2236 		EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
2237 
2238 		pid2 = wait(&status);
2239 		EXPECT_EQ(pid2, ret);
2240 		EXPECT_EQ(status, 0);
2241 	} else {
2242 		pfd.fd = self->cfd;
2243 		pfd.events = POLLIN;
2244 		EXPECT_EQ(poll(&pfd, 1, 5000), 1);
2245 		exit(!__test_passed(_metadata));
2246 	}
2247 }
2248 
2249 struct raw_rec {
2250 	unsigned int plain_len;
2251 	unsigned char plain_data[100];
2252 	unsigned int cipher_len;
2253 	unsigned char cipher_data[128];
2254 };
2255 
2256 /* TLS 1.2, AES_CCM, data, seqno:0, plaintext: 'Hello world' */
2257 static const struct raw_rec id0_data_l11 = {
2258 	.plain_len = 11,
2259 	.plain_data = {
2260 		0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f,
2261 		0x72, 0x6c, 0x64,
2262 	},
2263 	.cipher_len = 40,
2264 	.cipher_data = {
2265 		0x17, 0x03, 0x03, 0x00, 0x23, 0x00, 0x00, 0x00,
2266 		0x00, 0x00, 0x00, 0x00, 0x00, 0x26, 0xa2, 0x33,
2267 		0xde, 0x8d, 0x94, 0xf0, 0x29, 0x6c, 0xb1, 0xaf,
2268 		0x6a, 0x75, 0xb2, 0x93, 0xad, 0x45, 0xd5, 0xfd,
2269 		0x03, 0x51, 0x57, 0x8f, 0xf9, 0xcc, 0x3b, 0x42,
2270 	},
2271 };
2272 
2273 /* TLS 1.2, AES_CCM, ctrl, seqno:0, plaintext: '' */
2274 static const struct raw_rec id0_ctrl_l0 = {
2275 	.plain_len = 0,
2276 	.plain_data = {
2277 	},
2278 	.cipher_len = 29,
2279 	.cipher_data = {
2280 		0x16, 0x03, 0x03, 0x00, 0x18, 0x00, 0x00, 0x00,
2281 		0x00, 0x00, 0x00, 0x00, 0x00, 0x13, 0x38, 0x7b,
2282 		0xa6, 0x1c, 0xdd, 0xa7, 0x19, 0x33, 0xab, 0xae,
2283 		0x88, 0xe1, 0xd2, 0x08, 0x4f,
2284 	},
2285 };
2286 
2287 /* TLS 1.2, AES_CCM, data, seqno:0, plaintext: '' */
2288 static const struct raw_rec id0_data_l0 = {
2289 	.plain_len = 0,
2290 	.plain_data = {
2291 	},
2292 	.cipher_len = 29,
2293 	.cipher_data = {
2294 		0x17, 0x03, 0x03, 0x00, 0x18, 0x00, 0x00, 0x00,
2295 		0x00, 0x00, 0x00, 0x00, 0x00, 0xc5, 0x37, 0x90,
2296 		0x70, 0x45, 0x89, 0xfb, 0x5c, 0xc7, 0x89, 0x03,
2297 		0x68, 0x80, 0xd3, 0xd8, 0xcc,
2298 	},
2299 };
2300 
2301 /* TLS 1.2, AES_CCM, data, seqno:1, plaintext: 'Hello world' */
2302 static const struct raw_rec id1_data_l11 = {
2303 	.plain_len = 11,
2304 	.plain_data = {
2305 		0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f,
2306 		0x72, 0x6c, 0x64,
2307 	},
2308 	.cipher_len = 40,
2309 	.cipher_data = {
2310 		0x17, 0x03, 0x03, 0x00, 0x23, 0x00, 0x00, 0x00,
2311 		0x00, 0x00, 0x00, 0x00, 0x01, 0x3a, 0x1a, 0x9c,
2312 		0xd0, 0xa8, 0x9a, 0xd6, 0x69, 0xd6, 0x1a, 0xe3,
2313 		0xb5, 0x1f, 0x0d, 0x2c, 0xe2, 0x97, 0x46, 0xff,
2314 		0x2b, 0xcc, 0x5a, 0xc4, 0xa3, 0xb9, 0xef, 0xba,
2315 	},
2316 };
2317 
2318 /* TLS 1.2, AES_CCM, ctrl, seqno:1, plaintext: '' */
2319 static const struct raw_rec id1_ctrl_l0 = {
2320 	.plain_len = 0,
2321 	.plain_data = {
2322 	},
2323 	.cipher_len = 29,
2324 	.cipher_data = {
2325 		0x16, 0x03, 0x03, 0x00, 0x18, 0x00, 0x00, 0x00,
2326 		0x00, 0x00, 0x00, 0x00, 0x01, 0x3e, 0xf0, 0xfe,
2327 		0xee, 0xd9, 0xe2, 0x5d, 0xc7, 0x11, 0x4c, 0xe6,
2328 		0xb4, 0x7e, 0xef, 0x40, 0x2b,
2329 	},
2330 };
2331 
2332 /* TLS 1.2, AES_CCM, data, seqno:1, plaintext: '' */
2333 static const struct raw_rec id1_data_l0 = {
2334 	.plain_len = 0,
2335 	.plain_data = {
2336 	},
2337 	.cipher_len = 29,
2338 	.cipher_data = {
2339 		0x17, 0x03, 0x03, 0x00, 0x18, 0x00, 0x00, 0x00,
2340 		0x00, 0x00, 0x00, 0x00, 0x01, 0xce, 0xfc, 0x86,
2341 		0xc8, 0xf0, 0x55, 0xf9, 0x47, 0x3f, 0x74, 0xdc,
2342 		0xc9, 0xbf, 0xfe, 0x5b, 0xb1,
2343 	},
2344 };
2345 
2346 /* TLS 1.2, AES_CCM, ctrl, seqno:2, plaintext: 'Hello world' */
2347 static const struct raw_rec id2_ctrl_l11 = {
2348 	.plain_len = 11,
2349 	.plain_data = {
2350 		0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f,
2351 		0x72, 0x6c, 0x64,
2352 	},
2353 	.cipher_len = 40,
2354 	.cipher_data = {
2355 		0x16, 0x03, 0x03, 0x00, 0x23, 0x00, 0x00, 0x00,
2356 		0x00, 0x00, 0x00, 0x00, 0x02, 0xe5, 0x3d, 0x19,
2357 		0x3d, 0xca, 0xb8, 0x16, 0xb6, 0xff, 0x79, 0x87,
2358 		0x2a, 0x04, 0x11, 0x3d, 0xf8, 0x64, 0x5f, 0x36,
2359 		0x8b, 0xa8, 0xee, 0x4c, 0x6d, 0x62, 0xa5, 0x00,
2360 	},
2361 };
2362 
2363 /* TLS 1.2, AES_CCM, data, seqno:2, plaintext: 'Hello world' */
2364 static const struct raw_rec id2_data_l11 = {
2365 	.plain_len = 11,
2366 	.plain_data = {
2367 		0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f,
2368 		0x72, 0x6c, 0x64,
2369 	},
2370 	.cipher_len = 40,
2371 	.cipher_data = {
2372 		0x17, 0x03, 0x03, 0x00, 0x23, 0x00, 0x00, 0x00,
2373 		0x00, 0x00, 0x00, 0x00, 0x02, 0xe5, 0x3d, 0x19,
2374 		0x3d, 0xca, 0xb8, 0x16, 0xb6, 0xff, 0x79, 0x87,
2375 		0x8e, 0xa1, 0xd0, 0xcd, 0x33, 0xb5, 0x86, 0x2b,
2376 		0x17, 0xf1, 0x52, 0x2a, 0x55, 0x62, 0x65, 0x11,
2377 	},
2378 };
2379 
2380 /* TLS 1.2, AES_CCM, ctrl, seqno:2, plaintext: '' */
2381 static const struct raw_rec id2_ctrl_l0 = {
2382 	.plain_len = 0,
2383 	.plain_data = {
2384 	},
2385 	.cipher_len = 29,
2386 	.cipher_data = {
2387 		0x16, 0x03, 0x03, 0x00, 0x18, 0x00, 0x00, 0x00,
2388 		0x00, 0x00, 0x00, 0x00, 0x02, 0xdc, 0x5c, 0x0e,
2389 		0x41, 0xdd, 0xba, 0xd3, 0xcc, 0xcf, 0x6d, 0xd9,
2390 		0x06, 0xdb, 0x79, 0xe5, 0x5d,
2391 	},
2392 };
2393 
2394 /* TLS 1.2, AES_CCM, data, seqno:2, plaintext: '' */
2395 static const struct raw_rec id2_data_l0 = {
2396 	.plain_len = 0,
2397 	.plain_data = {
2398 	},
2399 	.cipher_len = 29,
2400 	.cipher_data = {
2401 		0x17, 0x03, 0x03, 0x00, 0x18, 0x00, 0x00, 0x00,
2402 		0x00, 0x00, 0x00, 0x00, 0x02, 0xc3, 0xca, 0x26,
2403 		0x22, 0xe4, 0x25, 0xfb, 0x5f, 0x6d, 0xbf, 0x83,
2404 		0x30, 0x48, 0x69, 0x1a, 0x47,
2405 	},
2406 };
2407 
FIXTURE(zero_len)2408 FIXTURE(zero_len)
2409 {
2410 	int fd, cfd;
2411 	bool notls;
2412 };
2413 
FIXTURE_VARIANT(zero_len)2414 FIXTURE_VARIANT(zero_len)
2415 {
2416 	const struct raw_rec *recs[4];
2417 	ssize_t recv_ret[4];
2418 };
2419 
FIXTURE_VARIANT_ADD(zero_len,data_data_data)2420 FIXTURE_VARIANT_ADD(zero_len, data_data_data)
2421 {
2422 	.recs = { &id0_data_l11, &id1_data_l11, &id2_data_l11, },
2423 	.recv_ret = { 33, -EAGAIN, },
2424 };
2425 
FIXTURE_VARIANT_ADD(zero_len,data_0ctrl_data)2426 FIXTURE_VARIANT_ADD(zero_len, data_0ctrl_data)
2427 {
2428 	.recs = { &id0_data_l11, &id1_ctrl_l0, &id2_data_l11, },
2429 	.recv_ret = { 11, 0, 11, -EAGAIN, },
2430 };
2431 
2432 FIXTURE_VARIANT_ADD(zero_len, 0data_0data_0data)
2433 {
2434 	.recs = { &id0_data_l0, &id1_data_l0, &id2_data_l0, },
2435 	.recv_ret = { -EAGAIN, },
2436 };
2437 
2438 FIXTURE_VARIANT_ADD(zero_len, 0data_0data_ctrl)
2439 {
2440 	.recs = { &id0_data_l0, &id1_data_l0, &id2_ctrl_l11, },
2441 	.recv_ret = { 0, 11, -EAGAIN, },
2442 };
2443 
2444 FIXTURE_VARIANT_ADD(zero_len, 0data_0data_0ctrl)
2445 {
2446 	.recs = { &id0_data_l0, &id1_data_l0, &id2_ctrl_l0, },
2447 	.recv_ret = { 0, 0, -EAGAIN, },
2448 };
2449 
2450 FIXTURE_VARIANT_ADD(zero_len, 0ctrl_0ctrl_0ctrl)
2451 {
2452 	.recs = { &id0_ctrl_l0, &id1_ctrl_l0, &id2_ctrl_l0, },
2453 	.recv_ret = { 0, 0, 0, -EAGAIN, },
2454 };
2455 
2456 FIXTURE_VARIANT_ADD(zero_len, 0data_0data_data)
2457 {
2458 	.recs = { &id0_data_l0, &id1_data_l0, &id2_data_l11, },
2459 	.recv_ret = { 11, -EAGAIN, },
2460 };
2461 
FIXTURE_VARIANT_ADD(zero_len,data_0data_0data)2462 FIXTURE_VARIANT_ADD(zero_len, data_0data_0data)
2463 {
2464 	.recs = { &id0_data_l11, &id1_data_l0, &id2_data_l0, },
2465 	.recv_ret = { 11, -EAGAIN, },
2466 };
2467 
FIXTURE_SETUP(zero_len)2468 FIXTURE_SETUP(zero_len)
2469 {
2470 	struct tls_crypto_info_keys tls12;
2471 	int ret;
2472 
2473 	tls_crypto_info_init(TLS_1_2_VERSION, TLS_CIPHER_AES_CCM_128,
2474 			     &tls12, 0);
2475 
2476 	ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
2477 	if (self->notls)
2478 		return;
2479 
2480 	/* Don't install keys on fd, we'll send raw records */
2481 	ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len);
2482 	ASSERT_EQ(ret, 0);
2483 }
2484 
FIXTURE_TEARDOWN(zero_len)2485 FIXTURE_TEARDOWN(zero_len)
2486 {
2487 	close(self->fd);
2488 	close(self->cfd);
2489 }
2490 
TEST_F(zero_len,test)2491 TEST_F(zero_len, test)
2492 {
2493 	const struct raw_rec *const *rec;
2494 	unsigned char buf[128];
2495 	int rec_off;
2496 	int i;
2497 
2498 	for (i = 0; i < 4 && variant->recs[i]; i++)
2499 		EXPECT_EQ(send(self->fd, variant->recs[i]->cipher_data,
2500 			       variant->recs[i]->cipher_len, 0),
2501 			  variant->recs[i]->cipher_len);
2502 
2503 	rec = &variant->recs[0];
2504 	rec_off = 0;
2505 	for (i = 0; i < 4; i++) {
2506 		int j, ret;
2507 
2508 		ret = variant->recv_ret[i] >= 0 ? variant->recv_ret[i] : -1;
2509 		EXPECT_EQ(__tls_recv_cmsg(_metadata, self->cfd, NULL,
2510 					  buf, sizeof(buf), MSG_DONTWAIT), ret);
2511 		if (ret == -1)
2512 			EXPECT_EQ(errno, -variant->recv_ret[i]);
2513 		if (variant->recv_ret[i] == -EAGAIN)
2514 			break;
2515 
2516 		for (j = 0; j < ret; j++) {
2517 			while (rec_off == (*rec)->plain_len) {
2518 				rec++;
2519 				rec_off = 0;
2520 			}
2521 			EXPECT_EQ(buf[j], (*rec)->plain_data[rec_off]);
2522 			rec_off++;
2523 		}
2524 	}
2525 };
2526 
FIXTURE(tls_err)2527 FIXTURE(tls_err)
2528 {
2529 	int fd, cfd;
2530 	int fd2, cfd2;
2531 	bool notls;
2532 };
2533 
FIXTURE_VARIANT(tls_err)2534 FIXTURE_VARIANT(tls_err)
2535 {
2536 	uint16_t tls_version;
2537 };
2538 
2539 FIXTURE_VARIANT_ADD(tls_err, 12_aes_gcm)
2540 {
2541 	.tls_version = TLS_1_2_VERSION,
2542 };
2543 
2544 FIXTURE_VARIANT_ADD(tls_err, 13_aes_gcm)
2545 {
2546 	.tls_version = TLS_1_3_VERSION,
2547 };
2548 
FIXTURE_SETUP(tls_err)2549 FIXTURE_SETUP(tls_err)
2550 {
2551 	struct tls_crypto_info_keys tls12;
2552 	int ret;
2553 
2554 	tls_crypto_info_init(variant->tls_version, TLS_CIPHER_AES_GCM_128,
2555 			     &tls12, 0);
2556 
2557 	ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
2558 	ulp_sock_pair(_metadata, &self->fd2, &self->cfd2, &self->notls);
2559 	if (self->notls)
2560 		return;
2561 
2562 	ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
2563 	ASSERT_EQ(ret, 0);
2564 
2565 	ret = setsockopt(self->cfd2, SOL_TLS, TLS_RX, &tls12, tls12.len);
2566 	ASSERT_EQ(ret, 0);
2567 }
2568 
FIXTURE_TEARDOWN(tls_err)2569 FIXTURE_TEARDOWN(tls_err)
2570 {
2571 	close(self->fd);
2572 	close(self->cfd);
2573 	close(self->fd2);
2574 	close(self->cfd2);
2575 }
2576 
TEST_F(tls_err,bad_rec)2577 TEST_F(tls_err, bad_rec)
2578 {
2579 	char buf[64];
2580 
2581 	if (self->notls)
2582 		SKIP(return, "no TLS support");
2583 
2584 	memset(buf, 0x55, sizeof(buf));
2585 	EXPECT_EQ(send(self->fd2, buf, sizeof(buf), 0), sizeof(buf));
2586 	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2587 	EXPECT_EQ(errno, EMSGSIZE);
2588 	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), MSG_DONTWAIT), -1);
2589 	EXPECT_EQ(errno, EAGAIN);
2590 }
2591 
TEST_F(tls_err,bad_auth)2592 TEST_F(tls_err, bad_auth)
2593 {
2594 	char buf[128];
2595 	int n;
2596 
2597 	if (self->notls)
2598 		SKIP(return, "no TLS support");
2599 
2600 	memrnd(buf, sizeof(buf) / 2);
2601 	EXPECT_EQ(send(self->fd, buf, sizeof(buf) / 2, 0), sizeof(buf) / 2);
2602 	n = recv(self->cfd, buf, sizeof(buf), 0);
2603 	EXPECT_GT(n, sizeof(buf) / 2);
2604 
2605 	buf[n - 1]++;
2606 
2607 	EXPECT_EQ(send(self->fd2, buf, n, 0), n);
2608 	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2609 	EXPECT_EQ(errno, EBADMSG);
2610 	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2611 	EXPECT_EQ(errno, EBADMSG);
2612 }
2613 
TEST_F(tls_err,bad_in_large_read)2614 TEST_F(tls_err, bad_in_large_read)
2615 {
2616 	char txt[3][64];
2617 	char cip[3][128];
2618 	char buf[3 * 128];
2619 	int i, n;
2620 
2621 	if (self->notls)
2622 		SKIP(return, "no TLS support");
2623 
2624 	/* Put 3 records in the sockets */
2625 	for (i = 0; i < 3; i++) {
2626 		memrnd(txt[i], sizeof(txt[i]));
2627 		EXPECT_EQ(send(self->fd, txt[i], sizeof(txt[i]), 0),
2628 			  sizeof(txt[i]));
2629 		n = recv(self->cfd, cip[i], sizeof(cip[i]), 0);
2630 		EXPECT_GT(n, sizeof(txt[i]));
2631 		/* Break the third message */
2632 		if (i == 2)
2633 			cip[2][n - 1]++;
2634 		EXPECT_EQ(send(self->fd2, cip[i], n, 0), n);
2635 	}
2636 
2637 	/* We should be able to receive the first two messages */
2638 	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), sizeof(txt[0]) * 2);
2639 	EXPECT_EQ(memcmp(buf, txt[0], sizeof(txt[0])), 0);
2640 	EXPECT_EQ(memcmp(buf + sizeof(txt[0]), txt[1], sizeof(txt[1])), 0);
2641 	/* Third mesasge is bad */
2642 	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2643 	EXPECT_EQ(errno, EBADMSG);
2644 	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2645 	EXPECT_EQ(errno, EBADMSG);
2646 }
2647 
TEST_F(tls_err,bad_cmsg)2648 TEST_F(tls_err, bad_cmsg)
2649 {
2650 	char *test_str = "test_read";
2651 	int send_len = 10;
2652 	char cip[128];
2653 	char buf[128];
2654 	char txt[64];
2655 	int n;
2656 
2657 	if (self->notls)
2658 		SKIP(return, "no TLS support");
2659 
2660 	/* Queue up one data record */
2661 	memrnd(txt, sizeof(txt));
2662 	EXPECT_EQ(send(self->fd, txt, sizeof(txt), 0), sizeof(txt));
2663 	n = recv(self->cfd, cip, sizeof(cip), 0);
2664 	EXPECT_GT(n, sizeof(txt));
2665 	EXPECT_EQ(send(self->fd2, cip, n, 0), n);
2666 
2667 	EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
2668 	n = recv(self->cfd, cip, sizeof(cip), 0);
2669 	cip[n - 1]++; /* Break it */
2670 	EXPECT_GT(n, send_len);
2671 	EXPECT_EQ(send(self->fd2, cip, n, 0), n);
2672 
2673 	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), sizeof(txt));
2674 	EXPECT_EQ(memcmp(buf, txt, sizeof(txt)), 0);
2675 	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2676 	EXPECT_EQ(errno, EBADMSG);
2677 	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2678 	EXPECT_EQ(errno, EBADMSG);
2679 }
2680 
TEST_F(tls_err,timeo)2681 TEST_F(tls_err, timeo)
2682 {
2683 	struct timeval tv = { .tv_usec = 10000, };
2684 	char buf[128];
2685 	int ret;
2686 
2687 	if (self->notls)
2688 		SKIP(return, "no TLS support");
2689 
2690 	ret = setsockopt(self->cfd2, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
2691 	ASSERT_EQ(ret, 0);
2692 
2693 	ret = fork();
2694 	ASSERT_GE(ret, 0);
2695 
2696 	if (ret) {
2697 		usleep(1000); /* Give child a head start */
2698 
2699 		EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2700 		EXPECT_EQ(errno, EAGAIN);
2701 
2702 		EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2703 		EXPECT_EQ(errno, EAGAIN);
2704 
2705 		wait(&ret);
2706 	} else {
2707 		EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2708 		EXPECT_EQ(errno, EAGAIN);
2709 		exit(0);
2710 	}
2711 }
2712 
TEST_F(tls_err,poll_partial_rec)2713 TEST_F(tls_err, poll_partial_rec)
2714 {
2715 	struct pollfd pfd = { };
2716 	ssize_t rec_len;
2717 	char rec[256];
2718 	char buf[128];
2719 
2720 	if (self->notls)
2721 		SKIP(return, "no TLS support");
2722 
2723 	pfd.fd = self->cfd2;
2724 	pfd.events = POLLIN;
2725 	EXPECT_EQ(poll(&pfd, 1, 1), 0);
2726 
2727 	memrnd(buf, sizeof(buf));
2728 	EXPECT_EQ(send(self->fd, buf, sizeof(buf), 0), sizeof(buf));
2729 	rec_len = recv(self->cfd, rec, sizeof(rec), 0);
2730 	EXPECT_GT(rec_len, sizeof(buf));
2731 
2732 	/* Write 100B, not the full record ... */
2733 	EXPECT_EQ(send(self->fd2, rec, 100, 0), 100);
2734 	/* ... no full record should mean no POLLIN */
2735 	pfd.fd = self->cfd2;
2736 	pfd.events = POLLIN;
2737 	EXPECT_EQ(poll(&pfd, 1, 1), 0);
2738 	/* Now write the rest, and it should all pop out of the other end. */
2739 	EXPECT_EQ(send(self->fd2, rec + 100, rec_len - 100, 0), rec_len - 100);
2740 	pfd.fd = self->cfd2;
2741 	pfd.events = POLLIN;
2742 	EXPECT_EQ(poll(&pfd, 1, 1), 1);
2743 	EXPECT_EQ(recv(self->cfd2, rec, sizeof(rec), 0), sizeof(buf));
2744 	EXPECT_EQ(memcmp(buf, rec, sizeof(buf)), 0);
2745 }
2746 
TEST_F(tls_err,epoll_partial_rec)2747 TEST_F(tls_err, epoll_partial_rec)
2748 {
2749 	struct epoll_event ev, events[10];
2750 	ssize_t rec_len;
2751 	char rec[256];
2752 	char buf[128];
2753 	int epollfd;
2754 
2755 	if (self->notls)
2756 		SKIP(return, "no TLS support");
2757 
2758 	epollfd = epoll_create1(0);
2759 	ASSERT_GE(epollfd, 0);
2760 
2761 	memset(&ev, 0, sizeof(ev));
2762 	ev.events = EPOLLIN;
2763 	ev.data.fd = self->cfd2;
2764 	ASSERT_GE(epoll_ctl(epollfd, EPOLL_CTL_ADD, self->cfd2, &ev), 0);
2765 
2766 	EXPECT_EQ(epoll_wait(epollfd, events, 10, 0), 0);
2767 
2768 	memrnd(buf, sizeof(buf));
2769 	EXPECT_EQ(send(self->fd, buf, sizeof(buf), 0), sizeof(buf));
2770 	rec_len = recv(self->cfd, rec, sizeof(rec), 0);
2771 	EXPECT_GT(rec_len, sizeof(buf));
2772 
2773 	/* Write 100B, not the full record ... */
2774 	EXPECT_EQ(send(self->fd2, rec, 100, 0), 100);
2775 	/* ... no full record should mean no POLLIN */
2776 	EXPECT_EQ(epoll_wait(epollfd, events, 10, 0), 0);
2777 	/* Now write the rest, and it should all pop out of the other end. */
2778 	EXPECT_EQ(send(self->fd2, rec + 100, rec_len - 100, 0), rec_len - 100);
2779 	EXPECT_EQ(epoll_wait(epollfd, events, 10, 0), 1);
2780 	EXPECT_EQ(recv(self->cfd2, rec, sizeof(rec), 0), sizeof(buf));
2781 	EXPECT_EQ(memcmp(buf, rec, sizeof(buf)), 0);
2782 
2783 	close(epollfd);
2784 }
2785 
TEST_F(tls_err,poll_partial_rec_async)2786 TEST_F(tls_err, poll_partial_rec_async)
2787 {
2788 	struct pollfd pfd = { };
2789 	ssize_t rec_len;
2790 	char rec[256];
2791 	char buf[128];
2792 	char token;
2793 	int p[2];
2794 	int ret;
2795 
2796 	if (self->notls)
2797 		SKIP(return, "no TLS support");
2798 
2799 	ASSERT_GE(pipe(p), 0);
2800 
2801 	memrnd(buf, sizeof(buf));
2802 	EXPECT_EQ(send(self->fd, buf, sizeof(buf), 0), sizeof(buf));
2803 	rec_len = recv(self->cfd, rec, sizeof(rec), 0);
2804 	EXPECT_GT(rec_len, sizeof(buf));
2805 
2806 	ret = fork();
2807 	ASSERT_GE(ret, 0);
2808 
2809 	if (ret) {
2810 		int status, pid2;
2811 
2812 		close(p[1]);
2813 		usleep(1000); /* Give child a head start */
2814 
2815 		EXPECT_EQ(send(self->fd2, rec, 100, 0), 100);
2816 
2817 		EXPECT_EQ(read(p[0], &token, 1), 1); /* Barrier #1 */
2818 
2819 		EXPECT_EQ(send(self->fd2, rec + 100, rec_len - 100, 0),
2820 			  rec_len - 100);
2821 
2822 		pid2 = wait(&status);
2823 		EXPECT_EQ(pid2, ret);
2824 		EXPECT_EQ(status, 0);
2825 	} else {
2826 		close(p[0]);
2827 
2828 		/* Child should sleep in poll(), never get a wake */
2829 		pfd.fd = self->cfd2;
2830 		pfd.events = POLLIN;
2831 		EXPECT_EQ(poll(&pfd, 1, 20), 0);
2832 
2833 		EXPECT_EQ(write(p[1], &token, 1), 1); /* Barrier #1 */
2834 
2835 		pfd.fd = self->cfd2;
2836 		pfd.events = POLLIN;
2837 		EXPECT_EQ(poll(&pfd, 1, 20), 1);
2838 
2839 		exit(!__test_passed(_metadata));
2840 	}
2841 }
2842 
2843 /* Use OOB+large send to trigger copy mode due to memory pressure.
2844  * OOB causes a short read.
2845  */
TEST_F(tls_err,oob_pressure)2846 TEST_F(tls_err, oob_pressure)
2847 {
2848 	char buf[1<<16];
2849 	int i;
2850 
2851 	memrnd(buf, sizeof(buf));
2852 
2853 	EXPECT_EQ(send(self->fd2, buf, 5, MSG_OOB), 5);
2854 	EXPECT_EQ(send(self->fd2, buf, sizeof(buf), 0), sizeof(buf));
2855 	for (i = 0; i < 64; i++)
2856 		EXPECT_EQ(send(self->fd2, buf, 5, MSG_OOB), 5);
2857 }
2858 
TEST(non_established)2859 TEST(non_established) {
2860 	struct tls12_crypto_info_aes_gcm_256 tls12;
2861 	struct sockaddr_in addr;
2862 	int sfd, ret, fd;
2863 	socklen_t len;
2864 
2865 	len = sizeof(addr);
2866 
2867 	memset(&tls12, 0, sizeof(tls12));
2868 	tls12.info.version = TLS_1_2_VERSION;
2869 	tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
2870 
2871 	addr.sin_family = AF_INET;
2872 	addr.sin_addr.s_addr = htonl(INADDR_ANY);
2873 	addr.sin_port = 0;
2874 
2875 	fd = socket(AF_INET, SOCK_STREAM, 0);
2876 	sfd = socket(AF_INET, SOCK_STREAM, 0);
2877 
2878 	ret = bind(sfd, &addr, sizeof(addr));
2879 	ASSERT_EQ(ret, 0);
2880 	ret = listen(sfd, 10);
2881 	ASSERT_EQ(ret, 0);
2882 
2883 	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2884 	EXPECT_EQ(ret, -1);
2885 	/* TLS ULP not supported */
2886 	if (errno == ENOENT)
2887 		return;
2888 	EXPECT_EQ(errno, ENOTCONN);
2889 
2890 	ret = setsockopt(sfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2891 	EXPECT_EQ(ret, -1);
2892 	EXPECT_EQ(errno, ENOTCONN);
2893 
2894 	ret = getsockname(sfd, &addr, &len);
2895 	ASSERT_EQ(ret, 0);
2896 
2897 	ret = connect(fd, &addr, sizeof(addr));
2898 	ASSERT_EQ(ret, 0);
2899 
2900 	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2901 	ASSERT_EQ(ret, 0);
2902 
2903 	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2904 	EXPECT_EQ(ret, -1);
2905 	EXPECT_EQ(errno, EEXIST);
2906 
2907 	close(fd);
2908 	close(sfd);
2909 }
2910 
TEST(keysizes)2911 TEST(keysizes) {
2912 	struct tls12_crypto_info_aes_gcm_256 tls12;
2913 	int ret, fd, cfd;
2914 	bool notls;
2915 
2916 	memset(&tls12, 0, sizeof(tls12));
2917 	tls12.info.version = TLS_1_2_VERSION;
2918 	tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
2919 
2920 	ulp_sock_pair(_metadata, &fd, &cfd, &notls);
2921 
2922 	if (!notls) {
2923 		ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12,
2924 				 sizeof(tls12));
2925 		EXPECT_EQ(ret, 0);
2926 
2927 		ret = setsockopt(cfd, SOL_TLS, TLS_RX, &tls12,
2928 				 sizeof(tls12));
2929 		EXPECT_EQ(ret, 0);
2930 	}
2931 
2932 	close(fd);
2933 	close(cfd);
2934 }
2935 
TEST(no_pad)2936 TEST(no_pad) {
2937 	struct tls12_crypto_info_aes_gcm_256 tls12;
2938 	int ret, fd, cfd, val;
2939 	socklen_t len;
2940 	bool notls;
2941 
2942 	memset(&tls12, 0, sizeof(tls12));
2943 	tls12.info.version = TLS_1_3_VERSION;
2944 	tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
2945 
2946 	ulp_sock_pair(_metadata, &fd, &cfd, &notls);
2947 
2948 	if (notls)
2949 		exit(KSFT_SKIP);
2950 
2951 	ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12, sizeof(tls12));
2952 	EXPECT_EQ(ret, 0);
2953 
2954 	ret = setsockopt(cfd, SOL_TLS, TLS_RX, &tls12, sizeof(tls12));
2955 	EXPECT_EQ(ret, 0);
2956 
2957 	val = 1;
2958 	ret = setsockopt(cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
2959 			 (void *)&val, sizeof(val));
2960 	EXPECT_EQ(ret, 0);
2961 
2962 	len = sizeof(val);
2963 	val = 2;
2964 	ret = getsockopt(cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
2965 			 (void *)&val, &len);
2966 	EXPECT_EQ(ret, 0);
2967 	EXPECT_EQ(val, 1);
2968 	EXPECT_EQ(len, 4);
2969 
2970 	val = 0;
2971 	ret = setsockopt(cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
2972 			 (void *)&val, sizeof(val));
2973 	EXPECT_EQ(ret, 0);
2974 
2975 	len = sizeof(val);
2976 	val = 2;
2977 	ret = getsockopt(cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
2978 			 (void *)&val, &len);
2979 	EXPECT_EQ(ret, 0);
2980 	EXPECT_EQ(val, 0);
2981 	EXPECT_EQ(len, 4);
2982 
2983 	close(fd);
2984 	close(cfd);
2985 }
2986 
TEST(tls_v6ops)2987 TEST(tls_v6ops) {
2988 	struct tls_crypto_info_keys tls12;
2989 	struct sockaddr_in6 addr, addr2;
2990 	int sfd, ret, fd;
2991 	socklen_t len, len2;
2992 
2993 	tls_crypto_info_init(TLS_1_2_VERSION, TLS_CIPHER_AES_GCM_128, &tls12, 0);
2994 
2995 	addr.sin6_family = AF_INET6;
2996 	addr.sin6_addr = in6addr_any;
2997 	addr.sin6_port = 0;
2998 
2999 	fd = socket(AF_INET6, SOCK_STREAM, 0);
3000 	sfd = socket(AF_INET6, SOCK_STREAM, 0);
3001 
3002 	ret = bind(sfd, &addr, sizeof(addr));
3003 	ASSERT_EQ(ret, 0);
3004 	ret = listen(sfd, 10);
3005 	ASSERT_EQ(ret, 0);
3006 
3007 	len = sizeof(addr);
3008 	ret = getsockname(sfd, &addr, &len);
3009 	ASSERT_EQ(ret, 0);
3010 
3011 	ret = connect(fd, &addr, sizeof(addr));
3012 	ASSERT_EQ(ret, 0);
3013 
3014 	len = sizeof(addr);
3015 	ret = getsockname(fd, &addr, &len);
3016 	ASSERT_EQ(ret, 0);
3017 
3018 	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
3019 	if (ret) {
3020 		ASSERT_EQ(errno, ENOENT);
3021 		SKIP(return, "no TLS support");
3022 	}
3023 	ASSERT_EQ(ret, 0);
3024 
3025 	ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
3026 	ASSERT_EQ(ret, 0);
3027 
3028 	ret = setsockopt(fd, SOL_TLS, TLS_RX, &tls12, tls12.len);
3029 	ASSERT_EQ(ret, 0);
3030 
3031 	len2 = sizeof(addr2);
3032 	ret = getsockname(fd, &addr2, &len2);
3033 	ASSERT_EQ(ret, 0);
3034 
3035 	EXPECT_EQ(len2, len);
3036 	EXPECT_EQ(memcmp(&addr, &addr2, len), 0);
3037 
3038 	close(fd);
3039 	close(sfd);
3040 }
3041 
TEST(prequeue)3042 TEST(prequeue) {
3043 	struct tls_crypto_info_keys tls12;
3044 	char buf[20000], buf2[20000];
3045 	struct sockaddr_in addr;
3046 	int sfd, cfd, ret, fd;
3047 	socklen_t len;
3048 
3049 	len = sizeof(addr);
3050 	memrnd(buf, sizeof(buf));
3051 
3052 	tls_crypto_info_init(TLS_1_2_VERSION, TLS_CIPHER_AES_GCM_256, &tls12, 0);
3053 
3054 	addr.sin_family = AF_INET;
3055 	addr.sin_addr.s_addr = htonl(INADDR_ANY);
3056 	addr.sin_port = 0;
3057 
3058 	fd = socket(AF_INET, SOCK_STREAM, 0);
3059 	sfd = socket(AF_INET, SOCK_STREAM, 0);
3060 
3061 	ASSERT_EQ(bind(sfd, &addr, sizeof(addr)), 0);
3062 	ASSERT_EQ(listen(sfd, 10), 0);
3063 	ASSERT_EQ(getsockname(sfd, &addr, &len), 0);
3064 	ASSERT_EQ(connect(fd, &addr, sizeof(addr)), 0);
3065 	ASSERT_GE(cfd = accept(sfd, &addr, &len), 0);
3066 	close(sfd);
3067 
3068 	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
3069 	if (ret) {
3070 		ASSERT_EQ(errno, ENOENT);
3071 		SKIP(return, "no TLS support");
3072 	}
3073 
3074 	ASSERT_EQ(setsockopt(fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
3075 	EXPECT_EQ(send(fd, buf, sizeof(buf), MSG_DONTWAIT), sizeof(buf));
3076 
3077 	ASSERT_EQ(setsockopt(cfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls")), 0);
3078 	ASSERT_EQ(setsockopt(cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
3079 	EXPECT_EQ(recv(cfd, buf2, sizeof(buf2), MSG_WAITALL), sizeof(buf2));
3080 
3081 	EXPECT_EQ(memcmp(buf, buf2, sizeof(buf)), 0);
3082 
3083 	close(fd);
3084 	close(cfd);
3085 }
3086 
TEST(data_steal)3087 TEST(data_steal) {
3088 	struct tls_crypto_info_keys tls;
3089 	char buf[20000], buf2[20000];
3090 	struct sockaddr_in addr;
3091 	int sfd, cfd, ret, fd;
3092 	int pid, status;
3093 	socklen_t len;
3094 
3095 	len = sizeof(addr);
3096 	memrnd(buf, sizeof(buf));
3097 
3098 	tls_crypto_info_init(TLS_1_2_VERSION, TLS_CIPHER_AES_GCM_256, &tls, 0);
3099 
3100 	addr.sin_family = AF_INET;
3101 	addr.sin_addr.s_addr = htonl(INADDR_ANY);
3102 	addr.sin_port = 0;
3103 
3104 	fd = socket(AF_INET, SOCK_STREAM, 0);
3105 	sfd = socket(AF_INET, SOCK_STREAM, 0);
3106 
3107 	ASSERT_EQ(bind(sfd, &addr, sizeof(addr)), 0);
3108 	ASSERT_EQ(listen(sfd, 10), 0);
3109 	ASSERT_EQ(getsockname(sfd, &addr, &len), 0);
3110 	ASSERT_EQ(connect(fd, &addr, sizeof(addr)), 0);
3111 	ASSERT_GE(cfd = accept(sfd, &addr, &len), 0);
3112 	close(sfd);
3113 
3114 	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
3115 	if (ret) {
3116 		ASSERT_EQ(errno, ENOENT);
3117 		SKIP(return, "no TLS support");
3118 	}
3119 	ASSERT_EQ(setsockopt(cfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls")), 0);
3120 
3121 	/* Spawn a child and get it into the read wait path of the underlying
3122 	 * TCP socket.
3123 	 */
3124 	pid = fork();
3125 	ASSERT_GE(pid, 0);
3126 	if (!pid) {
3127 		EXPECT_EQ(recv(cfd, buf, sizeof(buf) / 2, MSG_WAITALL),
3128 			  sizeof(buf) / 2);
3129 		exit(!__test_passed(_metadata));
3130 	}
3131 
3132 	usleep(10000);
3133 	ASSERT_EQ(setsockopt(fd, SOL_TLS, TLS_TX, &tls, tls.len), 0);
3134 	ASSERT_EQ(setsockopt(cfd, SOL_TLS, TLS_RX, &tls, tls.len), 0);
3135 
3136 	EXPECT_EQ(send(fd, buf, sizeof(buf), 0), sizeof(buf));
3137 	EXPECT_EQ(wait(&status), pid);
3138 	EXPECT_EQ(status, 0);
3139 	EXPECT_EQ(recv(cfd, buf2, sizeof(buf2), MSG_DONTWAIT), -1);
3140 	/* Don't check errno, the error will be different depending
3141 	 * on what random bytes TLS interpreted as the record length.
3142 	 */
3143 
3144 	close(fd);
3145 	close(cfd);
3146 }
3147 
fips_check(void)3148 static void __attribute__((constructor)) fips_check(void) {
3149 	int res;
3150 	FILE *f;
3151 
3152 	f = fopen("/proc/sys/crypto/fips_enabled", "r");
3153 	if (f) {
3154 		res = fscanf(f, "%d", &fips_enabled);
3155 		if (res != 1)
3156 			ksft_print_msg("ERROR: Couldn't read /proc/sys/crypto/fips_enabled\n");
3157 		fclose(f);
3158 	}
3159 }
3160 
3161 TEST_HARNESS_MAIN
3162