xref: /linux/tools/testing/selftests/net/tls.c (revision 6439a0e64c355d2e375bd094f365d56ce81faba3)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 #define _GNU_SOURCE
4 
5 #include <arpa/inet.h>
6 #include <errno.h>
7 #include <error.h>
8 #include <fcntl.h>
9 #include <poll.h>
10 #include <stdio.h>
11 #include <stdlib.h>
12 #include <unistd.h>
13 
14 #include <linux/tls.h>
15 #include <linux/tcp.h>
16 #include <linux/socket.h>
17 
18 #include <sys/epoll.h>
19 #include <sys/types.h>
20 #include <sys/sendfile.h>
21 #include <sys/socket.h>
22 #include <sys/stat.h>
23 
24 #include "../kselftest_harness.h"
25 
26 #define TLS_PAYLOAD_MAX_LEN 16384
27 #define SOL_TLS 282
28 
29 static int fips_enabled;
30 
31 struct tls_crypto_info_keys {
32 	union {
33 		struct tls_crypto_info crypto_info;
34 		struct tls12_crypto_info_aes_gcm_128 aes128;
35 		struct tls12_crypto_info_chacha20_poly1305 chacha20;
36 		struct tls12_crypto_info_sm4_gcm sm4gcm;
37 		struct tls12_crypto_info_sm4_ccm sm4ccm;
38 		struct tls12_crypto_info_aes_ccm_128 aesccm128;
39 		struct tls12_crypto_info_aes_gcm_256 aesgcm256;
40 		struct tls12_crypto_info_aria_gcm_128 ariagcm128;
41 		struct tls12_crypto_info_aria_gcm_256 ariagcm256;
42 	};
43 	size_t len;
44 };
45 
tls_crypto_info_init(uint16_t tls_version,uint16_t cipher_type,struct tls_crypto_info_keys * tls12,char key_generation)46 static void tls_crypto_info_init(uint16_t tls_version, uint16_t cipher_type,
47 				 struct tls_crypto_info_keys *tls12,
48 				 char key_generation)
49 {
50 	memset(tls12, key_generation, sizeof(*tls12));
51 	memset(tls12, 0, sizeof(struct tls_crypto_info));
52 
53 	switch (cipher_type) {
54 	case TLS_CIPHER_CHACHA20_POLY1305:
55 		tls12->len = sizeof(struct tls12_crypto_info_chacha20_poly1305);
56 		tls12->chacha20.info.version = tls_version;
57 		tls12->chacha20.info.cipher_type = cipher_type;
58 		break;
59 	case TLS_CIPHER_AES_GCM_128:
60 		tls12->len = sizeof(struct tls12_crypto_info_aes_gcm_128);
61 		tls12->aes128.info.version = tls_version;
62 		tls12->aes128.info.cipher_type = cipher_type;
63 		break;
64 	case TLS_CIPHER_SM4_GCM:
65 		tls12->len = sizeof(struct tls12_crypto_info_sm4_gcm);
66 		tls12->sm4gcm.info.version = tls_version;
67 		tls12->sm4gcm.info.cipher_type = cipher_type;
68 		break;
69 	case TLS_CIPHER_SM4_CCM:
70 		tls12->len = sizeof(struct tls12_crypto_info_sm4_ccm);
71 		tls12->sm4ccm.info.version = tls_version;
72 		tls12->sm4ccm.info.cipher_type = cipher_type;
73 		break;
74 	case TLS_CIPHER_AES_CCM_128:
75 		tls12->len = sizeof(struct tls12_crypto_info_aes_ccm_128);
76 		tls12->aesccm128.info.version = tls_version;
77 		tls12->aesccm128.info.cipher_type = cipher_type;
78 		break;
79 	case TLS_CIPHER_AES_GCM_256:
80 		tls12->len = sizeof(struct tls12_crypto_info_aes_gcm_256);
81 		tls12->aesgcm256.info.version = tls_version;
82 		tls12->aesgcm256.info.cipher_type = cipher_type;
83 		break;
84 	case TLS_CIPHER_ARIA_GCM_128:
85 		tls12->len = sizeof(struct tls12_crypto_info_aria_gcm_128);
86 		tls12->ariagcm128.info.version = tls_version;
87 		tls12->ariagcm128.info.cipher_type = cipher_type;
88 		break;
89 	case TLS_CIPHER_ARIA_GCM_256:
90 		tls12->len = sizeof(struct tls12_crypto_info_aria_gcm_256);
91 		tls12->ariagcm256.info.version = tls_version;
92 		tls12->ariagcm256.info.cipher_type = cipher_type;
93 		break;
94 	default:
95 		break;
96 	}
97 }
98 
memrnd(void * s,size_t n)99 static void memrnd(void *s, size_t n)
100 {
101 	int *dword = s;
102 	char *byte;
103 
104 	for (; n >= 4; n -= 4)
105 		*dword++ = rand();
106 	byte = (void *)dword;
107 	while (n--)
108 		*byte++ = rand();
109 }
110 
ulp_sock_pair(struct __test_metadata * _metadata,int * fd,int * cfd,bool * notls)111 static void ulp_sock_pair(struct __test_metadata *_metadata,
112 			  int *fd, int *cfd, bool *notls)
113 {
114 	struct sockaddr_in addr;
115 	socklen_t len;
116 	int sfd, ret;
117 
118 	*notls = false;
119 	len = sizeof(addr);
120 
121 	addr.sin_family = AF_INET;
122 	addr.sin_addr.s_addr = htonl(INADDR_ANY);
123 	addr.sin_port = 0;
124 
125 	*fd = socket(AF_INET, SOCK_STREAM, 0);
126 	sfd = socket(AF_INET, SOCK_STREAM, 0);
127 
128 	ret = bind(sfd, &addr, sizeof(addr));
129 	ASSERT_EQ(ret, 0);
130 	ret = listen(sfd, 10);
131 	ASSERT_EQ(ret, 0);
132 
133 	ret = getsockname(sfd, &addr, &len);
134 	ASSERT_EQ(ret, 0);
135 
136 	ret = connect(*fd, &addr, sizeof(addr));
137 	ASSERT_EQ(ret, 0);
138 
139 	*cfd = accept(sfd, &addr, &len);
140 	ASSERT_GE(*cfd, 0);
141 
142 	close(sfd);
143 
144 	ret = setsockopt(*fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
145 	if (ret != 0) {
146 		ASSERT_EQ(errno, ENOENT);
147 		*notls = true;
148 		printf("Failure setting TCP_ULP, testing without tls\n");
149 		return;
150 	}
151 
152 	ret = setsockopt(*cfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
153 	ASSERT_EQ(ret, 0);
154 }
155 
156 /* Produce a basic cmsg */
tls_send_cmsg(int fd,unsigned char record_type,void * data,size_t len,int flags)157 static int tls_send_cmsg(int fd, unsigned char record_type,
158 			 void *data, size_t len, int flags)
159 {
160 	char cbuf[CMSG_SPACE(sizeof(char))];
161 	int cmsg_len = sizeof(char);
162 	struct cmsghdr *cmsg;
163 	struct msghdr msg;
164 	struct iovec vec;
165 
166 	vec.iov_base = data;
167 	vec.iov_len = len;
168 	memset(&msg, 0, sizeof(struct msghdr));
169 	msg.msg_iov = &vec;
170 	msg.msg_iovlen = 1;
171 	msg.msg_control = cbuf;
172 	msg.msg_controllen = sizeof(cbuf);
173 	cmsg = CMSG_FIRSTHDR(&msg);
174 	cmsg->cmsg_level = SOL_TLS;
175 	/* test sending non-record types. */
176 	cmsg->cmsg_type = TLS_SET_RECORD_TYPE;
177 	cmsg->cmsg_len = CMSG_LEN(cmsg_len);
178 	*CMSG_DATA(cmsg) = record_type;
179 	msg.msg_controllen = cmsg->cmsg_len;
180 
181 	return sendmsg(fd, &msg, flags);
182 }
183 
__tls_recv_cmsg(struct __test_metadata * _metadata,int fd,unsigned char * ctype,void * data,size_t len,int flags)184 static int __tls_recv_cmsg(struct __test_metadata *_metadata,
185 			   int fd, unsigned char *ctype,
186 			   void *data, size_t len, int flags)
187 {
188 	char cbuf[CMSG_SPACE(sizeof(char))];
189 	struct cmsghdr *cmsg;
190 	struct msghdr msg;
191 	struct iovec vec;
192 	int n;
193 
194 	vec.iov_base = data;
195 	vec.iov_len = len;
196 	memset(&msg, 0, sizeof(struct msghdr));
197 	msg.msg_iov = &vec;
198 	msg.msg_iovlen = 1;
199 	msg.msg_control = cbuf;
200 	msg.msg_controllen = sizeof(cbuf);
201 
202 	n = recvmsg(fd, &msg, flags);
203 
204 	cmsg = CMSG_FIRSTHDR(&msg);
205 	EXPECT_NE(cmsg, NULL);
206 	EXPECT_EQ(cmsg->cmsg_level, SOL_TLS);
207 	EXPECT_EQ(cmsg->cmsg_type, TLS_GET_RECORD_TYPE);
208 	if (ctype)
209 		*ctype = *((unsigned char *)CMSG_DATA(cmsg));
210 
211 	return n;
212 }
213 
tls_recv_cmsg(struct __test_metadata * _metadata,int fd,unsigned char record_type,void * data,size_t len,int flags)214 static int tls_recv_cmsg(struct __test_metadata *_metadata,
215 			 int fd, unsigned char record_type,
216 			 void *data, size_t len, int flags)
217 {
218 	unsigned char ctype;
219 	int n;
220 
221 	n = __tls_recv_cmsg(_metadata, fd, &ctype, data, len, flags);
222 	EXPECT_EQ(ctype, record_type);
223 
224 	return n;
225 }
226 
FIXTURE(tls_basic)227 FIXTURE(tls_basic)
228 {
229 	int fd, cfd;
230 	bool notls;
231 };
232 
FIXTURE_SETUP(tls_basic)233 FIXTURE_SETUP(tls_basic)
234 {
235 	ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
236 }
237 
FIXTURE_TEARDOWN(tls_basic)238 FIXTURE_TEARDOWN(tls_basic)
239 {
240 	close(self->fd);
241 	close(self->cfd);
242 }
243 
244 /* Send some data through with ULP but no keys */
TEST_F(tls_basic,base_base)245 TEST_F(tls_basic, base_base)
246 {
247 	char const *test_str = "test_read";
248 	int send_len = 10;
249 	char buf[10];
250 
251 	ASSERT_EQ(strlen(test_str) + 1, send_len);
252 
253 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
254 	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
255 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
256 };
257 
TEST_F(tls_basic,bad_cipher)258 TEST_F(tls_basic, bad_cipher)
259 {
260 	struct tls_crypto_info_keys tls12;
261 
262 	tls12.crypto_info.version = 200;
263 	tls12.crypto_info.cipher_type = TLS_CIPHER_AES_GCM_128;
264 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
265 
266 	tls12.crypto_info.version = TLS_1_2_VERSION;
267 	tls12.crypto_info.cipher_type = 50;
268 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
269 
270 	tls12.crypto_info.version = TLS_1_2_VERSION;
271 	tls12.crypto_info.cipher_type = 59;
272 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
273 
274 	tls12.crypto_info.version = TLS_1_2_VERSION;
275 	tls12.crypto_info.cipher_type = 10;
276 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
277 
278 	tls12.crypto_info.version = TLS_1_2_VERSION;
279 	tls12.crypto_info.cipher_type = 70;
280 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
281 }
282 
TEST_F(tls_basic,recseq_wrap)283 TEST_F(tls_basic, recseq_wrap)
284 {
285 	struct tls_crypto_info_keys tls12;
286 	char const *test_str = "test_read";
287 	int send_len = 10;
288 
289 	if (self->notls)
290 		SKIP(return, "no TLS support");
291 
292 	tls_crypto_info_init(TLS_1_2_VERSION, TLS_CIPHER_AES_GCM_128, &tls12, 0);
293 	memset(&tls12.aes128.rec_seq, 0xff, sizeof(tls12.aes128.rec_seq));
294 
295 	ASSERT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
296 	ASSERT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
297 
298 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), -1);
299 	EXPECT_EQ(errno, EBADMSG);
300 }
301 
FIXTURE(tls)302 FIXTURE(tls)
303 {
304 	int fd, cfd;
305 	bool notls;
306 };
307 
FIXTURE_VARIANT(tls)308 FIXTURE_VARIANT(tls)
309 {
310 	uint16_t tls_version;
311 	uint16_t cipher_type;
312 	bool nopad, fips_non_compliant;
313 };
314 
315 FIXTURE_VARIANT_ADD(tls, 12_aes_gcm)
316 {
317 	.tls_version = TLS_1_2_VERSION,
318 	.cipher_type = TLS_CIPHER_AES_GCM_128,
319 };
320 
321 FIXTURE_VARIANT_ADD(tls, 13_aes_gcm)
322 {
323 	.tls_version = TLS_1_3_VERSION,
324 	.cipher_type = TLS_CIPHER_AES_GCM_128,
325 };
326 
327 FIXTURE_VARIANT_ADD(tls, 12_chacha)
328 {
329 	.tls_version = TLS_1_2_VERSION,
330 	.cipher_type = TLS_CIPHER_CHACHA20_POLY1305,
331 	.fips_non_compliant = true,
332 };
333 
334 FIXTURE_VARIANT_ADD(tls, 13_chacha)
335 {
336 	.tls_version = TLS_1_3_VERSION,
337 	.cipher_type = TLS_CIPHER_CHACHA20_POLY1305,
338 	.fips_non_compliant = true,
339 };
340 
341 FIXTURE_VARIANT_ADD(tls, 13_sm4_gcm)
342 {
343 	.tls_version = TLS_1_3_VERSION,
344 	.cipher_type = TLS_CIPHER_SM4_GCM,
345 	.fips_non_compliant = true,
346 };
347 
348 FIXTURE_VARIANT_ADD(tls, 13_sm4_ccm)
349 {
350 	.tls_version = TLS_1_3_VERSION,
351 	.cipher_type = TLS_CIPHER_SM4_CCM,
352 	.fips_non_compliant = true,
353 };
354 
355 FIXTURE_VARIANT_ADD(tls, 12_aes_ccm)
356 {
357 	.tls_version = TLS_1_2_VERSION,
358 	.cipher_type = TLS_CIPHER_AES_CCM_128,
359 };
360 
361 FIXTURE_VARIANT_ADD(tls, 13_aes_ccm)
362 {
363 	.tls_version = TLS_1_3_VERSION,
364 	.cipher_type = TLS_CIPHER_AES_CCM_128,
365 };
366 
367 FIXTURE_VARIANT_ADD(tls, 12_aes_gcm_256)
368 {
369 	.tls_version = TLS_1_2_VERSION,
370 	.cipher_type = TLS_CIPHER_AES_GCM_256,
371 };
372 
373 FIXTURE_VARIANT_ADD(tls, 13_aes_gcm_256)
374 {
375 	.tls_version = TLS_1_3_VERSION,
376 	.cipher_type = TLS_CIPHER_AES_GCM_256,
377 };
378 
379 FIXTURE_VARIANT_ADD(tls, 13_nopad)
380 {
381 	.tls_version = TLS_1_3_VERSION,
382 	.cipher_type = TLS_CIPHER_AES_GCM_128,
383 	.nopad = true,
384 };
385 
386 FIXTURE_VARIANT_ADD(tls, 12_aria_gcm)
387 {
388 	.tls_version = TLS_1_2_VERSION,
389 	.cipher_type = TLS_CIPHER_ARIA_GCM_128,
390 };
391 
392 FIXTURE_VARIANT_ADD(tls, 12_aria_gcm_256)
393 {
394 	.tls_version = TLS_1_2_VERSION,
395 	.cipher_type = TLS_CIPHER_ARIA_GCM_256,
396 };
397 
FIXTURE_SETUP(tls)398 FIXTURE_SETUP(tls)
399 {
400 	struct tls_crypto_info_keys tls12;
401 	int one = 1;
402 	int ret;
403 
404 	if (fips_enabled && variant->fips_non_compliant)
405 		SKIP(return, "Unsupported cipher in FIPS mode");
406 
407 	tls_crypto_info_init(variant->tls_version, variant->cipher_type,
408 			     &tls12, 0);
409 
410 	ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
411 
412 	if (self->notls)
413 		return;
414 
415 	ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
416 	ASSERT_EQ(ret, 0);
417 
418 	ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len);
419 	ASSERT_EQ(ret, 0);
420 
421 	if (variant->nopad) {
422 		ret = setsockopt(self->cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
423 				 (void *)&one, sizeof(one));
424 		ASSERT_EQ(ret, 0);
425 	}
426 }
427 
FIXTURE_TEARDOWN(tls)428 FIXTURE_TEARDOWN(tls)
429 {
430 	close(self->fd);
431 	close(self->cfd);
432 }
433 
TEST_F(tls,sendfile)434 TEST_F(tls, sendfile)
435 {
436 	int filefd = open("/proc/self/exe", O_RDONLY);
437 	struct stat st;
438 
439 	EXPECT_GE(filefd, 0);
440 	fstat(filefd, &st);
441 	EXPECT_GE(sendfile(self->fd, filefd, 0, st.st_size), 0);
442 }
443 
TEST_F(tls,send_then_sendfile)444 TEST_F(tls, send_then_sendfile)
445 {
446 	int filefd = open("/proc/self/exe", O_RDONLY);
447 	char const *test_str = "test_send";
448 	int to_send = strlen(test_str) + 1;
449 	char recv_buf[10];
450 	struct stat st;
451 	char *buf;
452 
453 	EXPECT_GE(filefd, 0);
454 	fstat(filefd, &st);
455 	buf = (char *)malloc(st.st_size);
456 
457 	EXPECT_EQ(send(self->fd, test_str, to_send, 0), to_send);
458 	EXPECT_EQ(recv(self->cfd, recv_buf, to_send, MSG_WAITALL), to_send);
459 	EXPECT_EQ(memcmp(test_str, recv_buf, to_send), 0);
460 
461 	EXPECT_GE(sendfile(self->fd, filefd, 0, st.st_size), 0);
462 	EXPECT_EQ(recv(self->cfd, buf, st.st_size, MSG_WAITALL), st.st_size);
463 }
464 
chunked_sendfile(struct __test_metadata * _metadata,struct _test_data_tls * self,uint16_t chunk_size,uint16_t extra_payload_size)465 static void chunked_sendfile(struct __test_metadata *_metadata,
466 			     struct _test_data_tls *self,
467 			     uint16_t chunk_size,
468 			     uint16_t extra_payload_size)
469 {
470 	char buf[TLS_PAYLOAD_MAX_LEN];
471 	uint16_t test_payload_size;
472 	int size = 0;
473 	int ret;
474 	char filename[] = "/tmp/mytemp.XXXXXX";
475 	int fd = mkstemp(filename);
476 	off_t offset = 0;
477 
478 	unlink(filename);
479 	ASSERT_GE(fd, 0);
480 	EXPECT_GE(chunk_size, 1);
481 	test_payload_size = chunk_size + extra_payload_size;
482 	ASSERT_GE(TLS_PAYLOAD_MAX_LEN, test_payload_size);
483 	memset(buf, 1, test_payload_size);
484 	size = write(fd, buf, test_payload_size);
485 	EXPECT_EQ(size, test_payload_size);
486 	fsync(fd);
487 
488 	while (size > 0) {
489 		ret = sendfile(self->fd, fd, &offset, chunk_size);
490 		EXPECT_GE(ret, 0);
491 		size -= ret;
492 	}
493 
494 	EXPECT_EQ(recv(self->cfd, buf, test_payload_size, MSG_WAITALL),
495 		  test_payload_size);
496 
497 	close(fd);
498 }
499 
TEST_F(tls,multi_chunk_sendfile)500 TEST_F(tls, multi_chunk_sendfile)
501 {
502 	chunked_sendfile(_metadata, self, 4096, 4096);
503 	chunked_sendfile(_metadata, self, 4096, 0);
504 	chunked_sendfile(_metadata, self, 4096, 1);
505 	chunked_sendfile(_metadata, self, 4096, 2048);
506 	chunked_sendfile(_metadata, self, 8192, 2048);
507 	chunked_sendfile(_metadata, self, 4096, 8192);
508 	chunked_sendfile(_metadata, self, 8192, 4096);
509 	chunked_sendfile(_metadata, self, 12288, 1024);
510 	chunked_sendfile(_metadata, self, 12288, 2000);
511 	chunked_sendfile(_metadata, self, 15360, 100);
512 	chunked_sendfile(_metadata, self, 15360, 300);
513 	chunked_sendfile(_metadata, self, 1, 4096);
514 	chunked_sendfile(_metadata, self, 2048, 4096);
515 	chunked_sendfile(_metadata, self, 2048, 8192);
516 	chunked_sendfile(_metadata, self, 4096, 8192);
517 	chunked_sendfile(_metadata, self, 1024, 12288);
518 	chunked_sendfile(_metadata, self, 2000, 12288);
519 	chunked_sendfile(_metadata, self, 100, 15360);
520 	chunked_sendfile(_metadata, self, 300, 15360);
521 }
522 
TEST_F(tls,recv_max)523 TEST_F(tls, recv_max)
524 {
525 	unsigned int send_len = TLS_PAYLOAD_MAX_LEN;
526 	char recv_mem[TLS_PAYLOAD_MAX_LEN];
527 	char buf[TLS_PAYLOAD_MAX_LEN];
528 
529 	memrnd(buf, sizeof(buf));
530 
531 	EXPECT_GE(send(self->fd, buf, send_len, 0), 0);
532 	EXPECT_NE(recv(self->cfd, recv_mem, send_len, 0), -1);
533 	EXPECT_EQ(memcmp(buf, recv_mem, send_len), 0);
534 }
535 
TEST_F(tls,recv_small)536 TEST_F(tls, recv_small)
537 {
538 	char const *test_str = "test_read";
539 	int send_len = 10;
540 	char buf[10];
541 
542 	send_len = strlen(test_str) + 1;
543 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
544 	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
545 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
546 }
547 
TEST_F(tls,msg_more)548 TEST_F(tls, msg_more)
549 {
550 	char const *test_str = "test_read";
551 	int send_len = 10;
552 	char buf[10 * 2];
553 
554 	EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
555 	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_DONTWAIT), -1);
556 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
557 	EXPECT_EQ(recv(self->cfd, buf, send_len * 2, MSG_WAITALL),
558 		  send_len * 2);
559 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
560 }
561 
TEST_F(tls,msg_more_unsent)562 TEST_F(tls, msg_more_unsent)
563 {
564 	char const *test_str = "test_read";
565 	int send_len = 10;
566 	char buf[10];
567 
568 	EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
569 	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_DONTWAIT), -1);
570 }
571 
TEST_F(tls,msg_eor)572 TEST_F(tls, msg_eor)
573 {
574 	char const *test_str = "test_read";
575 	int send_len = 10;
576 	char buf[10];
577 
578 	EXPECT_EQ(send(self->fd, test_str, send_len, MSG_EOR), send_len);
579 	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
580 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
581 }
582 
TEST_F(tls,sendmsg_single)583 TEST_F(tls, sendmsg_single)
584 {
585 	struct msghdr msg;
586 
587 	char const *test_str = "test_sendmsg";
588 	size_t send_len = 13;
589 	struct iovec vec;
590 	char buf[13];
591 
592 	vec.iov_base = (char *)test_str;
593 	vec.iov_len = send_len;
594 	memset(&msg, 0, sizeof(struct msghdr));
595 	msg.msg_iov = &vec;
596 	msg.msg_iovlen = 1;
597 	EXPECT_EQ(sendmsg(self->fd, &msg, 0), send_len);
598 	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
599 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
600 }
601 
602 #define MAX_FRAGS	64
603 #define SEND_LEN	13
TEST_F(tls,sendmsg_fragmented)604 TEST_F(tls, sendmsg_fragmented)
605 {
606 	char const *test_str = "test_sendmsg";
607 	char buf[SEND_LEN * MAX_FRAGS];
608 	struct iovec vec[MAX_FRAGS];
609 	struct msghdr msg;
610 	int i, frags;
611 
612 	for (frags = 1; frags <= MAX_FRAGS; frags++) {
613 		for (i = 0; i < frags; i++) {
614 			vec[i].iov_base = (char *)test_str;
615 			vec[i].iov_len = SEND_LEN;
616 		}
617 
618 		memset(&msg, 0, sizeof(struct msghdr));
619 		msg.msg_iov = vec;
620 		msg.msg_iovlen = frags;
621 
622 		EXPECT_EQ(sendmsg(self->fd, &msg, 0), SEND_LEN * frags);
623 		EXPECT_EQ(recv(self->cfd, buf, SEND_LEN * frags, MSG_WAITALL),
624 			  SEND_LEN * frags);
625 
626 		for (i = 0; i < frags; i++)
627 			EXPECT_EQ(memcmp(buf + SEND_LEN * i,
628 					 test_str, SEND_LEN), 0);
629 	}
630 }
631 #undef MAX_FRAGS
632 #undef SEND_LEN
633 
TEST_F(tls,sendmsg_large)634 TEST_F(tls, sendmsg_large)
635 {
636 	void *mem = malloc(16384);
637 	size_t send_len = 16384;
638 	size_t sends = 128;
639 	struct msghdr msg;
640 	size_t recvs = 0;
641 	size_t sent = 0;
642 
643 	memset(&msg, 0, sizeof(struct msghdr));
644 	while (sent++ < sends) {
645 		struct iovec vec = { (void *)mem, send_len };
646 
647 		msg.msg_iov = &vec;
648 		msg.msg_iovlen = 1;
649 		EXPECT_EQ(sendmsg(self->fd, &msg, 0), send_len);
650 	}
651 
652 	while (recvs++ < sends) {
653 		EXPECT_NE(recv(self->cfd, mem, send_len, 0), -1);
654 	}
655 
656 	free(mem);
657 }
658 
TEST_F(tls,sendmsg_multiple)659 TEST_F(tls, sendmsg_multiple)
660 {
661 	char const *test_str = "test_sendmsg_multiple";
662 	struct iovec vec[5];
663 	char *test_strs[5];
664 	struct msghdr msg;
665 	int total_len = 0;
666 	int len_cmp = 0;
667 	int iov_len = 5;
668 	char *buf;
669 	int i;
670 
671 	memset(&msg, 0, sizeof(struct msghdr));
672 	for (i = 0; i < iov_len; i++) {
673 		test_strs[i] = (char *)malloc(strlen(test_str) + 1);
674 		snprintf(test_strs[i], strlen(test_str) + 1, "%s", test_str);
675 		vec[i].iov_base = (void *)test_strs[i];
676 		vec[i].iov_len = strlen(test_strs[i]) + 1;
677 		total_len += vec[i].iov_len;
678 	}
679 	msg.msg_iov = vec;
680 	msg.msg_iovlen = iov_len;
681 
682 	EXPECT_EQ(sendmsg(self->fd, &msg, 0), total_len);
683 	buf = malloc(total_len);
684 	EXPECT_NE(recv(self->cfd, buf, total_len, 0), -1);
685 	for (i = 0; i < iov_len; i++) {
686 		EXPECT_EQ(memcmp(test_strs[i], buf + len_cmp,
687 				 strlen(test_strs[i])),
688 			  0);
689 		len_cmp += strlen(buf + len_cmp) + 1;
690 	}
691 	for (i = 0; i < iov_len; i++)
692 		free(test_strs[i]);
693 	free(buf);
694 }
695 
TEST_F(tls,sendmsg_multiple_stress)696 TEST_F(tls, sendmsg_multiple_stress)
697 {
698 	char const *test_str = "abcdefghijklmno";
699 	struct iovec vec[1024];
700 	char *test_strs[1024];
701 	int iov_len = 1024;
702 	int total_len = 0;
703 	char buf[1 << 14];
704 	struct msghdr msg;
705 	int len_cmp = 0;
706 	int i;
707 
708 	memset(&msg, 0, sizeof(struct msghdr));
709 	for (i = 0; i < iov_len; i++) {
710 		test_strs[i] = (char *)malloc(strlen(test_str) + 1);
711 		snprintf(test_strs[i], strlen(test_str) + 1, "%s", test_str);
712 		vec[i].iov_base = (void *)test_strs[i];
713 		vec[i].iov_len = strlen(test_strs[i]) + 1;
714 		total_len += vec[i].iov_len;
715 	}
716 	msg.msg_iov = vec;
717 	msg.msg_iovlen = iov_len;
718 
719 	EXPECT_EQ(sendmsg(self->fd, &msg, 0), total_len);
720 	EXPECT_NE(recv(self->cfd, buf, total_len, 0), -1);
721 
722 	for (i = 0; i < iov_len; i++)
723 		len_cmp += strlen(buf + len_cmp) + 1;
724 
725 	for (i = 0; i < iov_len; i++)
726 		free(test_strs[i]);
727 }
728 
TEST_F(tls,splice_from_pipe)729 TEST_F(tls, splice_from_pipe)
730 {
731 	int send_len = TLS_PAYLOAD_MAX_LEN;
732 	char mem_send[TLS_PAYLOAD_MAX_LEN];
733 	char mem_recv[TLS_PAYLOAD_MAX_LEN];
734 	int p[2];
735 
736 	ASSERT_GE(pipe(p), 0);
737 	EXPECT_GE(write(p[1], mem_send, send_len), 0);
738 	EXPECT_GE(splice(p[0], NULL, self->fd, NULL, send_len, 0), 0);
739 	EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
740 	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
741 }
742 
TEST_F(tls,splice_more)743 TEST_F(tls, splice_more)
744 {
745 	unsigned int f = SPLICE_F_NONBLOCK | SPLICE_F_MORE | SPLICE_F_GIFT;
746 	int send_len = TLS_PAYLOAD_MAX_LEN;
747 	char mem_send[TLS_PAYLOAD_MAX_LEN];
748 	int i, send_pipe = 1;
749 	int p[2];
750 
751 	ASSERT_GE(pipe(p), 0);
752 	EXPECT_GE(write(p[1], mem_send, send_len), 0);
753 	for (i = 0; i < 32; i++)
754 		EXPECT_EQ(splice(p[0], NULL, self->fd, NULL, send_pipe, f), 1);
755 }
756 
TEST_F(tls,splice_from_pipe2)757 TEST_F(tls, splice_from_pipe2)
758 {
759 	int send_len = 16000;
760 	char mem_send[16000];
761 	char mem_recv[16000];
762 	int p2[2];
763 	int p[2];
764 
765 	memrnd(mem_send, sizeof(mem_send));
766 
767 	ASSERT_GE(pipe(p), 0);
768 	ASSERT_GE(pipe(p2), 0);
769 	EXPECT_EQ(write(p[1], mem_send, 8000), 8000);
770 	EXPECT_EQ(splice(p[0], NULL, self->fd, NULL, 8000, 0), 8000);
771 	EXPECT_EQ(write(p2[1], mem_send + 8000, 8000), 8000);
772 	EXPECT_EQ(splice(p2[0], NULL, self->fd, NULL, 8000, 0), 8000);
773 	EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
774 	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
775 }
776 
TEST_F(tls,send_and_splice)777 TEST_F(tls, send_and_splice)
778 {
779 	int send_len = TLS_PAYLOAD_MAX_LEN;
780 	char mem_send[TLS_PAYLOAD_MAX_LEN];
781 	char mem_recv[TLS_PAYLOAD_MAX_LEN];
782 	char const *test_str = "test_read";
783 	int send_len2 = 10;
784 	char buf[10];
785 	int p[2];
786 
787 	ASSERT_GE(pipe(p), 0);
788 	EXPECT_EQ(send(self->fd, test_str, send_len2, 0), send_len2);
789 	EXPECT_EQ(recv(self->cfd, buf, send_len2, MSG_WAITALL), send_len2);
790 	EXPECT_EQ(memcmp(test_str, buf, send_len2), 0);
791 
792 	EXPECT_GE(write(p[1], mem_send, send_len), send_len);
793 	EXPECT_GE(splice(p[0], NULL, self->fd, NULL, send_len, 0), send_len);
794 
795 	EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
796 	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
797 }
798 
TEST_F(tls,splice_to_pipe)799 TEST_F(tls, splice_to_pipe)
800 {
801 	int send_len = TLS_PAYLOAD_MAX_LEN;
802 	char mem_send[TLS_PAYLOAD_MAX_LEN];
803 	char mem_recv[TLS_PAYLOAD_MAX_LEN];
804 	int p[2];
805 
806 	memrnd(mem_send, sizeof(mem_send));
807 
808 	ASSERT_GE(pipe(p), 0);
809 	EXPECT_EQ(send(self->fd, mem_send, send_len, 0), send_len);
810 	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, send_len, 0), send_len);
811 	EXPECT_EQ(read(p[0], mem_recv, send_len), send_len);
812 	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
813 }
814 
TEST_F(tls,splice_cmsg_to_pipe)815 TEST_F(tls, splice_cmsg_to_pipe)
816 {
817 	char *test_str = "test_read";
818 	char record_type = 100;
819 	int send_len = 10;
820 	char buf[10];
821 	int p[2];
822 
823 	if (self->notls)
824 		SKIP(return, "no TLS support");
825 
826 	ASSERT_GE(pipe(p), 0);
827 	EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
828 	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, send_len, 0), -1);
829 	EXPECT_EQ(errno, EINVAL);
830 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
831 	EXPECT_EQ(errno, EIO);
832 	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
833 				buf, sizeof(buf), MSG_WAITALL),
834 		  send_len);
835 	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
836 }
837 
TEST_F(tls,splice_dec_cmsg_to_pipe)838 TEST_F(tls, splice_dec_cmsg_to_pipe)
839 {
840 	char *test_str = "test_read";
841 	char record_type = 100;
842 	int send_len = 10;
843 	char buf[10];
844 	int p[2];
845 
846 	if (self->notls)
847 		SKIP(return, "no TLS support");
848 
849 	ASSERT_GE(pipe(p), 0);
850 	EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
851 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
852 	EXPECT_EQ(errno, EIO);
853 	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, send_len, 0), -1);
854 	EXPECT_EQ(errno, EINVAL);
855 	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
856 				buf, sizeof(buf), MSG_WAITALL),
857 		  send_len);
858 	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
859 }
860 
TEST_F(tls,recv_and_splice)861 TEST_F(tls, recv_and_splice)
862 {
863 	int send_len = TLS_PAYLOAD_MAX_LEN;
864 	char mem_send[TLS_PAYLOAD_MAX_LEN];
865 	char mem_recv[TLS_PAYLOAD_MAX_LEN];
866 	int half = send_len / 2;
867 	int p[2];
868 
869 	ASSERT_GE(pipe(p), 0);
870 	EXPECT_EQ(send(self->fd, mem_send, send_len, 0), send_len);
871 	/* Recv hald of the record, splice the other half */
872 	EXPECT_EQ(recv(self->cfd, mem_recv, half, MSG_WAITALL), half);
873 	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, half, SPLICE_F_NONBLOCK),
874 		  half);
875 	EXPECT_EQ(read(p[0], &mem_recv[half], half), half);
876 	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
877 }
878 
TEST_F(tls,peek_and_splice)879 TEST_F(tls, peek_and_splice)
880 {
881 	int send_len = TLS_PAYLOAD_MAX_LEN;
882 	char mem_send[TLS_PAYLOAD_MAX_LEN];
883 	char mem_recv[TLS_PAYLOAD_MAX_LEN];
884 	int chunk = TLS_PAYLOAD_MAX_LEN / 4;
885 	int n, i, p[2];
886 
887 	memrnd(mem_send, sizeof(mem_send));
888 
889 	ASSERT_GE(pipe(p), 0);
890 	for (i = 0; i < 4; i++)
891 		EXPECT_EQ(send(self->fd, &mem_send[chunk * i], chunk, 0),
892 			  chunk);
893 
894 	EXPECT_EQ(recv(self->cfd, mem_recv, chunk * 5 / 2,
895 		       MSG_WAITALL | MSG_PEEK),
896 		  chunk * 5 / 2);
897 	EXPECT_EQ(memcmp(mem_send, mem_recv, chunk * 5 / 2), 0);
898 
899 	n = 0;
900 	while (n < send_len) {
901 		i = splice(self->cfd, NULL, p[1], NULL, send_len - n, 0);
902 		EXPECT_GT(i, 0);
903 		n += i;
904 	}
905 	EXPECT_EQ(n, send_len);
906 	EXPECT_EQ(read(p[0], mem_recv, send_len), send_len);
907 	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
908 }
909 
TEST_F(tls,recvmsg_single)910 TEST_F(tls, recvmsg_single)
911 {
912 	char const *test_str = "test_recvmsg_single";
913 	int send_len = strlen(test_str) + 1;
914 	char buf[20];
915 	struct msghdr hdr;
916 	struct iovec vec;
917 
918 	memset(&hdr, 0, sizeof(hdr));
919 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
920 	vec.iov_base = (char *)buf;
921 	vec.iov_len = send_len;
922 	hdr.msg_iovlen = 1;
923 	hdr.msg_iov = &vec;
924 	EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
925 	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
926 }
927 
TEST_F(tls,recvmsg_single_max)928 TEST_F(tls, recvmsg_single_max)
929 {
930 	int send_len = TLS_PAYLOAD_MAX_LEN;
931 	char send_mem[TLS_PAYLOAD_MAX_LEN];
932 	char recv_mem[TLS_PAYLOAD_MAX_LEN];
933 	struct iovec vec;
934 	struct msghdr hdr;
935 
936 	memrnd(send_mem, sizeof(send_mem));
937 
938 	EXPECT_EQ(send(self->fd, send_mem, send_len, 0), send_len);
939 	vec.iov_base = (char *)recv_mem;
940 	vec.iov_len = TLS_PAYLOAD_MAX_LEN;
941 
942 	hdr.msg_iovlen = 1;
943 	hdr.msg_iov = &vec;
944 	EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
945 	EXPECT_EQ(memcmp(send_mem, recv_mem, send_len), 0);
946 }
947 
TEST_F(tls,recvmsg_multiple)948 TEST_F(tls, recvmsg_multiple)
949 {
950 	unsigned int msg_iovlen = 1024;
951 	struct iovec vec[1024];
952 	char *iov_base[1024];
953 	unsigned int iov_len = 16;
954 	int send_len = 1 << 14;
955 	char buf[1 << 14];
956 	struct msghdr hdr;
957 	int i;
958 
959 	memrnd(buf, sizeof(buf));
960 
961 	EXPECT_EQ(send(self->fd, buf, send_len, 0), send_len);
962 	for (i = 0; i < msg_iovlen; i++) {
963 		iov_base[i] = (char *)malloc(iov_len);
964 		vec[i].iov_base = iov_base[i];
965 		vec[i].iov_len = iov_len;
966 	}
967 
968 	hdr.msg_iovlen = msg_iovlen;
969 	hdr.msg_iov = vec;
970 	EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
971 
972 	for (i = 0; i < msg_iovlen; i++)
973 		free(iov_base[i]);
974 }
975 
TEST_F(tls,single_send_multiple_recv)976 TEST_F(tls, single_send_multiple_recv)
977 {
978 	unsigned int total_len = TLS_PAYLOAD_MAX_LEN * 2;
979 	unsigned int send_len = TLS_PAYLOAD_MAX_LEN;
980 	char send_mem[TLS_PAYLOAD_MAX_LEN * 2];
981 	char recv_mem[TLS_PAYLOAD_MAX_LEN * 2];
982 
983 	memrnd(send_mem, sizeof(send_mem));
984 
985 	EXPECT_GE(send(self->fd, send_mem, total_len, 0), 0);
986 	memset(recv_mem, 0, total_len);
987 
988 	EXPECT_NE(recv(self->cfd, recv_mem, send_len, 0), -1);
989 	EXPECT_NE(recv(self->cfd, recv_mem + send_len, send_len, 0), -1);
990 	EXPECT_EQ(memcmp(send_mem, recv_mem, total_len), 0);
991 }
992 
TEST_F(tls,multiple_send_single_recv)993 TEST_F(tls, multiple_send_single_recv)
994 {
995 	unsigned int total_len = 2 * 10;
996 	unsigned int send_len = 10;
997 	char recv_mem[2 * 10];
998 	char send_mem[10];
999 
1000 	memrnd(send_mem, sizeof(send_mem));
1001 
1002 	EXPECT_GE(send(self->fd, send_mem, send_len, 0), 0);
1003 	EXPECT_GE(send(self->fd, send_mem, send_len, 0), 0);
1004 	memset(recv_mem, 0, total_len);
1005 	EXPECT_EQ(recv(self->cfd, recv_mem, total_len, MSG_WAITALL), total_len);
1006 
1007 	EXPECT_EQ(memcmp(send_mem, recv_mem, send_len), 0);
1008 	EXPECT_EQ(memcmp(send_mem, recv_mem + send_len, send_len), 0);
1009 }
1010 
TEST_F(tls,single_send_multiple_recv_non_align)1011 TEST_F(tls, single_send_multiple_recv_non_align)
1012 {
1013 	const unsigned int total_len = 15;
1014 	const unsigned int recv_len = 10;
1015 	char recv_mem[recv_len * 2];
1016 	char send_mem[total_len];
1017 
1018 	memrnd(send_mem, sizeof(send_mem));
1019 
1020 	EXPECT_GE(send(self->fd, send_mem, total_len, 0), 0);
1021 	memset(recv_mem, 0, total_len);
1022 
1023 	EXPECT_EQ(recv(self->cfd, recv_mem, recv_len, 0), recv_len);
1024 	EXPECT_EQ(recv(self->cfd, recv_mem + recv_len, recv_len, 0), 5);
1025 	EXPECT_EQ(memcmp(send_mem, recv_mem, total_len), 0);
1026 }
1027 
TEST_F(tls,recv_partial)1028 TEST_F(tls, recv_partial)
1029 {
1030 	char const *test_str = "test_read_partial";
1031 	char const *test_str_first = "test_read";
1032 	char const *test_str_second = "_partial";
1033 	int send_len = strlen(test_str) + 1;
1034 	char recv_mem[18];
1035 
1036 	memset(recv_mem, 0, sizeof(recv_mem));
1037 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1038 	EXPECT_EQ(recv(self->cfd, recv_mem, strlen(test_str_first),
1039 		       MSG_WAITALL), strlen(test_str_first));
1040 	EXPECT_EQ(memcmp(test_str_first, recv_mem, strlen(test_str_first)), 0);
1041 	memset(recv_mem, 0, sizeof(recv_mem));
1042 	EXPECT_EQ(recv(self->cfd, recv_mem, strlen(test_str_second),
1043 		       MSG_WAITALL), strlen(test_str_second));
1044 	EXPECT_EQ(memcmp(test_str_second, recv_mem, strlen(test_str_second)),
1045 		  0);
1046 }
1047 
TEST_F(tls,recv_nonblock)1048 TEST_F(tls, recv_nonblock)
1049 {
1050 	char buf[4096];
1051 	bool err;
1052 
1053 	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_DONTWAIT), -1);
1054 	err = (errno == EAGAIN || errno == EWOULDBLOCK);
1055 	EXPECT_EQ(err, true);
1056 }
1057 
TEST_F(tls,recv_peek)1058 TEST_F(tls, recv_peek)
1059 {
1060 	char const *test_str = "test_read_peek";
1061 	int send_len = strlen(test_str) + 1;
1062 	char buf[15];
1063 
1064 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1065 	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_PEEK), send_len);
1066 	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
1067 	memset(buf, 0, sizeof(buf));
1068 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), send_len);
1069 	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
1070 }
1071 
TEST_F(tls,recv_peek_multiple)1072 TEST_F(tls, recv_peek_multiple)
1073 {
1074 	char const *test_str = "test_read_peek";
1075 	int send_len = strlen(test_str) + 1;
1076 	unsigned int num_peeks = 100;
1077 	char buf[15];
1078 	int i;
1079 
1080 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1081 	for (i = 0; i < num_peeks; i++) {
1082 		EXPECT_NE(recv(self->cfd, buf, send_len, MSG_PEEK), -1);
1083 		EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
1084 		memset(buf, 0, sizeof(buf));
1085 	}
1086 	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1087 	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
1088 }
1089 
TEST_F(tls,recv_peek_multiple_records)1090 TEST_F(tls, recv_peek_multiple_records)
1091 {
1092 	char const *test_str = "test_read_peek_mult_recs";
1093 	char const *test_str_first = "test_read_peek";
1094 	char const *test_str_second = "_mult_recs";
1095 	int len;
1096 	char buf[64];
1097 
1098 	len = strlen(test_str_first);
1099 	EXPECT_EQ(send(self->fd, test_str_first, len, 0), len);
1100 
1101 	len = strlen(test_str_second) + 1;
1102 	EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
1103 
1104 	len = strlen(test_str_first);
1105 	memset(buf, 0, len);
1106 	EXPECT_EQ(recv(self->cfd, buf, len, MSG_PEEK | MSG_WAITALL), len);
1107 
1108 	/* MSG_PEEK can only peek into the current record. */
1109 	len = strlen(test_str_first);
1110 	EXPECT_EQ(memcmp(test_str_first, buf, len), 0);
1111 
1112 	len = strlen(test_str) + 1;
1113 	memset(buf, 0, len);
1114 	EXPECT_EQ(recv(self->cfd, buf, len, MSG_WAITALL), len);
1115 
1116 	/* Non-MSG_PEEK will advance strparser (and therefore record)
1117 	 * however.
1118 	 */
1119 	len = strlen(test_str) + 1;
1120 	EXPECT_EQ(memcmp(test_str, buf, len), 0);
1121 
1122 	/* MSG_MORE will hold current record open, so later MSG_PEEK
1123 	 * will see everything.
1124 	 */
1125 	len = strlen(test_str_first);
1126 	EXPECT_EQ(send(self->fd, test_str_first, len, MSG_MORE), len);
1127 
1128 	len = strlen(test_str_second) + 1;
1129 	EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
1130 
1131 	len = strlen(test_str) + 1;
1132 	memset(buf, 0, len);
1133 	EXPECT_EQ(recv(self->cfd, buf, len, MSG_PEEK | MSG_WAITALL), len);
1134 
1135 	len = strlen(test_str) + 1;
1136 	EXPECT_EQ(memcmp(test_str, buf, len), 0);
1137 }
1138 
TEST_F(tls,recv_peek_large_buf_mult_recs)1139 TEST_F(tls, recv_peek_large_buf_mult_recs)
1140 {
1141 	char const *test_str = "test_read_peek_mult_recs";
1142 	char const *test_str_first = "test_read_peek";
1143 	char const *test_str_second = "_mult_recs";
1144 	int len;
1145 	char buf[64];
1146 
1147 	len = strlen(test_str_first);
1148 	EXPECT_EQ(send(self->fd, test_str_first, len, 0), len);
1149 
1150 	len = strlen(test_str_second) + 1;
1151 	EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
1152 
1153 	len = strlen(test_str) + 1;
1154 	memset(buf, 0, len);
1155 	EXPECT_NE((len = recv(self->cfd, buf, len,
1156 			      MSG_PEEK | MSG_WAITALL)), -1);
1157 	len = strlen(test_str) + 1;
1158 	EXPECT_EQ(memcmp(test_str, buf, len), 0);
1159 }
1160 
TEST_F(tls,recv_lowat)1161 TEST_F(tls, recv_lowat)
1162 {
1163 	char send_mem[10] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
1164 	char recv_mem[20];
1165 	int lowat = 8;
1166 
1167 	EXPECT_EQ(send(self->fd, send_mem, 10, 0), 10);
1168 	EXPECT_EQ(send(self->fd, send_mem, 5, 0), 5);
1169 
1170 	memset(recv_mem, 0, 20);
1171 	EXPECT_EQ(setsockopt(self->cfd, SOL_SOCKET, SO_RCVLOWAT,
1172 			     &lowat, sizeof(lowat)), 0);
1173 	EXPECT_EQ(recv(self->cfd, recv_mem, 1, MSG_WAITALL), 1);
1174 	EXPECT_EQ(recv(self->cfd, recv_mem + 1, 6, MSG_WAITALL), 6);
1175 	EXPECT_EQ(recv(self->cfd, recv_mem + 7, 10, 0), 8);
1176 
1177 	EXPECT_EQ(memcmp(send_mem, recv_mem, 10), 0);
1178 	EXPECT_EQ(memcmp(send_mem, recv_mem + 10, 5), 0);
1179 }
1180 
TEST_F(tls,bidir)1181 TEST_F(tls, bidir)
1182 {
1183 	char const *test_str = "test_read";
1184 	int send_len = 10;
1185 	char buf[10];
1186 	int ret;
1187 
1188 	if (!self->notls) {
1189 		struct tls_crypto_info_keys tls12;
1190 
1191 		tls_crypto_info_init(variant->tls_version, variant->cipher_type,
1192 				     &tls12, 0);
1193 
1194 		ret = setsockopt(self->fd, SOL_TLS, TLS_RX, &tls12,
1195 				 tls12.len);
1196 		ASSERT_EQ(ret, 0);
1197 
1198 		ret = setsockopt(self->cfd, SOL_TLS, TLS_TX, &tls12,
1199 				 tls12.len);
1200 		ASSERT_EQ(ret, 0);
1201 	}
1202 
1203 	ASSERT_EQ(strlen(test_str) + 1, send_len);
1204 
1205 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1206 	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1207 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1208 
1209 	memset(buf, 0, sizeof(buf));
1210 
1211 	EXPECT_EQ(send(self->cfd, test_str, send_len, 0), send_len);
1212 	EXPECT_NE(recv(self->fd, buf, send_len, 0), -1);
1213 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1214 };
1215 
TEST_F(tls,pollin)1216 TEST_F(tls, pollin)
1217 {
1218 	char const *test_str = "test_poll";
1219 	struct pollfd fd = { 0, 0, 0 };
1220 	char buf[10];
1221 	int send_len = 10;
1222 
1223 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1224 	fd.fd = self->cfd;
1225 	fd.events = POLLIN;
1226 
1227 	EXPECT_EQ(poll(&fd, 1, 20), 1);
1228 	EXPECT_EQ(fd.revents & POLLIN, 1);
1229 	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
1230 	/* Test timing out */
1231 	EXPECT_EQ(poll(&fd, 1, 20), 0);
1232 }
1233 
TEST_F(tls,poll_wait)1234 TEST_F(tls, poll_wait)
1235 {
1236 	char const *test_str = "test_poll_wait";
1237 	int send_len = strlen(test_str) + 1;
1238 	struct pollfd fd = { 0, 0, 0 };
1239 	char recv_mem[15];
1240 
1241 	fd.fd = self->cfd;
1242 	fd.events = POLLIN;
1243 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1244 	/* Set timeout to inf. secs */
1245 	EXPECT_EQ(poll(&fd, 1, -1), 1);
1246 	EXPECT_EQ(fd.revents & POLLIN, 1);
1247 	EXPECT_EQ(recv(self->cfd, recv_mem, send_len, MSG_WAITALL), send_len);
1248 }
1249 
TEST_F(tls,poll_wait_split)1250 TEST_F(tls, poll_wait_split)
1251 {
1252 	struct pollfd fd = { 0, 0, 0 };
1253 	char send_mem[20] = {};
1254 	char recv_mem[15];
1255 
1256 	fd.fd = self->cfd;
1257 	fd.events = POLLIN;
1258 	/* Send 20 bytes */
1259 	EXPECT_EQ(send(self->fd, send_mem, sizeof(send_mem), 0),
1260 		  sizeof(send_mem));
1261 	/* Poll with inf. timeout */
1262 	EXPECT_EQ(poll(&fd, 1, -1), 1);
1263 	EXPECT_EQ(fd.revents & POLLIN, 1);
1264 	EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), MSG_WAITALL),
1265 		  sizeof(recv_mem));
1266 
1267 	/* Now the remaining 5 bytes of record data are in TLS ULP */
1268 	fd.fd = self->cfd;
1269 	fd.events = POLLIN;
1270 	EXPECT_EQ(poll(&fd, 1, -1), 1);
1271 	EXPECT_EQ(fd.revents & POLLIN, 1);
1272 	EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), 0),
1273 		  sizeof(send_mem) - sizeof(recv_mem));
1274 }
1275 
TEST_F(tls,blocking)1276 TEST_F(tls, blocking)
1277 {
1278 	size_t data = 100000;
1279 	int res = fork();
1280 
1281 	EXPECT_NE(res, -1);
1282 
1283 	if (res) {
1284 		/* parent */
1285 		size_t left = data;
1286 		char buf[16384];
1287 		int status;
1288 		int pid2;
1289 
1290 		while (left) {
1291 			int res = send(self->fd, buf,
1292 				       left > 16384 ? 16384 : left, 0);
1293 
1294 			EXPECT_GE(res, 0);
1295 			left -= res;
1296 		}
1297 
1298 		pid2 = wait(&status);
1299 		EXPECT_EQ(status, 0);
1300 		EXPECT_EQ(res, pid2);
1301 	} else {
1302 		/* child */
1303 		size_t left = data;
1304 		char buf[16384];
1305 
1306 		while (left) {
1307 			int res = recv(self->cfd, buf,
1308 				       left > 16384 ? 16384 : left, 0);
1309 
1310 			EXPECT_GE(res, 0);
1311 			left -= res;
1312 		}
1313 	}
1314 }
1315 
TEST_F(tls,nonblocking)1316 TEST_F(tls, nonblocking)
1317 {
1318 	size_t data = 100000;
1319 	int sendbuf = 100;
1320 	int flags;
1321 	int res;
1322 
1323 	flags = fcntl(self->fd, F_GETFL, 0);
1324 	fcntl(self->fd, F_SETFL, flags | O_NONBLOCK);
1325 	fcntl(self->cfd, F_SETFL, flags | O_NONBLOCK);
1326 
1327 	/* Ensure nonblocking behavior by imposing a small send
1328 	 * buffer.
1329 	 */
1330 	EXPECT_EQ(setsockopt(self->fd, SOL_SOCKET, SO_SNDBUF,
1331 			     &sendbuf, sizeof(sendbuf)), 0);
1332 
1333 	res = fork();
1334 	EXPECT_NE(res, -1);
1335 
1336 	if (res) {
1337 		/* parent */
1338 		bool eagain = false;
1339 		size_t left = data;
1340 		char buf[16384];
1341 		int status;
1342 		int pid2;
1343 
1344 		while (left) {
1345 			int res = send(self->fd, buf,
1346 				       left > 16384 ? 16384 : left, 0);
1347 
1348 			if (res == -1 && errno == EAGAIN) {
1349 				eagain = true;
1350 				usleep(10000);
1351 				continue;
1352 			}
1353 			EXPECT_GE(res, 0);
1354 			left -= res;
1355 		}
1356 
1357 		EXPECT_TRUE(eagain);
1358 		pid2 = wait(&status);
1359 
1360 		EXPECT_EQ(status, 0);
1361 		EXPECT_EQ(res, pid2);
1362 	} else {
1363 		/* child */
1364 		bool eagain = false;
1365 		size_t left = data;
1366 		char buf[16384];
1367 
1368 		while (left) {
1369 			int res = recv(self->cfd, buf,
1370 				       left > 16384 ? 16384 : left, 0);
1371 
1372 			if (res == -1 && errno == EAGAIN) {
1373 				eagain = true;
1374 				usleep(10000);
1375 				continue;
1376 			}
1377 			EXPECT_GE(res, 0);
1378 			left -= res;
1379 		}
1380 		EXPECT_TRUE(eagain);
1381 	}
1382 }
1383 
1384 static void
test_mutliproc(struct __test_metadata * _metadata,struct _test_data_tls * self,bool sendpg,unsigned int n_readers,unsigned int n_writers)1385 test_mutliproc(struct __test_metadata *_metadata, struct _test_data_tls *self,
1386 	       bool sendpg, unsigned int n_readers, unsigned int n_writers)
1387 {
1388 	const unsigned int n_children = n_readers + n_writers;
1389 	const size_t data = 6 * 1000 * 1000;
1390 	const size_t file_sz = data / 100;
1391 	size_t read_bias, write_bias;
1392 	int i, fd, child_id;
1393 	char buf[file_sz];
1394 	pid_t pid;
1395 
1396 	/* Only allow multiples for simplicity */
1397 	ASSERT_EQ(!(n_readers % n_writers) || !(n_writers % n_readers), true);
1398 	read_bias = n_writers / n_readers ?: 1;
1399 	write_bias = n_readers / n_writers ?: 1;
1400 
1401 	/* prep a file to send */
1402 	fd = open("/tmp/", O_TMPFILE | O_RDWR, 0600);
1403 	ASSERT_GE(fd, 0);
1404 
1405 	memset(buf, 0xac, file_sz);
1406 	ASSERT_EQ(write(fd, buf, file_sz), file_sz);
1407 
1408 	/* spawn children */
1409 	for (child_id = 0; child_id < n_children; child_id++) {
1410 		pid = fork();
1411 		ASSERT_NE(pid, -1);
1412 		if (!pid)
1413 			break;
1414 	}
1415 
1416 	/* parent waits for all children */
1417 	if (pid) {
1418 		for (i = 0; i < n_children; i++) {
1419 			int status;
1420 
1421 			wait(&status);
1422 			EXPECT_EQ(status, 0);
1423 		}
1424 
1425 		return;
1426 	}
1427 
1428 	/* Split threads for reading and writing */
1429 	if (child_id < n_readers) {
1430 		size_t left = data * read_bias;
1431 		char rb[8001];
1432 
1433 		while (left) {
1434 			int res;
1435 
1436 			res = recv(self->cfd, rb,
1437 				   left > sizeof(rb) ? sizeof(rb) : left, 0);
1438 
1439 			EXPECT_GE(res, 0);
1440 			left -= res;
1441 		}
1442 	} else {
1443 		size_t left = data * write_bias;
1444 
1445 		while (left) {
1446 			int res;
1447 
1448 			ASSERT_EQ(lseek(fd, 0, SEEK_SET), 0);
1449 			if (sendpg)
1450 				res = sendfile(self->fd, fd, NULL,
1451 					       left > file_sz ? file_sz : left);
1452 			else
1453 				res = send(self->fd, buf,
1454 					   left > file_sz ? file_sz : left, 0);
1455 
1456 			EXPECT_GE(res, 0);
1457 			left -= res;
1458 		}
1459 	}
1460 }
1461 
TEST_F(tls,mutliproc_even)1462 TEST_F(tls, mutliproc_even)
1463 {
1464 	test_mutliproc(_metadata, self, false, 6, 6);
1465 }
1466 
TEST_F(tls,mutliproc_readers)1467 TEST_F(tls, mutliproc_readers)
1468 {
1469 	test_mutliproc(_metadata, self, false, 4, 12);
1470 }
1471 
TEST_F(tls,mutliproc_writers)1472 TEST_F(tls, mutliproc_writers)
1473 {
1474 	test_mutliproc(_metadata, self, false, 10, 2);
1475 }
1476 
TEST_F(tls,mutliproc_sendpage_even)1477 TEST_F(tls, mutliproc_sendpage_even)
1478 {
1479 	test_mutliproc(_metadata, self, true, 6, 6);
1480 }
1481 
TEST_F(tls,mutliproc_sendpage_readers)1482 TEST_F(tls, mutliproc_sendpage_readers)
1483 {
1484 	test_mutliproc(_metadata, self, true, 4, 12);
1485 }
1486 
TEST_F(tls,mutliproc_sendpage_writers)1487 TEST_F(tls, mutliproc_sendpage_writers)
1488 {
1489 	test_mutliproc(_metadata, self, true, 10, 2);
1490 }
1491 
TEST_F(tls,control_msg)1492 TEST_F(tls, control_msg)
1493 {
1494 	char *test_str = "test_read";
1495 	char record_type = 100;
1496 	int send_len = 10;
1497 	char buf[10];
1498 
1499 	if (self->notls)
1500 		SKIP(return, "no TLS support");
1501 
1502 	EXPECT_EQ(tls_send_cmsg(self->fd, record_type, test_str, send_len, 0),
1503 		  send_len);
1504 	/* Should fail because we didn't provide a control message */
1505 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
1506 
1507 	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
1508 				buf, sizeof(buf), MSG_WAITALL | MSG_PEEK),
1509 		  send_len);
1510 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1511 
1512 	/* Recv the message again without MSG_PEEK */
1513 	memset(buf, 0, sizeof(buf));
1514 
1515 	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
1516 				buf, sizeof(buf), MSG_WAITALL),
1517 		  send_len);
1518 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1519 }
1520 
TEST_F(tls,control_msg_nomerge)1521 TEST_F(tls, control_msg_nomerge)
1522 {
1523 	char *rec1 = "1111";
1524 	char *rec2 = "2222";
1525 	int send_len = 5;
1526 	char buf[15];
1527 
1528 	if (self->notls)
1529 		SKIP(return, "no TLS support");
1530 
1531 	EXPECT_EQ(tls_send_cmsg(self->fd, 100, rec1, send_len, 0), send_len);
1532 	EXPECT_EQ(tls_send_cmsg(self->fd, 100, rec2, send_len, 0), send_len);
1533 
1534 	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, 100, buf, sizeof(buf), MSG_PEEK), send_len);
1535 	EXPECT_EQ(memcmp(buf, rec1, send_len), 0);
1536 
1537 	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, 100, buf, sizeof(buf), MSG_PEEK), send_len);
1538 	EXPECT_EQ(memcmp(buf, rec1, send_len), 0);
1539 
1540 	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, 100, buf, sizeof(buf), 0), send_len);
1541 	EXPECT_EQ(memcmp(buf, rec1, send_len), 0);
1542 
1543 	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, 100, buf, sizeof(buf), 0), send_len);
1544 	EXPECT_EQ(memcmp(buf, rec2, send_len), 0);
1545 }
1546 
TEST_F(tls,data_control_data)1547 TEST_F(tls, data_control_data)
1548 {
1549 	char *rec1 = "1111";
1550 	char *rec2 = "2222";
1551 	char *rec3 = "3333";
1552 	int send_len = 5;
1553 	char buf[15];
1554 
1555 	if (self->notls)
1556 		SKIP(return, "no TLS support");
1557 
1558 	EXPECT_EQ(send(self->fd, rec1, send_len, 0), send_len);
1559 	EXPECT_EQ(tls_send_cmsg(self->fd, 100, rec2, send_len, 0), send_len);
1560 	EXPECT_EQ(send(self->fd, rec3, send_len, 0), send_len);
1561 
1562 	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_PEEK), send_len);
1563 	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_PEEK), send_len);
1564 }
1565 
TEST_F(tls,shutdown)1566 TEST_F(tls, shutdown)
1567 {
1568 	char const *test_str = "test_read";
1569 	int send_len = 10;
1570 	char buf[10];
1571 
1572 	ASSERT_EQ(strlen(test_str) + 1, send_len);
1573 
1574 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1575 	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1576 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1577 
1578 	shutdown(self->fd, SHUT_RDWR);
1579 	shutdown(self->cfd, SHUT_RDWR);
1580 }
1581 
TEST_F(tls,shutdown_unsent)1582 TEST_F(tls, shutdown_unsent)
1583 {
1584 	char const *test_str = "test_read";
1585 	int send_len = 10;
1586 
1587 	EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
1588 
1589 	shutdown(self->fd, SHUT_RDWR);
1590 	shutdown(self->cfd, SHUT_RDWR);
1591 }
1592 
TEST_F(tls,shutdown_reuse)1593 TEST_F(tls, shutdown_reuse)
1594 {
1595 	struct sockaddr_in addr;
1596 	int ret;
1597 
1598 	shutdown(self->fd, SHUT_RDWR);
1599 	shutdown(self->cfd, SHUT_RDWR);
1600 	close(self->cfd);
1601 
1602 	addr.sin_family = AF_INET;
1603 	addr.sin_addr.s_addr = htonl(INADDR_ANY);
1604 	addr.sin_port = 0;
1605 
1606 	ret = bind(self->fd, &addr, sizeof(addr));
1607 	EXPECT_EQ(ret, 0);
1608 	ret = listen(self->fd, 10);
1609 	EXPECT_EQ(ret, -1);
1610 	EXPECT_EQ(errno, EINVAL);
1611 
1612 	ret = connect(self->fd, &addr, sizeof(addr));
1613 	EXPECT_EQ(ret, -1);
1614 	EXPECT_EQ(errno, EISCONN);
1615 }
1616 
TEST_F(tls,getsockopt)1617 TEST_F(tls, getsockopt)
1618 {
1619 	struct tls_crypto_info_keys expect, get;
1620 	socklen_t len;
1621 
1622 	/* get only the version/cipher */
1623 	len = sizeof(struct tls_crypto_info);
1624 	memrnd(&get, sizeof(get));
1625 	EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &get, &len), 0);
1626 	EXPECT_EQ(len, sizeof(struct tls_crypto_info));
1627 	EXPECT_EQ(get.crypto_info.version, variant->tls_version);
1628 	EXPECT_EQ(get.crypto_info.cipher_type, variant->cipher_type);
1629 
1630 	/* get the full crypto_info */
1631 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &expect, 0);
1632 	len = expect.len;
1633 	memrnd(&get, sizeof(get));
1634 	EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &get, &len), 0);
1635 	EXPECT_EQ(len, expect.len);
1636 	EXPECT_EQ(get.crypto_info.version, variant->tls_version);
1637 	EXPECT_EQ(get.crypto_info.cipher_type, variant->cipher_type);
1638 	EXPECT_EQ(memcmp(&get, &expect, expect.len), 0);
1639 
1640 	/* short get should fail */
1641 	len = sizeof(struct tls_crypto_info) - 1;
1642 	EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &get, &len), -1);
1643 	EXPECT_EQ(errno, EINVAL);
1644 
1645 	/* partial get of the cipher data should fail */
1646 	len = expect.len - 1;
1647 	EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &get, &len), -1);
1648 	EXPECT_EQ(errno, EINVAL);
1649 }
1650 
TEST_F(tls,recv_efault)1651 TEST_F(tls, recv_efault)
1652 {
1653 	char *rec1 = "1111111111";
1654 	char *rec2 = "2222222222";
1655 	struct msghdr hdr = {};
1656 	struct iovec iov[2];
1657 	char recv_mem[12];
1658 	int ret;
1659 
1660 	if (self->notls)
1661 		SKIP(return, "no TLS support");
1662 
1663 	EXPECT_EQ(send(self->fd, rec1, 10, 0), 10);
1664 	EXPECT_EQ(send(self->fd, rec2, 10, 0), 10);
1665 
1666 	iov[0].iov_base = recv_mem;
1667 	iov[0].iov_len = sizeof(recv_mem);
1668 	iov[1].iov_base = NULL; /* broken iov to make process_rx_list fail */
1669 	iov[1].iov_len = 1;
1670 
1671 	hdr.msg_iovlen = 2;
1672 	hdr.msg_iov = iov;
1673 
1674 	EXPECT_EQ(recv(self->cfd, recv_mem, 1, 0), 1);
1675 	EXPECT_EQ(recv_mem[0], rec1[0]);
1676 
1677 	ret = recvmsg(self->cfd, &hdr, 0);
1678 	EXPECT_LE(ret, sizeof(recv_mem));
1679 	EXPECT_GE(ret, 9);
1680 	EXPECT_EQ(memcmp(rec1, recv_mem, 9), 0);
1681 	if (ret > 9)
1682 		EXPECT_EQ(memcmp(rec2, recv_mem + 9, ret - 9), 0);
1683 }
1684 
1685 #define TLS_RECORD_TYPE_HANDSHAKE      0x16
1686 /* key_update, length 1, update_not_requested */
1687 static const char key_update_msg[] = "\x18\x00\x00\x01\x00";
tls_send_keyupdate(struct __test_metadata * _metadata,int fd)1688 static void tls_send_keyupdate(struct __test_metadata *_metadata, int fd)
1689 {
1690 	size_t len = sizeof(key_update_msg);
1691 
1692 	EXPECT_EQ(tls_send_cmsg(fd, TLS_RECORD_TYPE_HANDSHAKE,
1693 				(char *)key_update_msg, len, 0),
1694 		  len);
1695 }
1696 
tls_recv_keyupdate(struct __test_metadata * _metadata,int fd,int flags)1697 static void tls_recv_keyupdate(struct __test_metadata *_metadata, int fd, int flags)
1698 {
1699 	char buf[100];
1700 
1701 	EXPECT_EQ(tls_recv_cmsg(_metadata, fd, TLS_RECORD_TYPE_HANDSHAKE, buf, sizeof(buf), flags),
1702 		  sizeof(key_update_msg));
1703 	EXPECT_EQ(memcmp(buf, key_update_msg, sizeof(key_update_msg)), 0);
1704 }
1705 
1706 /* set the key to 0 then 1 for RX, immediately to 1 for TX */
TEST_F(tls_basic,rekey_rx)1707 TEST_F(tls_basic, rekey_rx)
1708 {
1709 	struct tls_crypto_info_keys tls12_0, tls12_1;
1710 	char const *test_str = "test_message";
1711 	int send_len = strlen(test_str) + 1;
1712 	char buf[20];
1713 	int ret;
1714 
1715 	if (self->notls)
1716 		return;
1717 
1718 	tls_crypto_info_init(TLS_1_3_VERSION, TLS_CIPHER_AES_GCM_128,
1719 			     &tls12_0, 0);
1720 	tls_crypto_info_init(TLS_1_3_VERSION, TLS_CIPHER_AES_GCM_128,
1721 			     &tls12_1, 1);
1722 
1723 	ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12_1, tls12_1.len);
1724 	ASSERT_EQ(ret, 0);
1725 
1726 	ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12_0, tls12_0.len);
1727 	ASSERT_EQ(ret, 0);
1728 
1729 	ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12_1, tls12_1.len);
1730 	EXPECT_EQ(ret, 0);
1731 
1732 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1733 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), send_len);
1734 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1735 }
1736 
1737 /* set the key to 0 then 1 for TX, immediately to 1 for RX */
TEST_F(tls_basic,rekey_tx)1738 TEST_F(tls_basic, rekey_tx)
1739 {
1740 	struct tls_crypto_info_keys tls12_0, tls12_1;
1741 	char const *test_str = "test_message";
1742 	int send_len = strlen(test_str) + 1;
1743 	char buf[20];
1744 	int ret;
1745 
1746 	if (self->notls)
1747 		return;
1748 
1749 	tls_crypto_info_init(TLS_1_3_VERSION, TLS_CIPHER_AES_GCM_128,
1750 			     &tls12_0, 0);
1751 	tls_crypto_info_init(TLS_1_3_VERSION, TLS_CIPHER_AES_GCM_128,
1752 			     &tls12_1, 1);
1753 
1754 	ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12_0, tls12_0.len);
1755 	ASSERT_EQ(ret, 0);
1756 
1757 	ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12_1, tls12_1.len);
1758 	ASSERT_EQ(ret, 0);
1759 
1760 	ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12_1, tls12_1.len);
1761 	EXPECT_EQ(ret, 0);
1762 
1763 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1764 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), send_len);
1765 	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1766 }
1767 
TEST_F(tls_basic,disconnect)1768 TEST_F(tls_basic, disconnect)
1769 {
1770 	char const *test_str = "test_message";
1771 	int send_len = strlen(test_str) + 1;
1772 	struct tls_crypto_info_keys key;
1773 	struct sockaddr_in addr;
1774 	char buf[20];
1775 	int ret;
1776 
1777 	if (self->notls)
1778 		return;
1779 
1780 	tls_crypto_info_init(TLS_1_3_VERSION, TLS_CIPHER_AES_GCM_128,
1781 			     &key, 0);
1782 
1783 	ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &key, key.len);
1784 	ASSERT_EQ(ret, 0);
1785 
1786 	/* Pre-queue the data so that setsockopt parses it but doesn't
1787 	 * dequeue it from the TCP socket. recvmsg would dequeue.
1788 	 */
1789 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1790 
1791 	ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &key, key.len);
1792 	ASSERT_EQ(ret, 0);
1793 
1794 	addr.sin_family = AF_UNSPEC;
1795 	addr.sin_addr.s_addr = htonl(INADDR_ANY);
1796 	addr.sin_port = 0;
1797 	ret = connect(self->cfd, &addr, sizeof(addr));
1798 	EXPECT_EQ(ret, -1);
1799 	EXPECT_EQ(errno, EOPNOTSUPP);
1800 
1801 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), send_len);
1802 }
1803 
TEST_F(tls,rekey)1804 TEST_F(tls, rekey)
1805 {
1806 	char const *test_str_1 = "test_message_before_rekey";
1807 	char const *test_str_2 = "test_message_after_rekey";
1808 	struct tls_crypto_info_keys tls12;
1809 	int send_len;
1810 	char buf[100];
1811 
1812 	if (variant->tls_version != TLS_1_3_VERSION)
1813 		return;
1814 
1815 	/* initial send/recv */
1816 	send_len = strlen(test_str_1) + 1;
1817 	EXPECT_EQ(send(self->fd, test_str_1, send_len, 0), send_len);
1818 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), send_len);
1819 	EXPECT_EQ(memcmp(buf, test_str_1, send_len), 0);
1820 
1821 	/* update TX key */
1822 	tls_send_keyupdate(_metadata, self->fd);
1823 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
1824 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
1825 
1826 	/* send after rekey */
1827 	send_len = strlen(test_str_2) + 1;
1828 	EXPECT_EQ(send(self->fd, test_str_2, send_len, 0), send_len);
1829 
1830 	/* can't receive the KeyUpdate without a control message */
1831 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
1832 
1833 	/* get KeyUpdate */
1834 	tls_recv_keyupdate(_metadata, self->cfd, 0);
1835 
1836 	/* recv blocking -> -EKEYEXPIRED */
1837 	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), 0), -1);
1838 	EXPECT_EQ(errno, EKEYEXPIRED);
1839 
1840 	/* recv non-blocking -> -EKEYEXPIRED */
1841 	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_DONTWAIT), -1);
1842 	EXPECT_EQ(errno, EKEYEXPIRED);
1843 
1844 	/* update RX key */
1845 	EXPECT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
1846 
1847 	/* recv after rekey */
1848 	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1849 	EXPECT_EQ(memcmp(buf, test_str_2, send_len), 0);
1850 }
1851 
TEST_F(tls,rekey_fail)1852 TEST_F(tls, rekey_fail)
1853 {
1854 	char const *test_str_1 = "test_message_before_rekey";
1855 	char const *test_str_2 = "test_message_after_rekey";
1856 	struct tls_crypto_info_keys tls12;
1857 	int send_len;
1858 	char buf[100];
1859 
1860 	/* initial send/recv */
1861 	send_len = strlen(test_str_1) + 1;
1862 	EXPECT_EQ(send(self->fd, test_str_1, send_len, 0), send_len);
1863 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), send_len);
1864 	EXPECT_EQ(memcmp(buf, test_str_1, send_len), 0);
1865 
1866 	/* update TX key */
1867 	tls_send_keyupdate(_metadata, self->fd);
1868 
1869 	if (variant->tls_version != TLS_1_3_VERSION) {
1870 		/* just check that rekey is not supported and return */
1871 		tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
1872 		EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), -1);
1873 		EXPECT_EQ(errno, EBUSY);
1874 		return;
1875 	}
1876 
1877 	/* successful update */
1878 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
1879 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
1880 
1881 	/* invalid update: change of version */
1882 	tls_crypto_info_init(TLS_1_2_VERSION, variant->cipher_type, &tls12, 1);
1883 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), -1);
1884 	EXPECT_EQ(errno, EINVAL);
1885 
1886 	/* invalid update (RX socket): change of version */
1887 	tls_crypto_info_init(TLS_1_2_VERSION, variant->cipher_type, &tls12, 1);
1888 	EXPECT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), -1);
1889 	EXPECT_EQ(errno, EINVAL);
1890 
1891 	/* invalid update: change of cipher */
1892 	if (variant->cipher_type == TLS_CIPHER_AES_GCM_256)
1893 		tls_crypto_info_init(variant->tls_version, TLS_CIPHER_CHACHA20_POLY1305, &tls12, 1);
1894 	else
1895 		tls_crypto_info_init(variant->tls_version, TLS_CIPHER_AES_GCM_256, &tls12, 1);
1896 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), -1);
1897 	EXPECT_EQ(errno, EINVAL);
1898 
1899 	/* send after rekey, the invalid updates shouldn't have an effect */
1900 	send_len = strlen(test_str_2) + 1;
1901 	EXPECT_EQ(send(self->fd, test_str_2, send_len, 0), send_len);
1902 
1903 	/* can't receive the KeyUpdate without a control message */
1904 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
1905 
1906 	/* get KeyUpdate */
1907 	tls_recv_keyupdate(_metadata, self->cfd, 0);
1908 
1909 	/* recv blocking -> -EKEYEXPIRED */
1910 	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), 0), -1);
1911 	EXPECT_EQ(errno, EKEYEXPIRED);
1912 
1913 	/* recv non-blocking -> -EKEYEXPIRED */
1914 	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_DONTWAIT), -1);
1915 	EXPECT_EQ(errno, EKEYEXPIRED);
1916 
1917 	/* update RX key */
1918 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
1919 	EXPECT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
1920 
1921 	/* recv after rekey */
1922 	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1923 	EXPECT_EQ(memcmp(buf, test_str_2, send_len), 0);
1924 }
1925 
TEST_F(tls,rekey_peek)1926 TEST_F(tls, rekey_peek)
1927 {
1928 	char const *test_str_1 = "test_message_before_rekey";
1929 	struct tls_crypto_info_keys tls12;
1930 	int send_len;
1931 	char buf[100];
1932 
1933 	if (variant->tls_version != TLS_1_3_VERSION)
1934 		return;
1935 
1936 	send_len = strlen(test_str_1) + 1;
1937 	EXPECT_EQ(send(self->fd, test_str_1, send_len, 0), send_len);
1938 
1939 	/* update TX key */
1940 	tls_send_keyupdate(_metadata, self->fd);
1941 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
1942 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
1943 
1944 	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_PEEK), send_len);
1945 	EXPECT_EQ(memcmp(buf, test_str_1, send_len), 0);
1946 
1947 	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), send_len);
1948 	EXPECT_EQ(memcmp(buf, test_str_1, send_len), 0);
1949 
1950 	/* can't receive the KeyUpdate without a control message */
1951 	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_PEEK), -1);
1952 
1953 	/* peek KeyUpdate */
1954 	tls_recv_keyupdate(_metadata, self->cfd, MSG_PEEK);
1955 
1956 	/* get KeyUpdate */
1957 	tls_recv_keyupdate(_metadata, self->cfd, 0);
1958 
1959 	/* update RX key */
1960 	EXPECT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
1961 }
1962 
TEST_F(tls,splice_rekey)1963 TEST_F(tls, splice_rekey)
1964 {
1965 	int send_len = TLS_PAYLOAD_MAX_LEN / 2;
1966 	char mem_send[TLS_PAYLOAD_MAX_LEN];
1967 	char mem_recv[TLS_PAYLOAD_MAX_LEN];
1968 	struct tls_crypto_info_keys tls12;
1969 	int p[2];
1970 
1971 	if (variant->tls_version != TLS_1_3_VERSION)
1972 		return;
1973 
1974 	memrnd(mem_send, sizeof(mem_send));
1975 
1976 	ASSERT_GE(pipe(p), 0);
1977 	EXPECT_EQ(send(self->fd, mem_send, send_len, 0), send_len);
1978 
1979 	/* update TX key */
1980 	tls_send_keyupdate(_metadata, self->fd);
1981 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
1982 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
1983 
1984 	EXPECT_EQ(send(self->fd, mem_send, send_len, 0), send_len);
1985 
1986 	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, TLS_PAYLOAD_MAX_LEN, 0), send_len);
1987 	EXPECT_EQ(read(p[0], mem_recv, send_len), send_len);
1988 	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
1989 
1990 	/* can't splice the KeyUpdate */
1991 	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, TLS_PAYLOAD_MAX_LEN, 0), -1);
1992 	EXPECT_EQ(errno, EINVAL);
1993 
1994 	/* peek KeyUpdate */
1995 	tls_recv_keyupdate(_metadata, self->cfd, MSG_PEEK);
1996 
1997 	/* get KeyUpdate */
1998 	tls_recv_keyupdate(_metadata, self->cfd, 0);
1999 
2000 	/* can't splice before updating the key */
2001 	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, TLS_PAYLOAD_MAX_LEN, 0), -1);
2002 	EXPECT_EQ(errno, EKEYEXPIRED);
2003 
2004 	/* update RX key */
2005 	EXPECT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
2006 
2007 	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, TLS_PAYLOAD_MAX_LEN, 0), send_len);
2008 	EXPECT_EQ(read(p[0], mem_recv, send_len), send_len);
2009 	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
2010 }
2011 
TEST_F(tls,rekey_peek_splice)2012 TEST_F(tls, rekey_peek_splice)
2013 {
2014 	char const *test_str_1 = "test_message_before_rekey";
2015 	struct tls_crypto_info_keys tls12;
2016 	int send_len;
2017 	char buf[100];
2018 	char mem_recv[TLS_PAYLOAD_MAX_LEN];
2019 	int p[2];
2020 
2021 	if (variant->tls_version != TLS_1_3_VERSION)
2022 		return;
2023 
2024 	ASSERT_GE(pipe(p), 0);
2025 
2026 	send_len = strlen(test_str_1) + 1;
2027 	EXPECT_EQ(send(self->fd, test_str_1, send_len, 0), send_len);
2028 
2029 	/* update TX key */
2030 	tls_send_keyupdate(_metadata, self->fd);
2031 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
2032 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
2033 
2034 	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_PEEK), send_len);
2035 	EXPECT_EQ(memcmp(buf, test_str_1, send_len), 0);
2036 
2037 	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, TLS_PAYLOAD_MAX_LEN, 0), send_len);
2038 	EXPECT_EQ(read(p[0], mem_recv, send_len), send_len);
2039 	EXPECT_EQ(memcmp(mem_recv, test_str_1, send_len), 0);
2040 }
2041 
TEST_F(tls,rekey_getsockopt)2042 TEST_F(tls, rekey_getsockopt)
2043 {
2044 	struct tls_crypto_info_keys tls12;
2045 	struct tls_crypto_info_keys tls12_get;
2046 	socklen_t len;
2047 
2048 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 0);
2049 
2050 	len = tls12.len;
2051 	EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &tls12_get, &len), 0);
2052 	EXPECT_EQ(len, tls12.len);
2053 	EXPECT_EQ(memcmp(&tls12_get, &tls12, tls12.len), 0);
2054 
2055 	len = tls12.len;
2056 	EXPECT_EQ(getsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12_get, &len), 0);
2057 	EXPECT_EQ(len, tls12.len);
2058 	EXPECT_EQ(memcmp(&tls12_get, &tls12, tls12.len), 0);
2059 
2060 	if (variant->tls_version != TLS_1_3_VERSION)
2061 		return;
2062 
2063 	tls_send_keyupdate(_metadata, self->fd);
2064 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
2065 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
2066 
2067 	tls_recv_keyupdate(_metadata, self->cfd, 0);
2068 	EXPECT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
2069 
2070 	len = tls12.len;
2071 	EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &tls12_get, &len), 0);
2072 	EXPECT_EQ(len, tls12.len);
2073 	EXPECT_EQ(memcmp(&tls12_get, &tls12, tls12.len), 0);
2074 
2075 	len = tls12.len;
2076 	EXPECT_EQ(getsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12_get, &len), 0);
2077 	EXPECT_EQ(len, tls12.len);
2078 	EXPECT_EQ(memcmp(&tls12_get, &tls12, tls12.len), 0);
2079 }
2080 
TEST_F(tls,rekey_poll_pending)2081 TEST_F(tls, rekey_poll_pending)
2082 {
2083 	char const *test_str = "test_message_after_rekey";
2084 	struct tls_crypto_info_keys tls12;
2085 	struct pollfd pfd = { };
2086 	int send_len;
2087 	int ret;
2088 
2089 	if (variant->tls_version != TLS_1_3_VERSION)
2090 		return;
2091 
2092 	/* update TX key */
2093 	tls_send_keyupdate(_metadata, self->fd);
2094 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
2095 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
2096 
2097 	/* get KeyUpdate */
2098 	tls_recv_keyupdate(_metadata, self->cfd, 0);
2099 
2100 	/* send immediately after rekey */
2101 	send_len = strlen(test_str) + 1;
2102 	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
2103 
2104 	/* key hasn't been updated, expect cfd to be non-readable */
2105 	pfd.fd = self->cfd;
2106 	pfd.events = POLLIN;
2107 	EXPECT_EQ(poll(&pfd, 1, 0), 0);
2108 
2109 	ret = fork();
2110 	ASSERT_GE(ret, 0);
2111 
2112 	if (ret) {
2113 		int pid2, status;
2114 
2115 		/* wait before installing the new key */
2116 		sleep(1);
2117 
2118 		/* update RX key while poll() is sleeping */
2119 		EXPECT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
2120 
2121 		pid2 = wait(&status);
2122 		EXPECT_EQ(pid2, ret);
2123 		EXPECT_EQ(status, 0);
2124 	} else {
2125 		pfd.fd = self->cfd;
2126 		pfd.events = POLLIN;
2127 		EXPECT_EQ(poll(&pfd, 1, 5000), 1);
2128 
2129 		exit(!__test_passed(_metadata));
2130 	}
2131 }
2132 
TEST_F(tls,rekey_poll_delay)2133 TEST_F(tls, rekey_poll_delay)
2134 {
2135 	char const *test_str = "test_message_after_rekey";
2136 	struct tls_crypto_info_keys tls12;
2137 	struct pollfd pfd = { };
2138 	int send_len;
2139 	int ret;
2140 
2141 	if (variant->tls_version != TLS_1_3_VERSION)
2142 		return;
2143 
2144 	/* update TX key */
2145 	tls_send_keyupdate(_metadata, self->fd);
2146 	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
2147 	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
2148 
2149 	/* get KeyUpdate */
2150 	tls_recv_keyupdate(_metadata, self->cfd, 0);
2151 
2152 	ret = fork();
2153 	ASSERT_GE(ret, 0);
2154 
2155 	if (ret) {
2156 		int pid2, status;
2157 
2158 		/* wait before installing the new key */
2159 		sleep(1);
2160 
2161 		/* update RX key while poll() is sleeping */
2162 		EXPECT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
2163 
2164 		sleep(1);
2165 		send_len = strlen(test_str) + 1;
2166 		EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
2167 
2168 		pid2 = wait(&status);
2169 		EXPECT_EQ(pid2, ret);
2170 		EXPECT_EQ(status, 0);
2171 	} else {
2172 		pfd.fd = self->cfd;
2173 		pfd.events = POLLIN;
2174 		EXPECT_EQ(poll(&pfd, 1, 5000), 1);
2175 		exit(!__test_passed(_metadata));
2176 	}
2177 }
2178 
2179 struct raw_rec {
2180 	unsigned int plain_len;
2181 	unsigned char plain_data[100];
2182 	unsigned int cipher_len;
2183 	unsigned char cipher_data[128];
2184 };
2185 
2186 /* TLS 1.2, AES_CCM, data, seqno:0, plaintext: 'Hello world' */
2187 static const struct raw_rec id0_data_l11 = {
2188 	.plain_len = 11,
2189 	.plain_data = {
2190 		0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f,
2191 		0x72, 0x6c, 0x64,
2192 	},
2193 	.cipher_len = 40,
2194 	.cipher_data = {
2195 		0x17, 0x03, 0x03, 0x00, 0x23, 0x00, 0x00, 0x00,
2196 		0x00, 0x00, 0x00, 0x00, 0x00, 0x26, 0xa2, 0x33,
2197 		0xde, 0x8d, 0x94, 0xf0, 0x29, 0x6c, 0xb1, 0xaf,
2198 		0x6a, 0x75, 0xb2, 0x93, 0xad, 0x45, 0xd5, 0xfd,
2199 		0x03, 0x51, 0x57, 0x8f, 0xf9, 0xcc, 0x3b, 0x42,
2200 	},
2201 };
2202 
2203 /* TLS 1.2, AES_CCM, ctrl, seqno:0, plaintext: '' */
2204 static const struct raw_rec id0_ctrl_l0 = {
2205 	.plain_len = 0,
2206 	.plain_data = {
2207 	},
2208 	.cipher_len = 29,
2209 	.cipher_data = {
2210 		0x16, 0x03, 0x03, 0x00, 0x18, 0x00, 0x00, 0x00,
2211 		0x00, 0x00, 0x00, 0x00, 0x00, 0x13, 0x38, 0x7b,
2212 		0xa6, 0x1c, 0xdd, 0xa7, 0x19, 0x33, 0xab, 0xae,
2213 		0x88, 0xe1, 0xd2, 0x08, 0x4f,
2214 	},
2215 };
2216 
2217 /* TLS 1.2, AES_CCM, data, seqno:0, plaintext: '' */
2218 static const struct raw_rec id0_data_l0 = {
2219 	.plain_len = 0,
2220 	.plain_data = {
2221 	},
2222 	.cipher_len = 29,
2223 	.cipher_data = {
2224 		0x17, 0x03, 0x03, 0x00, 0x18, 0x00, 0x00, 0x00,
2225 		0x00, 0x00, 0x00, 0x00, 0x00, 0xc5, 0x37, 0x90,
2226 		0x70, 0x45, 0x89, 0xfb, 0x5c, 0xc7, 0x89, 0x03,
2227 		0x68, 0x80, 0xd3, 0xd8, 0xcc,
2228 	},
2229 };
2230 
2231 /* TLS 1.2, AES_CCM, data, seqno:1, plaintext: 'Hello world' */
2232 static const struct raw_rec id1_data_l11 = {
2233 	.plain_len = 11,
2234 	.plain_data = {
2235 		0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f,
2236 		0x72, 0x6c, 0x64,
2237 	},
2238 	.cipher_len = 40,
2239 	.cipher_data = {
2240 		0x17, 0x03, 0x03, 0x00, 0x23, 0x00, 0x00, 0x00,
2241 		0x00, 0x00, 0x00, 0x00, 0x01, 0x3a, 0x1a, 0x9c,
2242 		0xd0, 0xa8, 0x9a, 0xd6, 0x69, 0xd6, 0x1a, 0xe3,
2243 		0xb5, 0x1f, 0x0d, 0x2c, 0xe2, 0x97, 0x46, 0xff,
2244 		0x2b, 0xcc, 0x5a, 0xc4, 0xa3, 0xb9, 0xef, 0xba,
2245 	},
2246 };
2247 
2248 /* TLS 1.2, AES_CCM, ctrl, seqno:1, plaintext: '' */
2249 static const struct raw_rec id1_ctrl_l0 = {
2250 	.plain_len = 0,
2251 	.plain_data = {
2252 	},
2253 	.cipher_len = 29,
2254 	.cipher_data = {
2255 		0x16, 0x03, 0x03, 0x00, 0x18, 0x00, 0x00, 0x00,
2256 		0x00, 0x00, 0x00, 0x00, 0x01, 0x3e, 0xf0, 0xfe,
2257 		0xee, 0xd9, 0xe2, 0x5d, 0xc7, 0x11, 0x4c, 0xe6,
2258 		0xb4, 0x7e, 0xef, 0x40, 0x2b,
2259 	},
2260 };
2261 
2262 /* TLS 1.2, AES_CCM, data, seqno:1, plaintext: '' */
2263 static const struct raw_rec id1_data_l0 = {
2264 	.plain_len = 0,
2265 	.plain_data = {
2266 	},
2267 	.cipher_len = 29,
2268 	.cipher_data = {
2269 		0x17, 0x03, 0x03, 0x00, 0x18, 0x00, 0x00, 0x00,
2270 		0x00, 0x00, 0x00, 0x00, 0x01, 0xce, 0xfc, 0x86,
2271 		0xc8, 0xf0, 0x55, 0xf9, 0x47, 0x3f, 0x74, 0xdc,
2272 		0xc9, 0xbf, 0xfe, 0x5b, 0xb1,
2273 	},
2274 };
2275 
2276 /* TLS 1.2, AES_CCM, ctrl, seqno:2, plaintext: 'Hello world' */
2277 static const struct raw_rec id2_ctrl_l11 = {
2278 	.plain_len = 11,
2279 	.plain_data = {
2280 		0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f,
2281 		0x72, 0x6c, 0x64,
2282 	},
2283 	.cipher_len = 40,
2284 	.cipher_data = {
2285 		0x16, 0x03, 0x03, 0x00, 0x23, 0x00, 0x00, 0x00,
2286 		0x00, 0x00, 0x00, 0x00, 0x02, 0xe5, 0x3d, 0x19,
2287 		0x3d, 0xca, 0xb8, 0x16, 0xb6, 0xff, 0x79, 0x87,
2288 		0x2a, 0x04, 0x11, 0x3d, 0xf8, 0x64, 0x5f, 0x36,
2289 		0x8b, 0xa8, 0xee, 0x4c, 0x6d, 0x62, 0xa5, 0x00,
2290 	},
2291 };
2292 
2293 /* TLS 1.2, AES_CCM, data, seqno:2, plaintext: 'Hello world' */
2294 static const struct raw_rec id2_data_l11 = {
2295 	.plain_len = 11,
2296 	.plain_data = {
2297 		0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f,
2298 		0x72, 0x6c, 0x64,
2299 	},
2300 	.cipher_len = 40,
2301 	.cipher_data = {
2302 		0x17, 0x03, 0x03, 0x00, 0x23, 0x00, 0x00, 0x00,
2303 		0x00, 0x00, 0x00, 0x00, 0x02, 0xe5, 0x3d, 0x19,
2304 		0x3d, 0xca, 0xb8, 0x16, 0xb6, 0xff, 0x79, 0x87,
2305 		0x8e, 0xa1, 0xd0, 0xcd, 0x33, 0xb5, 0x86, 0x2b,
2306 		0x17, 0xf1, 0x52, 0x2a, 0x55, 0x62, 0x65, 0x11,
2307 	},
2308 };
2309 
2310 /* TLS 1.2, AES_CCM, ctrl, seqno:2, plaintext: '' */
2311 static const struct raw_rec id2_ctrl_l0 = {
2312 	.plain_len = 0,
2313 	.plain_data = {
2314 	},
2315 	.cipher_len = 29,
2316 	.cipher_data = {
2317 		0x16, 0x03, 0x03, 0x00, 0x18, 0x00, 0x00, 0x00,
2318 		0x00, 0x00, 0x00, 0x00, 0x02, 0xdc, 0x5c, 0x0e,
2319 		0x41, 0xdd, 0xba, 0xd3, 0xcc, 0xcf, 0x6d, 0xd9,
2320 		0x06, 0xdb, 0x79, 0xe5, 0x5d,
2321 	},
2322 };
2323 
2324 /* TLS 1.2, AES_CCM, data, seqno:2, plaintext: '' */
2325 static const struct raw_rec id2_data_l0 = {
2326 	.plain_len = 0,
2327 	.plain_data = {
2328 	},
2329 	.cipher_len = 29,
2330 	.cipher_data = {
2331 		0x17, 0x03, 0x03, 0x00, 0x18, 0x00, 0x00, 0x00,
2332 		0x00, 0x00, 0x00, 0x00, 0x02, 0xc3, 0xca, 0x26,
2333 		0x22, 0xe4, 0x25, 0xfb, 0x5f, 0x6d, 0xbf, 0x83,
2334 		0x30, 0x48, 0x69, 0x1a, 0x47,
2335 	},
2336 };
2337 
FIXTURE(zero_len)2338 FIXTURE(zero_len)
2339 {
2340 	int fd, cfd;
2341 	bool notls;
2342 };
2343 
FIXTURE_VARIANT(zero_len)2344 FIXTURE_VARIANT(zero_len)
2345 {
2346 	const struct raw_rec *recs[4];
2347 	ssize_t recv_ret[4];
2348 };
2349 
FIXTURE_VARIANT_ADD(zero_len,data_data_data)2350 FIXTURE_VARIANT_ADD(zero_len, data_data_data)
2351 {
2352 	.recs = { &id0_data_l11, &id1_data_l11, &id2_data_l11, },
2353 	.recv_ret = { 33, -EAGAIN, },
2354 };
2355 
FIXTURE_VARIANT_ADD(zero_len,data_0ctrl_data)2356 FIXTURE_VARIANT_ADD(zero_len, data_0ctrl_data)
2357 {
2358 	.recs = { &id0_data_l11, &id1_ctrl_l0, &id2_data_l11, },
2359 	.recv_ret = { 11, 0, 11, -EAGAIN, },
2360 };
2361 
2362 FIXTURE_VARIANT_ADD(zero_len, 0data_0data_0data)
2363 {
2364 	.recs = { &id0_data_l0, &id1_data_l0, &id2_data_l0, },
2365 	.recv_ret = { -EAGAIN, },
2366 };
2367 
2368 FIXTURE_VARIANT_ADD(zero_len, 0data_0data_ctrl)
2369 {
2370 	.recs = { &id0_data_l0, &id1_data_l0, &id2_ctrl_l11, },
2371 	.recv_ret = { 0, 11, -EAGAIN, },
2372 };
2373 
2374 FIXTURE_VARIANT_ADD(zero_len, 0data_0data_0ctrl)
2375 {
2376 	.recs = { &id0_data_l0, &id1_data_l0, &id2_ctrl_l0, },
2377 	.recv_ret = { 0, 0, -EAGAIN, },
2378 };
2379 
2380 FIXTURE_VARIANT_ADD(zero_len, 0ctrl_0ctrl_0ctrl)
2381 {
2382 	.recs = { &id0_ctrl_l0, &id1_ctrl_l0, &id2_ctrl_l0, },
2383 	.recv_ret = { 0, 0, 0, -EAGAIN, },
2384 };
2385 
2386 FIXTURE_VARIANT_ADD(zero_len, 0data_0data_data)
2387 {
2388 	.recs = { &id0_data_l0, &id1_data_l0, &id2_data_l11, },
2389 	.recv_ret = { 11, -EAGAIN, },
2390 };
2391 
FIXTURE_VARIANT_ADD(zero_len,data_0data_0data)2392 FIXTURE_VARIANT_ADD(zero_len, data_0data_0data)
2393 {
2394 	.recs = { &id0_data_l11, &id1_data_l0, &id2_data_l0, },
2395 	.recv_ret = { 11, -EAGAIN, },
2396 };
2397 
FIXTURE_SETUP(zero_len)2398 FIXTURE_SETUP(zero_len)
2399 {
2400 	struct tls_crypto_info_keys tls12;
2401 	int ret;
2402 
2403 	tls_crypto_info_init(TLS_1_2_VERSION, TLS_CIPHER_AES_CCM_128,
2404 			     &tls12, 0);
2405 
2406 	ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
2407 	if (self->notls)
2408 		return;
2409 
2410 	/* Don't install keys on fd, we'll send raw records */
2411 	ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len);
2412 	ASSERT_EQ(ret, 0);
2413 }
2414 
FIXTURE_TEARDOWN(zero_len)2415 FIXTURE_TEARDOWN(zero_len)
2416 {
2417 	close(self->fd);
2418 	close(self->cfd);
2419 }
2420 
TEST_F(zero_len,test)2421 TEST_F(zero_len, test)
2422 {
2423 	const struct raw_rec *const *rec;
2424 	unsigned char buf[128];
2425 	int rec_off;
2426 	int i;
2427 
2428 	for (i = 0; i < 4 && variant->recs[i]; i++)
2429 		EXPECT_EQ(send(self->fd, variant->recs[i]->cipher_data,
2430 			       variant->recs[i]->cipher_len, 0),
2431 			  variant->recs[i]->cipher_len);
2432 
2433 	rec = &variant->recs[0];
2434 	rec_off = 0;
2435 	for (i = 0; i < 4; i++) {
2436 		int j, ret;
2437 
2438 		ret = variant->recv_ret[i] >= 0 ? variant->recv_ret[i] : -1;
2439 		EXPECT_EQ(__tls_recv_cmsg(_metadata, self->cfd, NULL,
2440 					  buf, sizeof(buf), MSG_DONTWAIT), ret);
2441 		if (ret == -1)
2442 			EXPECT_EQ(errno, -variant->recv_ret[i]);
2443 		if (variant->recv_ret[i] == -EAGAIN)
2444 			break;
2445 
2446 		for (j = 0; j < ret; j++) {
2447 			while (rec_off == (*rec)->plain_len) {
2448 				rec++;
2449 				rec_off = 0;
2450 			}
2451 			EXPECT_EQ(buf[j], (*rec)->plain_data[rec_off]);
2452 			rec_off++;
2453 		}
2454 	}
2455 };
2456 
FIXTURE(tls_err)2457 FIXTURE(tls_err)
2458 {
2459 	int fd, cfd;
2460 	int fd2, cfd2;
2461 	bool notls;
2462 };
2463 
FIXTURE_VARIANT(tls_err)2464 FIXTURE_VARIANT(tls_err)
2465 {
2466 	uint16_t tls_version;
2467 };
2468 
2469 FIXTURE_VARIANT_ADD(tls_err, 12_aes_gcm)
2470 {
2471 	.tls_version = TLS_1_2_VERSION,
2472 };
2473 
2474 FIXTURE_VARIANT_ADD(tls_err, 13_aes_gcm)
2475 {
2476 	.tls_version = TLS_1_3_VERSION,
2477 };
2478 
FIXTURE_SETUP(tls_err)2479 FIXTURE_SETUP(tls_err)
2480 {
2481 	struct tls_crypto_info_keys tls12;
2482 	int ret;
2483 
2484 	tls_crypto_info_init(variant->tls_version, TLS_CIPHER_AES_GCM_128,
2485 			     &tls12, 0);
2486 
2487 	ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
2488 	ulp_sock_pair(_metadata, &self->fd2, &self->cfd2, &self->notls);
2489 	if (self->notls)
2490 		return;
2491 
2492 	ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
2493 	ASSERT_EQ(ret, 0);
2494 
2495 	ret = setsockopt(self->cfd2, SOL_TLS, TLS_RX, &tls12, tls12.len);
2496 	ASSERT_EQ(ret, 0);
2497 }
2498 
FIXTURE_TEARDOWN(tls_err)2499 FIXTURE_TEARDOWN(tls_err)
2500 {
2501 	close(self->fd);
2502 	close(self->cfd);
2503 	close(self->fd2);
2504 	close(self->cfd2);
2505 }
2506 
TEST_F(tls_err,bad_rec)2507 TEST_F(tls_err, bad_rec)
2508 {
2509 	char buf[64];
2510 
2511 	if (self->notls)
2512 		SKIP(return, "no TLS support");
2513 
2514 	memset(buf, 0x55, sizeof(buf));
2515 	EXPECT_EQ(send(self->fd2, buf, sizeof(buf), 0), sizeof(buf));
2516 	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2517 	EXPECT_EQ(errno, EMSGSIZE);
2518 	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), MSG_DONTWAIT), -1);
2519 	EXPECT_EQ(errno, EAGAIN);
2520 }
2521 
TEST_F(tls_err,bad_auth)2522 TEST_F(tls_err, bad_auth)
2523 {
2524 	char buf[128];
2525 	int n;
2526 
2527 	if (self->notls)
2528 		SKIP(return, "no TLS support");
2529 
2530 	memrnd(buf, sizeof(buf) / 2);
2531 	EXPECT_EQ(send(self->fd, buf, sizeof(buf) / 2, 0), sizeof(buf) / 2);
2532 	n = recv(self->cfd, buf, sizeof(buf), 0);
2533 	EXPECT_GT(n, sizeof(buf) / 2);
2534 
2535 	buf[n - 1]++;
2536 
2537 	EXPECT_EQ(send(self->fd2, buf, n, 0), n);
2538 	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2539 	EXPECT_EQ(errno, EBADMSG);
2540 	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2541 	EXPECT_EQ(errno, EBADMSG);
2542 }
2543 
TEST_F(tls_err,bad_in_large_read)2544 TEST_F(tls_err, bad_in_large_read)
2545 {
2546 	char txt[3][64];
2547 	char cip[3][128];
2548 	char buf[3 * 128];
2549 	int i, n;
2550 
2551 	if (self->notls)
2552 		SKIP(return, "no TLS support");
2553 
2554 	/* Put 3 records in the sockets */
2555 	for (i = 0; i < 3; i++) {
2556 		memrnd(txt[i], sizeof(txt[i]));
2557 		EXPECT_EQ(send(self->fd, txt[i], sizeof(txt[i]), 0),
2558 			  sizeof(txt[i]));
2559 		n = recv(self->cfd, cip[i], sizeof(cip[i]), 0);
2560 		EXPECT_GT(n, sizeof(txt[i]));
2561 		/* Break the third message */
2562 		if (i == 2)
2563 			cip[2][n - 1]++;
2564 		EXPECT_EQ(send(self->fd2, cip[i], n, 0), n);
2565 	}
2566 
2567 	/* We should be able to receive the first two messages */
2568 	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), sizeof(txt[0]) * 2);
2569 	EXPECT_EQ(memcmp(buf, txt[0], sizeof(txt[0])), 0);
2570 	EXPECT_EQ(memcmp(buf + sizeof(txt[0]), txt[1], sizeof(txt[1])), 0);
2571 	/* Third mesasge is bad */
2572 	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2573 	EXPECT_EQ(errno, EBADMSG);
2574 	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2575 	EXPECT_EQ(errno, EBADMSG);
2576 }
2577 
TEST_F(tls_err,bad_cmsg)2578 TEST_F(tls_err, bad_cmsg)
2579 {
2580 	char *test_str = "test_read";
2581 	int send_len = 10;
2582 	char cip[128];
2583 	char buf[128];
2584 	char txt[64];
2585 	int n;
2586 
2587 	if (self->notls)
2588 		SKIP(return, "no TLS support");
2589 
2590 	/* Queue up one data record */
2591 	memrnd(txt, sizeof(txt));
2592 	EXPECT_EQ(send(self->fd, txt, sizeof(txt), 0), sizeof(txt));
2593 	n = recv(self->cfd, cip, sizeof(cip), 0);
2594 	EXPECT_GT(n, sizeof(txt));
2595 	EXPECT_EQ(send(self->fd2, cip, n, 0), n);
2596 
2597 	EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
2598 	n = recv(self->cfd, cip, sizeof(cip), 0);
2599 	cip[n - 1]++; /* Break it */
2600 	EXPECT_GT(n, send_len);
2601 	EXPECT_EQ(send(self->fd2, cip, n, 0), n);
2602 
2603 	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), sizeof(txt));
2604 	EXPECT_EQ(memcmp(buf, txt, sizeof(txt)), 0);
2605 	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2606 	EXPECT_EQ(errno, EBADMSG);
2607 	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2608 	EXPECT_EQ(errno, EBADMSG);
2609 }
2610 
TEST_F(tls_err,timeo)2611 TEST_F(tls_err, timeo)
2612 {
2613 	struct timeval tv = { .tv_usec = 10000, };
2614 	char buf[128];
2615 	int ret;
2616 
2617 	if (self->notls)
2618 		SKIP(return, "no TLS support");
2619 
2620 	ret = setsockopt(self->cfd2, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
2621 	ASSERT_EQ(ret, 0);
2622 
2623 	ret = fork();
2624 	ASSERT_GE(ret, 0);
2625 
2626 	if (ret) {
2627 		usleep(1000); /* Give child a head start */
2628 
2629 		EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2630 		EXPECT_EQ(errno, EAGAIN);
2631 
2632 		EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2633 		EXPECT_EQ(errno, EAGAIN);
2634 
2635 		wait(&ret);
2636 	} else {
2637 		EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2638 		EXPECT_EQ(errno, EAGAIN);
2639 		exit(0);
2640 	}
2641 }
2642 
TEST_F(tls_err,poll_partial_rec)2643 TEST_F(tls_err, poll_partial_rec)
2644 {
2645 	struct pollfd pfd = { };
2646 	ssize_t rec_len;
2647 	char rec[256];
2648 	char buf[128];
2649 
2650 	if (self->notls)
2651 		SKIP(return, "no TLS support");
2652 
2653 	pfd.fd = self->cfd2;
2654 	pfd.events = POLLIN;
2655 	EXPECT_EQ(poll(&pfd, 1, 1), 0);
2656 
2657 	memrnd(buf, sizeof(buf));
2658 	EXPECT_EQ(send(self->fd, buf, sizeof(buf), 0), sizeof(buf));
2659 	rec_len = recv(self->cfd, rec, sizeof(rec), 0);
2660 	EXPECT_GT(rec_len, sizeof(buf));
2661 
2662 	/* Write 100B, not the full record ... */
2663 	EXPECT_EQ(send(self->fd2, rec, 100, 0), 100);
2664 	/* ... no full record should mean no POLLIN */
2665 	pfd.fd = self->cfd2;
2666 	pfd.events = POLLIN;
2667 	EXPECT_EQ(poll(&pfd, 1, 1), 0);
2668 	/* Now write the rest, and it should all pop out of the other end. */
2669 	EXPECT_EQ(send(self->fd2, rec + 100, rec_len - 100, 0), rec_len - 100);
2670 	pfd.fd = self->cfd2;
2671 	pfd.events = POLLIN;
2672 	EXPECT_EQ(poll(&pfd, 1, 1), 1);
2673 	EXPECT_EQ(recv(self->cfd2, rec, sizeof(rec), 0), sizeof(buf));
2674 	EXPECT_EQ(memcmp(buf, rec, sizeof(buf)), 0);
2675 }
2676 
TEST_F(tls_err,epoll_partial_rec)2677 TEST_F(tls_err, epoll_partial_rec)
2678 {
2679 	struct epoll_event ev, events[10];
2680 	ssize_t rec_len;
2681 	char rec[256];
2682 	char buf[128];
2683 	int epollfd;
2684 
2685 	if (self->notls)
2686 		SKIP(return, "no TLS support");
2687 
2688 	epollfd = epoll_create1(0);
2689 	ASSERT_GE(epollfd, 0);
2690 
2691 	memset(&ev, 0, sizeof(ev));
2692 	ev.events = EPOLLIN;
2693 	ev.data.fd = self->cfd2;
2694 	ASSERT_GE(epoll_ctl(epollfd, EPOLL_CTL_ADD, self->cfd2, &ev), 0);
2695 
2696 	EXPECT_EQ(epoll_wait(epollfd, events, 10, 0), 0);
2697 
2698 	memrnd(buf, sizeof(buf));
2699 	EXPECT_EQ(send(self->fd, buf, sizeof(buf), 0), sizeof(buf));
2700 	rec_len = recv(self->cfd, rec, sizeof(rec), 0);
2701 	EXPECT_GT(rec_len, sizeof(buf));
2702 
2703 	/* Write 100B, not the full record ... */
2704 	EXPECT_EQ(send(self->fd2, rec, 100, 0), 100);
2705 	/* ... no full record should mean no POLLIN */
2706 	EXPECT_EQ(epoll_wait(epollfd, events, 10, 0), 0);
2707 	/* Now write the rest, and it should all pop out of the other end. */
2708 	EXPECT_EQ(send(self->fd2, rec + 100, rec_len - 100, 0), rec_len - 100);
2709 	EXPECT_EQ(epoll_wait(epollfd, events, 10, 0), 1);
2710 	EXPECT_EQ(recv(self->cfd2, rec, sizeof(rec), 0), sizeof(buf));
2711 	EXPECT_EQ(memcmp(buf, rec, sizeof(buf)), 0);
2712 
2713 	close(epollfd);
2714 }
2715 
TEST_F(tls_err,poll_partial_rec_async)2716 TEST_F(tls_err, poll_partial_rec_async)
2717 {
2718 	struct pollfd pfd = { };
2719 	ssize_t rec_len;
2720 	char rec[256];
2721 	char buf[128];
2722 	char token;
2723 	int p[2];
2724 	int ret;
2725 
2726 	if (self->notls)
2727 		SKIP(return, "no TLS support");
2728 
2729 	ASSERT_GE(pipe(p), 0);
2730 
2731 	memrnd(buf, sizeof(buf));
2732 	EXPECT_EQ(send(self->fd, buf, sizeof(buf), 0), sizeof(buf));
2733 	rec_len = recv(self->cfd, rec, sizeof(rec), 0);
2734 	EXPECT_GT(rec_len, sizeof(buf));
2735 
2736 	ret = fork();
2737 	ASSERT_GE(ret, 0);
2738 
2739 	if (ret) {
2740 		int status, pid2;
2741 
2742 		close(p[1]);
2743 		usleep(1000); /* Give child a head start */
2744 
2745 		EXPECT_EQ(send(self->fd2, rec, 100, 0), 100);
2746 
2747 		EXPECT_EQ(read(p[0], &token, 1), 1); /* Barrier #1 */
2748 
2749 		EXPECT_EQ(send(self->fd2, rec + 100, rec_len - 100, 0),
2750 			  rec_len - 100);
2751 
2752 		pid2 = wait(&status);
2753 		EXPECT_EQ(pid2, ret);
2754 		EXPECT_EQ(status, 0);
2755 	} else {
2756 		close(p[0]);
2757 
2758 		/* Child should sleep in poll(), never get a wake */
2759 		pfd.fd = self->cfd2;
2760 		pfd.events = POLLIN;
2761 		EXPECT_EQ(poll(&pfd, 1, 20), 0);
2762 
2763 		EXPECT_EQ(write(p[1], &token, 1), 1); /* Barrier #1 */
2764 
2765 		pfd.fd = self->cfd2;
2766 		pfd.events = POLLIN;
2767 		EXPECT_EQ(poll(&pfd, 1, 20), 1);
2768 
2769 		exit(!__test_passed(_metadata));
2770 	}
2771 }
2772 
TEST(non_established)2773 TEST(non_established) {
2774 	struct tls12_crypto_info_aes_gcm_256 tls12;
2775 	struct sockaddr_in addr;
2776 	int sfd, ret, fd;
2777 	socklen_t len;
2778 
2779 	len = sizeof(addr);
2780 
2781 	memset(&tls12, 0, sizeof(tls12));
2782 	tls12.info.version = TLS_1_2_VERSION;
2783 	tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
2784 
2785 	addr.sin_family = AF_INET;
2786 	addr.sin_addr.s_addr = htonl(INADDR_ANY);
2787 	addr.sin_port = 0;
2788 
2789 	fd = socket(AF_INET, SOCK_STREAM, 0);
2790 	sfd = socket(AF_INET, SOCK_STREAM, 0);
2791 
2792 	ret = bind(sfd, &addr, sizeof(addr));
2793 	ASSERT_EQ(ret, 0);
2794 	ret = listen(sfd, 10);
2795 	ASSERT_EQ(ret, 0);
2796 
2797 	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2798 	EXPECT_EQ(ret, -1);
2799 	/* TLS ULP not supported */
2800 	if (errno == ENOENT)
2801 		return;
2802 	EXPECT_EQ(errno, ENOTCONN);
2803 
2804 	ret = setsockopt(sfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2805 	EXPECT_EQ(ret, -1);
2806 	EXPECT_EQ(errno, ENOTCONN);
2807 
2808 	ret = getsockname(sfd, &addr, &len);
2809 	ASSERT_EQ(ret, 0);
2810 
2811 	ret = connect(fd, &addr, sizeof(addr));
2812 	ASSERT_EQ(ret, 0);
2813 
2814 	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2815 	ASSERT_EQ(ret, 0);
2816 
2817 	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2818 	EXPECT_EQ(ret, -1);
2819 	EXPECT_EQ(errno, EEXIST);
2820 
2821 	close(fd);
2822 	close(sfd);
2823 }
2824 
TEST(keysizes)2825 TEST(keysizes) {
2826 	struct tls12_crypto_info_aes_gcm_256 tls12;
2827 	int ret, fd, cfd;
2828 	bool notls;
2829 
2830 	memset(&tls12, 0, sizeof(tls12));
2831 	tls12.info.version = TLS_1_2_VERSION;
2832 	tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
2833 
2834 	ulp_sock_pair(_metadata, &fd, &cfd, &notls);
2835 
2836 	if (!notls) {
2837 		ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12,
2838 				 sizeof(tls12));
2839 		EXPECT_EQ(ret, 0);
2840 
2841 		ret = setsockopt(cfd, SOL_TLS, TLS_RX, &tls12,
2842 				 sizeof(tls12));
2843 		EXPECT_EQ(ret, 0);
2844 	}
2845 
2846 	close(fd);
2847 	close(cfd);
2848 }
2849 
TEST(no_pad)2850 TEST(no_pad) {
2851 	struct tls12_crypto_info_aes_gcm_256 tls12;
2852 	int ret, fd, cfd, val;
2853 	socklen_t len;
2854 	bool notls;
2855 
2856 	memset(&tls12, 0, sizeof(tls12));
2857 	tls12.info.version = TLS_1_3_VERSION;
2858 	tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
2859 
2860 	ulp_sock_pair(_metadata, &fd, &cfd, &notls);
2861 
2862 	if (notls)
2863 		exit(KSFT_SKIP);
2864 
2865 	ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12, sizeof(tls12));
2866 	EXPECT_EQ(ret, 0);
2867 
2868 	ret = setsockopt(cfd, SOL_TLS, TLS_RX, &tls12, sizeof(tls12));
2869 	EXPECT_EQ(ret, 0);
2870 
2871 	val = 1;
2872 	ret = setsockopt(cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
2873 			 (void *)&val, sizeof(val));
2874 	EXPECT_EQ(ret, 0);
2875 
2876 	len = sizeof(val);
2877 	val = 2;
2878 	ret = getsockopt(cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
2879 			 (void *)&val, &len);
2880 	EXPECT_EQ(ret, 0);
2881 	EXPECT_EQ(val, 1);
2882 	EXPECT_EQ(len, 4);
2883 
2884 	val = 0;
2885 	ret = setsockopt(cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
2886 			 (void *)&val, sizeof(val));
2887 	EXPECT_EQ(ret, 0);
2888 
2889 	len = sizeof(val);
2890 	val = 2;
2891 	ret = getsockopt(cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
2892 			 (void *)&val, &len);
2893 	EXPECT_EQ(ret, 0);
2894 	EXPECT_EQ(val, 0);
2895 	EXPECT_EQ(len, 4);
2896 
2897 	close(fd);
2898 	close(cfd);
2899 }
2900 
TEST(tls_v6ops)2901 TEST(tls_v6ops) {
2902 	struct tls_crypto_info_keys tls12;
2903 	struct sockaddr_in6 addr, addr2;
2904 	int sfd, ret, fd;
2905 	socklen_t len, len2;
2906 
2907 	tls_crypto_info_init(TLS_1_2_VERSION, TLS_CIPHER_AES_GCM_128, &tls12, 0);
2908 
2909 	addr.sin6_family = AF_INET6;
2910 	addr.sin6_addr = in6addr_any;
2911 	addr.sin6_port = 0;
2912 
2913 	fd = socket(AF_INET6, SOCK_STREAM, 0);
2914 	sfd = socket(AF_INET6, SOCK_STREAM, 0);
2915 
2916 	ret = bind(sfd, &addr, sizeof(addr));
2917 	ASSERT_EQ(ret, 0);
2918 	ret = listen(sfd, 10);
2919 	ASSERT_EQ(ret, 0);
2920 
2921 	len = sizeof(addr);
2922 	ret = getsockname(sfd, &addr, &len);
2923 	ASSERT_EQ(ret, 0);
2924 
2925 	ret = connect(fd, &addr, sizeof(addr));
2926 	ASSERT_EQ(ret, 0);
2927 
2928 	len = sizeof(addr);
2929 	ret = getsockname(fd, &addr, &len);
2930 	ASSERT_EQ(ret, 0);
2931 
2932 	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2933 	if (ret) {
2934 		ASSERT_EQ(errno, ENOENT);
2935 		SKIP(return, "no TLS support");
2936 	}
2937 	ASSERT_EQ(ret, 0);
2938 
2939 	ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
2940 	ASSERT_EQ(ret, 0);
2941 
2942 	ret = setsockopt(fd, SOL_TLS, TLS_RX, &tls12, tls12.len);
2943 	ASSERT_EQ(ret, 0);
2944 
2945 	len2 = sizeof(addr2);
2946 	ret = getsockname(fd, &addr2, &len2);
2947 	ASSERT_EQ(ret, 0);
2948 
2949 	EXPECT_EQ(len2, len);
2950 	EXPECT_EQ(memcmp(&addr, &addr2, len), 0);
2951 
2952 	close(fd);
2953 	close(sfd);
2954 }
2955 
TEST(prequeue)2956 TEST(prequeue) {
2957 	struct tls_crypto_info_keys tls12;
2958 	char buf[20000], buf2[20000];
2959 	struct sockaddr_in addr;
2960 	int sfd, cfd, ret, fd;
2961 	socklen_t len;
2962 
2963 	len = sizeof(addr);
2964 	memrnd(buf, sizeof(buf));
2965 
2966 	tls_crypto_info_init(TLS_1_2_VERSION, TLS_CIPHER_AES_GCM_256, &tls12, 0);
2967 
2968 	addr.sin_family = AF_INET;
2969 	addr.sin_addr.s_addr = htonl(INADDR_ANY);
2970 	addr.sin_port = 0;
2971 
2972 	fd = socket(AF_INET, SOCK_STREAM, 0);
2973 	sfd = socket(AF_INET, SOCK_STREAM, 0);
2974 
2975 	ASSERT_EQ(bind(sfd, &addr, sizeof(addr)), 0);
2976 	ASSERT_EQ(listen(sfd, 10), 0);
2977 	ASSERT_EQ(getsockname(sfd, &addr, &len), 0);
2978 	ASSERT_EQ(connect(fd, &addr, sizeof(addr)), 0);
2979 	ASSERT_GE(cfd = accept(sfd, &addr, &len), 0);
2980 	close(sfd);
2981 
2982 	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2983 	if (ret) {
2984 		ASSERT_EQ(errno, ENOENT);
2985 		SKIP(return, "no TLS support");
2986 	}
2987 
2988 	ASSERT_EQ(setsockopt(fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
2989 	EXPECT_EQ(send(fd, buf, sizeof(buf), MSG_DONTWAIT), sizeof(buf));
2990 
2991 	ASSERT_EQ(setsockopt(cfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls")), 0);
2992 	ASSERT_EQ(setsockopt(cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
2993 	EXPECT_EQ(recv(cfd, buf2, sizeof(buf2), MSG_WAITALL), sizeof(buf2));
2994 
2995 	EXPECT_EQ(memcmp(buf, buf2, sizeof(buf)), 0);
2996 
2997 	close(fd);
2998 	close(cfd);
2999 }
3000 
TEST(data_steal)3001 TEST(data_steal) {
3002 	struct tls_crypto_info_keys tls;
3003 	char buf[20000], buf2[20000];
3004 	struct sockaddr_in addr;
3005 	int sfd, cfd, ret, fd;
3006 	int pid, status;
3007 	socklen_t len;
3008 
3009 	len = sizeof(addr);
3010 	memrnd(buf, sizeof(buf));
3011 
3012 	tls_crypto_info_init(TLS_1_2_VERSION, TLS_CIPHER_AES_GCM_256, &tls, 0);
3013 
3014 	addr.sin_family = AF_INET;
3015 	addr.sin_addr.s_addr = htonl(INADDR_ANY);
3016 	addr.sin_port = 0;
3017 
3018 	fd = socket(AF_INET, SOCK_STREAM, 0);
3019 	sfd = socket(AF_INET, SOCK_STREAM, 0);
3020 
3021 	ASSERT_EQ(bind(sfd, &addr, sizeof(addr)), 0);
3022 	ASSERT_EQ(listen(sfd, 10), 0);
3023 	ASSERT_EQ(getsockname(sfd, &addr, &len), 0);
3024 	ASSERT_EQ(connect(fd, &addr, sizeof(addr)), 0);
3025 	ASSERT_GE(cfd = accept(sfd, &addr, &len), 0);
3026 	close(sfd);
3027 
3028 	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
3029 	if (ret) {
3030 		ASSERT_EQ(errno, ENOENT);
3031 		SKIP(return, "no TLS support");
3032 	}
3033 	ASSERT_EQ(setsockopt(cfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls")), 0);
3034 
3035 	/* Spawn a child and get it into the read wait path of the underlying
3036 	 * TCP socket.
3037 	 */
3038 	pid = fork();
3039 	ASSERT_GE(pid, 0);
3040 	if (!pid) {
3041 		EXPECT_EQ(recv(cfd, buf, sizeof(buf) / 2, MSG_WAITALL),
3042 			  sizeof(buf) / 2);
3043 		exit(!__test_passed(_metadata));
3044 	}
3045 
3046 	usleep(10000);
3047 	ASSERT_EQ(setsockopt(fd, SOL_TLS, TLS_TX, &tls, tls.len), 0);
3048 	ASSERT_EQ(setsockopt(cfd, SOL_TLS, TLS_RX, &tls, tls.len), 0);
3049 
3050 	EXPECT_EQ(send(fd, buf, sizeof(buf), 0), sizeof(buf));
3051 	EXPECT_EQ(wait(&status), pid);
3052 	EXPECT_EQ(status, 0);
3053 	EXPECT_EQ(recv(cfd, buf2, sizeof(buf2), MSG_DONTWAIT), -1);
3054 	/* Don't check errno, the error will be different depending
3055 	 * on what random bytes TLS interpreted as the record length.
3056 	 */
3057 
3058 	close(fd);
3059 	close(cfd);
3060 }
3061 
fips_check(void)3062 static void __attribute__((constructor)) fips_check(void) {
3063 	int res;
3064 	FILE *f;
3065 
3066 	f = fopen("/proc/sys/crypto/fips_enabled", "r");
3067 	if (f) {
3068 		res = fscanf(f, "%d", &fips_enabled);
3069 		if (res != 1)
3070 			ksft_print_msg("ERROR: Couldn't read /proc/sys/crypto/fips_enabled\n");
3071 		fclose(f);
3072 	}
3073 }
3074 
3075 TEST_HARNESS_MAIN
3076