xref: /linux/tools/testing/selftests/net/tcp_ao/key-management.c (revision 8b6d678fede700db6466d73f11fcbad496fa515e)
1 // SPDX-License-Identifier: GPL-2.0
2 /* Author: Dmitry Safonov <dima@arista.com> */
3 #include <inttypes.h>
4 #include "../../../../include/linux/kernel.h"
5 #include "aolib.h"
6 
7 const size_t nr_packets = 20;
8 const size_t msg_len = 100;
9 const size_t quota = nr_packets * msg_len;
10 union tcp_addr wrong_addr;
11 #define SECOND_PASSWORD	"at all times sincere friends of freedom have been rare"
12 #define fault(type)	(inj == FAULT_ ## type)
13 
14 static const int test_vrf_ifindex = 200;
15 static const uint8_t test_vrf_tabid = 42;
16 static void setup_vrfs(void)
17 {
18 	int err;
19 
20 	if (!kernel_config_has(KCONFIG_NET_VRF))
21 		return;
22 
23 	err = add_vrf("ksft-vrf", test_vrf_tabid, test_vrf_ifindex, -1);
24 	if (err)
25 		test_error("Failed to add a VRF: %d", err);
26 
27 	err = link_set_up("ksft-vrf");
28 	if (err)
29 		test_error("Failed to bring up a VRF");
30 
31 	err = ip_route_add_vrf(veth_name, TEST_FAMILY,
32 			       this_ip_addr, this_ip_dest, test_vrf_tabid);
33 	if (err)
34 		test_error("Failed to add a route to VRF");
35 }
36 
37 
38 static int prepare_sk(union tcp_addr *addr, uint8_t sndid, uint8_t rcvid)
39 {
40 	int sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
41 
42 	if (sk < 0)
43 		test_error("socket()");
44 
45 	if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest,
46 			 DEFAULT_TEST_PREFIX, 100, 100))
47 		test_error("test_add_key()");
48 
49 	if (addr && test_add_key(sk, SECOND_PASSWORD, *addr,
50 				 DEFAULT_TEST_PREFIX, sndid, rcvid))
51 		test_error("test_add_key()");
52 
53 	return sk;
54 }
55 
56 static int prepare_lsk(union tcp_addr *addr, uint8_t sndid, uint8_t rcvid)
57 {
58 	int sk = prepare_sk(addr, sndid, rcvid);
59 
60 	if (listen(sk, 10))
61 		test_error("listen()");
62 
63 	return sk;
64 }
65 
66 static int test_del_key(int sk, uint8_t sndid, uint8_t rcvid, bool async,
67 			int current_key, int rnext_key)
68 {
69 	struct tcp_ao_info_opt ao_info = {};
70 	struct tcp_ao_getsockopt key = {};
71 	struct tcp_ao_del del = {};
72 	sockaddr_af sockaddr;
73 	int err;
74 
75 	tcp_addr_to_sockaddr_in(&del.addr, &this_ip_dest, 0);
76 	del.prefix = DEFAULT_TEST_PREFIX;
77 	del.sndid = sndid;
78 	del.rcvid = rcvid;
79 
80 	if (current_key >= 0) {
81 		del.set_current = 1;
82 		del.current_key = (uint8_t)current_key;
83 	}
84 	if (rnext_key >= 0) {
85 		del.set_rnext = 1;
86 		del.rnext = (uint8_t)rnext_key;
87 	}
88 
89 	err = setsockopt(sk, IPPROTO_TCP, TCP_AO_DEL_KEY, &del, sizeof(del));
90 	if (err < 0)
91 		return -errno;
92 
93 	if (async)
94 		return 0;
95 
96 	tcp_addr_to_sockaddr_in(&sockaddr, &this_ip_dest, 0);
97 	err = test_get_one_ao(sk, &key, &sockaddr, sizeof(sockaddr),
98 			      DEFAULT_TEST_PREFIX, sndid, rcvid);
99 	if (!err)
100 		return -EEXIST;
101 	if (err != -E2BIG)
102 		test_error("getsockopt()");
103 	if (current_key < 0 && rnext_key < 0)
104 		return 0;
105 	if (test_get_ao_info(sk, &ao_info))
106 		test_error("getsockopt(TCP_AO_INFO) failed");
107 	if (current_key >= 0 && ao_info.current_key != (uint8_t)current_key)
108 		return -ENOTRECOVERABLE;
109 	if (rnext_key >= 0 && ao_info.rnext != (uint8_t)rnext_key)
110 		return -ENOTRECOVERABLE;
111 	return 0;
112 }
113 
114 static void try_delete_key(char *tst_name, int sk, uint8_t sndid, uint8_t rcvid,
115 			   bool async, int current_key, int rnext_key,
116 			   fault_t inj)
117 {
118 	int err;
119 
120 	err = test_del_key(sk, sndid, rcvid, async, current_key, rnext_key);
121 	if ((err == -EBUSY && fault(BUSY)) || (err == -EINVAL && fault(CURRNEXT))) {
122 		test_ok("%s: key deletion was prevented", tst_name);
123 		return;
124 	}
125 	if (err && fault(FIXME)) {
126 		test_xfail("%s: failed to delete the key %u:%u %d",
127 			   tst_name, sndid, rcvid, err);
128 		return;
129 	}
130 	if (!err) {
131 		if (fault(BUSY) || fault(CURRNEXT)) {
132 			test_fail("%s: the key was deleted %u:%u %d", tst_name,
133 				  sndid, rcvid, err);
134 		} else {
135 			test_ok("%s: the key was deleted", tst_name);
136 		}
137 		return;
138 	}
139 	test_fail("%s: can't delete the key %u:%u %d", tst_name, sndid, rcvid, err);
140 }
141 
142 static int test_set_key(int sk, int current_keyid, int rnext_keyid)
143 {
144 	struct tcp_ao_info_opt ao_info = {};
145 	int err;
146 
147 	if (current_keyid >= 0) {
148 		ao_info.set_current = 1;
149 		ao_info.current_key = (uint8_t)current_keyid;
150 	}
151 	if (rnext_keyid >= 0) {
152 		ao_info.set_rnext = 1;
153 		ao_info.rnext = (uint8_t)rnext_keyid;
154 	}
155 
156 	err = test_set_ao_info(sk, &ao_info);
157 	if (err)
158 		return err;
159 	if (test_get_ao_info(sk, &ao_info))
160 		test_error("getsockopt(TCP_AO_INFO) failed");
161 	if (current_keyid >= 0 && ao_info.current_key != (uint8_t)current_keyid)
162 		return -ENOTRECOVERABLE;
163 	if (rnext_keyid >= 0 && ao_info.rnext != (uint8_t)rnext_keyid)
164 		return -ENOTRECOVERABLE;
165 	return 0;
166 }
167 
168 static int test_add_current_rnext_key(int sk, const char *key, uint8_t keyflags,
169 				      union tcp_addr in_addr, uint8_t prefix,
170 				      bool set_current, bool set_rnext,
171 				      uint8_t sndid, uint8_t rcvid)
172 {
173 	struct tcp_ao_add tmp = {};
174 	int err;
175 
176 	err = test_prepare_key(&tmp, DEFAULT_TEST_ALGO, in_addr,
177 			       set_current, set_rnext,
178 			       prefix, 0, sndid, rcvid, 0, keyflags,
179 			       strlen(key), key);
180 	if (err)
181 		return err;
182 
183 
184 	err = setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &tmp, sizeof(tmp));
185 	if (err < 0)
186 		return -errno;
187 
188 	return test_verify_socket_key(sk, &tmp);
189 }
190 
191 static int __try_add_current_rnext_key(int sk, const char *key, uint8_t keyflags,
192 				       union tcp_addr in_addr, uint8_t prefix,
193 				       bool set_current, bool set_rnext,
194 				       uint8_t sndid, uint8_t rcvid)
195 {
196 	struct tcp_ao_info_opt ao_info = {};
197 	int err;
198 
199 	err = test_add_current_rnext_key(sk, key, keyflags, in_addr, prefix,
200 					 set_current, set_rnext, sndid, rcvid);
201 	if (err)
202 		return err;
203 
204 	if (test_get_ao_info(sk, &ao_info))
205 		test_error("getsockopt(TCP_AO_INFO) failed");
206 	if (set_current && ao_info.current_key != sndid)
207 		return -ENOTRECOVERABLE;
208 	if (set_rnext && ao_info.rnext != rcvid)
209 		return -ENOTRECOVERABLE;
210 	return 0;
211 }
212 
213 static void try_add_current_rnext_key(char *tst_name, int sk, const char *key,
214 				     uint8_t keyflags,
215 				     union tcp_addr in_addr, uint8_t prefix,
216 				     bool set_current, bool set_rnext,
217 				     uint8_t sndid, uint8_t rcvid, fault_t inj)
218 {
219 	int err;
220 
221 	err = __try_add_current_rnext_key(sk, key, keyflags, in_addr, prefix,
222 					  set_current, set_rnext, sndid, rcvid);
223 	if (!err && !fault(CURRNEXT)) {
224 		test_ok("%s", tst_name);
225 		return;
226 	}
227 	if (err == -EINVAL && fault(CURRNEXT)) {
228 		test_ok("%s", tst_name);
229 		return;
230 	}
231 	test_fail("%s", tst_name);
232 }
233 
234 static void check_closed_socket(void)
235 {
236 	int sk;
237 
238 	sk = prepare_sk(&this_ip_dest, 200, 200);
239 	try_delete_key("closed socket, delete a key", sk, 200, 200, 0, -1, -1, 0);
240 	try_delete_key("closed socket, delete all keys", sk, 100, 100, 0, -1, -1, 0);
241 	close(sk);
242 
243 	sk = prepare_sk(&this_ip_dest, 200, 200);
244 	if (test_set_key(sk, 100, 200))
245 		test_error("failed to set current/rnext keys");
246 	try_delete_key("closed socket, delete current key", sk, 100, 100, 0, -1, -1, FAULT_BUSY);
247 	try_delete_key("closed socket, delete rnext key", sk, 200, 200, 0, -1, -1, FAULT_BUSY);
248 	close(sk);
249 
250 	sk = prepare_sk(&this_ip_dest, 200, 200);
251 	if (test_add_key(sk, "Glory to heros!", this_ip_dest,
252 			 DEFAULT_TEST_PREFIX, 10, 11))
253 		test_error("test_add_key()");
254 	if (test_add_key(sk, "Glory to Ukraine!", this_ip_dest,
255 			 DEFAULT_TEST_PREFIX, 12, 13))
256 		test_error("test_add_key()");
257 	try_delete_key("closed socket, delete a key + set current/rnext", sk, 100, 100, 0, 10, 13, 0);
258 	try_delete_key("closed socket, force-delete current key", sk, 10, 11, 0, 200, -1, 0);
259 	try_delete_key("closed socket, force-delete rnext key", sk, 12, 13, 0, -1, 200, 0);
260 	try_delete_key("closed socket, delete current+rnext key", sk, 200, 200, 0, -1, -1, FAULT_BUSY);
261 	close(sk);
262 
263 	sk = prepare_sk(&this_ip_dest, 200, 200);
264 	if (test_set_key(sk, 100, 200))
265 		test_error("failed to set current/rnext keys");
266 	try_add_current_rnext_key("closed socket, add + change current key",
267 				  sk, "Laaaa! Lalala-la-la-lalala...", 0,
268 				  this_ip_dest, DEFAULT_TEST_PREFIX,
269 				  true, false, 10, 20, 0);
270 	try_add_current_rnext_key("closed socket, add + change rnext key",
271 				  sk, "Laaaa! Lalala-la-la-lalala...", 0,
272 				  this_ip_dest, DEFAULT_TEST_PREFIX,
273 				  false, true, 20, 10, 0);
274 	close(sk);
275 }
276 
277 static void assert_no_current_rnext(const char *tst_msg, int sk)
278 {
279 	struct tcp_ao_info_opt ao_info = {};
280 
281 	if (test_get_ao_info(sk, &ao_info))
282 		test_error("getsockopt(TCP_AO_INFO) failed");
283 
284 	errno = 0;
285 	if (ao_info.set_current || ao_info.set_rnext) {
286 		test_xfail("%s: the socket has current/rnext keys: %d:%d",
287 			   tst_msg,
288 			   (ao_info.set_current) ? ao_info.current_key : -1,
289 			   (ao_info.set_rnext) ? ao_info.rnext : -1);
290 	} else {
291 		test_ok("%s: the socket has no current/rnext keys", tst_msg);
292 	}
293 }
294 
295 static void assert_no_tcp_repair(void)
296 {
297 	struct tcp_ao_repair ao_img = {};
298 	socklen_t len = sizeof(ao_img);
299 	int sk, err;
300 
301 	sk = prepare_sk(&this_ip_dest, 200, 200);
302 	test_enable_repair(sk);
303 	if (listen(sk, 10))
304 		test_error("listen()");
305 	errno = 0;
306 	err = getsockopt(sk, SOL_TCP, TCP_AO_REPAIR, &ao_img, &len);
307 	if (err && errno == EPERM)
308 		test_ok("listen socket, getsockopt(TCP_AO_REPAIR) is restricted");
309 	else
310 		test_fail("listen socket, getsockopt(TCP_AO_REPAIR) works");
311 	errno = 0;
312 	err = setsockopt(sk, SOL_TCP, TCP_AO_REPAIR, &ao_img, sizeof(ao_img));
313 	if (err && errno == EPERM)
314 		test_ok("listen socket, setsockopt(TCP_AO_REPAIR) is restricted");
315 	else
316 		test_fail("listen socket, setsockopt(TCP_AO_REPAIR) works");
317 	close(sk);
318 }
319 
320 static void check_listen_socket(void)
321 {
322 	int sk, err;
323 
324 	sk = prepare_lsk(&this_ip_dest, 200, 200);
325 	try_delete_key("listen socket, delete a key", sk, 200, 200, 0, -1, -1, 0);
326 	try_delete_key("listen socket, delete all keys", sk, 100, 100, 0, -1, -1, 0);
327 	close(sk);
328 
329 	sk = prepare_lsk(&this_ip_dest, 200, 200);
330 	err = test_set_key(sk, 100, -1);
331 	if (err == -EINVAL)
332 		test_ok("listen socket, setting current key not allowed");
333 	else
334 		test_fail("listen socket, set current key");
335 	err = test_set_key(sk, -1, 200);
336 	if (err == -EINVAL)
337 		test_ok("listen socket, setting rnext key not allowed");
338 	else
339 		test_fail("listen socket, set rnext key");
340 	close(sk);
341 
342 	sk = prepare_sk(&this_ip_dest, 200, 200);
343 	if (test_set_key(sk, 100, 200))
344 		test_error("failed to set current/rnext keys");
345 	if (listen(sk, 10))
346 		test_error("listen()");
347 	assert_no_current_rnext("listen() after current/rnext keys set", sk);
348 	try_delete_key("listen socket, delete current key from before listen()", sk, 100, 100, 0, -1, -1, FAULT_FIXME);
349 	try_delete_key("listen socket, delete rnext key from before listen()", sk, 200, 200, 0, -1, -1, FAULT_FIXME);
350 	close(sk);
351 
352 	assert_no_tcp_repair();
353 
354 	sk = prepare_lsk(&this_ip_dest, 200, 200);
355 	if (test_add_key(sk, "Glory to heros!", this_ip_dest,
356 			 DEFAULT_TEST_PREFIX, 10, 11))
357 		test_error("test_add_key()");
358 	if (test_add_key(sk, "Glory to Ukraine!", this_ip_dest,
359 			 DEFAULT_TEST_PREFIX, 12, 13))
360 		test_error("test_add_key()");
361 	try_delete_key("listen socket, delete a key + set current/rnext", sk,
362 		       100, 100, 0, 10, 13, FAULT_CURRNEXT);
363 	try_delete_key("listen socket, force-delete current key", sk,
364 		       10, 11, 0, 200, -1, FAULT_CURRNEXT);
365 	try_delete_key("listen socket, force-delete rnext key", sk,
366 		       12, 13, 0, -1, 200, FAULT_CURRNEXT);
367 	try_delete_key("listen socket, delete a key", sk,
368 		       200, 200, 0, -1, -1, 0);
369 	close(sk);
370 
371 	sk = prepare_lsk(&this_ip_dest, 200, 200);
372 	try_add_current_rnext_key("listen socket, add + change current key",
373 				  sk, "Laaaa! Lalala-la-la-lalala...", 0,
374 				  this_ip_dest, DEFAULT_TEST_PREFIX,
375 				  true, false, 10, 20, FAULT_CURRNEXT);
376 	try_add_current_rnext_key("listen socket, add + change rnext key",
377 				  sk, "Laaaa! Lalala-la-la-lalala...", 0,
378 				  this_ip_dest, DEFAULT_TEST_PREFIX,
379 				  false, true, 20, 10, FAULT_CURRNEXT);
380 	close(sk);
381 }
382 
383 static const char *fips_fpath = "/proc/sys/crypto/fips_enabled";
384 static bool is_fips_enabled(void)
385 {
386 	static int fips_checked = -1;
387 	FILE *fenabled;
388 	int enabled;
389 
390 	if (fips_checked >= 0)
391 		return !!fips_checked;
392 	if (access(fips_fpath, R_OK)) {
393 		if (errno != ENOENT)
394 			test_error("Can't open %s", fips_fpath);
395 		fips_checked = 0;
396 		return false;
397 	}
398 	fenabled = fopen(fips_fpath, "r");
399 	if (!fenabled)
400 		test_error("Can't open %s", fips_fpath);
401 	if (fscanf(fenabled, "%d", &enabled) != 1)
402 		test_error("Can't read from %s", fips_fpath);
403 	fclose(fenabled);
404 	fips_checked = !!enabled;
405 	return !!fips_checked;
406 }
407 
408 struct test_key {
409 	char password[TCP_AO_MAXKEYLEN];
410 	const char *alg;
411 	unsigned int len;
412 	uint8_t client_keyid;
413 	uint8_t server_keyid;
414 	uint8_t maclen;
415 	uint8_t matches_client		: 1,
416 		matches_server		: 1,
417 		matches_vrf		: 1,
418 		is_current		: 1,
419 		is_rnext		: 1,
420 		used_on_server_tx	: 1,
421 		used_on_client_tx	: 1,
422 		skip_counters_checks	: 1;
423 };
424 
425 struct key_collection {
426 	unsigned int nr_keys;
427 	struct test_key *keys;
428 };
429 
430 static struct key_collection collection;
431 
432 #define TEST_MAX_MACLEN		16
433 const char *test_algos[] = {
434 	"cmac(aes128)",
435 	"hmac(sha1)", "hmac(sha512)", "hmac(sha384)", "hmac(sha256)",
436 	"hmac(sha224)", "hmac(sha3-512)",
437 	/* only if !CONFIG_FIPS */
438 #define TEST_NON_FIPS_ALGOS	2
439 	"hmac(rmd160)", "hmac(md5)"
440 };
441 const unsigned int test_maclens[] = { 1, 4, 12, 16 };
442 #define MACLEN_SHIFT		2
443 #define ALGOS_SHIFT		4
444 
445 static unsigned int make_mask(unsigned int shift, unsigned int prev_shift)
446 {
447 	unsigned int ret = BIT(shift) - 1;
448 
449 	return ret << prev_shift;
450 }
451 
452 static void init_key_in_collection(unsigned int index, bool randomized)
453 {
454 	struct test_key *key = &collection.keys[index];
455 	unsigned int algos_nr, algos_index;
456 
457 	/* Same for randomized and non-randomized test flows */
458 	key->client_keyid = index;
459 	key->server_keyid = 127 + index;
460 	key->matches_client = 1;
461 	key->matches_server = 1;
462 	key->matches_vrf = 1;
463 	/* not really even random, but good enough for a test */
464 	key->len = rand() % (TCP_AO_MAXKEYLEN - TEST_TCP_AO_MINKEYLEN);
465 	key->len += TEST_TCP_AO_MINKEYLEN;
466 	randomize_buffer(key->password, key->len);
467 
468 	if (randomized) {
469 		key->maclen = (rand() % TEST_MAX_MACLEN) + 1;
470 		algos_index = rand();
471 	} else {
472 		unsigned int shift = MACLEN_SHIFT;
473 
474 		key->maclen = test_maclens[index & make_mask(shift, 0)];
475 		algos_index = index & make_mask(ALGOS_SHIFT, shift);
476 	}
477 	algos_nr = ARRAY_SIZE(test_algos);
478 	if (is_fips_enabled())
479 		algos_nr -= TEST_NON_FIPS_ALGOS;
480 	key->alg = test_algos[algos_index % algos_nr];
481 }
482 
483 static int init_default_key_collection(unsigned int nr_keys, bool randomized)
484 {
485 	size_t key_sz = sizeof(collection.keys[0]);
486 
487 	if (!nr_keys) {
488 		free(collection.keys);
489 		collection.keys = NULL;
490 		return 0;
491 	}
492 
493 	/*
494 	 * All keys have uniq sndid/rcvid and sndid != rcvid in order to
495 	 * check for any bugs/issues for different keyids, visible to both
496 	 * peers. Keyid == 254 is unused.
497 	 */
498 	if (nr_keys > 127)
499 		test_error("Test requires too many keys, correct the source");
500 
501 	collection.keys = reallocarray(collection.keys, nr_keys, key_sz);
502 	if (!collection.keys)
503 		return -ENOMEM;
504 
505 	memset(collection.keys, 0, nr_keys * key_sz);
506 	collection.nr_keys = nr_keys;
507 	while (nr_keys--)
508 		init_key_in_collection(nr_keys, randomized);
509 
510 	return 0;
511 }
512 
513 static void test_key_error(const char *msg, struct test_key *key)
514 {
515 	test_error("%s: key: { %s, %u:%u, %u, %u:%u:%u:%u:%u (%u)}",
516 		   msg, key->alg, key->client_keyid, key->server_keyid,
517 		   key->maclen, key->matches_client, key->matches_server,
518 		   key->matches_vrf, key->is_current, key->is_rnext, key->len);
519 }
520 
521 static int test_add_key_cr(int sk, const char *pwd, unsigned int pwd_len,
522 			   union tcp_addr addr, uint8_t vrf,
523 			   uint8_t sndid, uint8_t rcvid,
524 			   uint8_t maclen, const char *alg,
525 			   bool set_current, bool set_rnext)
526 {
527 	struct tcp_ao_add tmp = {};
528 	uint8_t keyflags = 0;
529 	int err;
530 
531 	if (!alg)
532 		alg = DEFAULT_TEST_ALGO;
533 
534 	if (vrf)
535 		keyflags |= TCP_AO_KEYF_IFINDEX;
536 	err = test_prepare_key(&tmp, alg, addr, set_current, set_rnext,
537 			       DEFAULT_TEST_PREFIX, vrf, sndid, rcvid, maclen,
538 			       keyflags, pwd_len, pwd);
539 	if (err)
540 		return err;
541 
542 	err = setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &tmp, sizeof(tmp));
543 	if (err < 0)
544 		return -errno;
545 
546 	return test_verify_socket_key(sk, &tmp);
547 }
548 
549 static void verify_current_rnext(const char *tst, int sk,
550 				 int current_keyid, int rnext_keyid)
551 {
552 	struct tcp_ao_info_opt ao_info = {};
553 
554 	if (test_get_ao_info(sk, &ao_info))
555 		test_error("getsockopt(TCP_AO_INFO) failed");
556 
557 	errno = 0;
558 	if (current_keyid >= 0) {
559 		if (!ao_info.set_current)
560 			test_fail("%s: the socket doesn't have current key", tst);
561 		else if (ao_info.current_key != current_keyid)
562 			test_fail("%s: current key is not the expected one %d != %u",
563 				  tst, current_keyid, ao_info.current_key);
564 		else
565 			test_ok("%s: current key %u as expected",
566 				tst, ao_info.current_key);
567 	}
568 	if (rnext_keyid >= 0) {
569 		if (!ao_info.set_rnext)
570 			test_fail("%s: the socket doesn't have rnext key", tst);
571 		else if (ao_info.rnext != rnext_keyid)
572 			test_fail("%s: rnext key is not the expected one %d != %u",
573 				  tst, rnext_keyid, ao_info.rnext);
574 		else
575 			test_ok("%s: rnext key %u as expected", tst, ao_info.rnext);
576 	}
577 }
578 
579 
580 static int key_collection_socket(bool server, unsigned int port)
581 {
582 	unsigned int i;
583 	int sk;
584 
585 	if (server)
586 		sk = test_listen_socket(this_ip_addr, port, 1);
587 	else
588 		sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
589 	if (sk < 0)
590 		test_error("socket()");
591 
592 	for (i = 0; i < collection.nr_keys; i++) {
593 		struct test_key *key = &collection.keys[i];
594 		union tcp_addr *addr = &wrong_addr;
595 		uint8_t sndid, rcvid, vrf;
596 		bool set_current = false, set_rnext = false;
597 
598 		if (key->matches_vrf)
599 			vrf = 0;
600 		else
601 			vrf = test_vrf_ifindex;
602 		if (server) {
603 			if (key->matches_client)
604 				addr = &this_ip_dest;
605 			sndid = key->server_keyid;
606 			rcvid = key->client_keyid;
607 		} else {
608 			if (key->matches_server)
609 				addr = &this_ip_dest;
610 			sndid = key->client_keyid;
611 			rcvid = key->server_keyid;
612 			key->used_on_client_tx = set_current = key->is_current;
613 			key->used_on_server_tx = set_rnext = key->is_rnext;
614 		}
615 
616 		if (test_add_key_cr(sk, key->password, key->len,
617 				    *addr, vrf, sndid, rcvid, key->maclen,
618 				    key->alg, set_current, set_rnext))
619 			test_key_error("setsockopt(TCP_AO_ADD_KEY)", key);
620 #ifdef DEBUG
621 		test_print("%s [%u/%u] key: { %s, %u:%u, %u, %u:%u:%u:%u (%u)}",
622 			   server ? "server" : "client", i, collection.nr_keys,
623 			   key->alg, rcvid, sndid, key->maclen,
624 			   key->matches_client, key->matches_server,
625 			   key->is_current, key->is_rnext, key->len);
626 #endif
627 	}
628 	return sk;
629 }
630 
631 static void verify_counters(const char *tst_name, bool is_listen_sk, bool server,
632 			    struct tcp_ao_counters *a, struct tcp_ao_counters *b)
633 {
634 	unsigned int i;
635 
636 	__test_tcp_ao_counters_cmp(tst_name, a, b, TEST_CNT_GOOD);
637 
638 	for (i = 0; i < collection.nr_keys; i++) {
639 		struct test_key *key = &collection.keys[i];
640 		uint8_t sndid, rcvid;
641 		bool rx_cnt_expected;
642 
643 		if (key->skip_counters_checks)
644 			continue;
645 		if (server) {
646 			sndid = key->server_keyid;
647 			rcvid = key->client_keyid;
648 			rx_cnt_expected = key->used_on_client_tx;
649 		} else {
650 			sndid = key->client_keyid;
651 			rcvid = key->server_keyid;
652 			rx_cnt_expected = key->used_on_server_tx;
653 		}
654 
655 		test_tcp_ao_key_counters_cmp(tst_name, a, b,
656 					     rx_cnt_expected ? TEST_CNT_KEY_GOOD : 0,
657 					     sndid, rcvid);
658 	}
659 	test_tcp_ao_counters_free(a);
660 	test_tcp_ao_counters_free(b);
661 	test_ok("%s: passed counters checks", tst_name);
662 }
663 
664 static struct tcp_ao_getsockopt *lookup_key(struct tcp_ao_getsockopt *buf,
665 					    size_t len, int sndid, int rcvid)
666 {
667 	size_t i;
668 
669 	for (i = 0; i < len; i++) {
670 		if (sndid >= 0 && buf[i].sndid != sndid)
671 			continue;
672 		if (rcvid >= 0 && buf[i].rcvid != rcvid)
673 			continue;
674 		return &buf[i];
675 	}
676 	return NULL;
677 }
678 
679 static void verify_keys(const char *tst_name, int sk,
680 			bool is_listen_sk, bool server)
681 {
682 	socklen_t len = sizeof(struct tcp_ao_getsockopt);
683 	struct tcp_ao_getsockopt *keys;
684 	bool passed_test = true;
685 	unsigned int i;
686 
687 	keys = calloc(collection.nr_keys, len);
688 	if (!keys)
689 		test_error("calloc()");
690 
691 	keys->nkeys = collection.nr_keys;
692 	keys->get_all = 1;
693 
694 	if (getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, keys, &len)) {
695 		free(keys);
696 		test_error("getsockopt(TCP_AO_GET_KEYS)");
697 	}
698 
699 	for (i = 0; i < collection.nr_keys; i++) {
700 		struct test_key *key = &collection.keys[i];
701 		struct tcp_ao_getsockopt *dump_key;
702 		bool is_kdf_aes_128_cmac = false;
703 		bool is_cmac_aes = false;
704 		uint8_t sndid, rcvid;
705 		bool matches = false;
706 
707 		if (server) {
708 			if (key->matches_client)
709 				matches = true;
710 			sndid = key->server_keyid;
711 			rcvid = key->client_keyid;
712 		} else {
713 			if (key->matches_server)
714 				matches = true;
715 			sndid = key->client_keyid;
716 			rcvid = key->server_keyid;
717 		}
718 		if (!key->matches_vrf)
719 			matches = false;
720 		/* no keys get removed on the original listener socket */
721 		if (is_listen_sk)
722 			matches = true;
723 
724 		dump_key = lookup_key(keys, keys->nkeys, sndid, rcvid);
725 		if (matches != !!dump_key) {
726 			test_fail("%s: key %u:%u %s%s on the socket",
727 				  tst_name, sndid, rcvid,
728 				  key->matches_vrf ? "" : "[vrf] ",
729 				  matches ? "disappeared" : "yet present");
730 			passed_test = false;
731 			goto out;
732 		}
733 		if (!dump_key)
734 			continue;
735 
736 		if (!strcmp("cmac(aes128)", key->alg)) {
737 			is_kdf_aes_128_cmac = (key->len != 16);
738 			is_cmac_aes = true;
739 		}
740 
741 		if (is_cmac_aes) {
742 			if (strcmp(dump_key->alg_name, "cmac(aes)")) {
743 				test_fail("%s: key %u:%u cmac(aes) has unexpected alg %s",
744 					  tst_name, sndid, rcvid,
745 					  dump_key->alg_name);
746 				passed_test = false;
747 				continue;
748 			}
749 		} else if (strcmp(dump_key->alg_name, key->alg)) {
750 			test_fail("%s: key %u:%u has unexpected alg %s != %s",
751 				  tst_name, sndid, rcvid,
752 				  dump_key->alg_name, key->alg);
753 			passed_test = false;
754 			continue;
755 		}
756 		if (is_kdf_aes_128_cmac) {
757 			if (dump_key->keylen != 16) {
758 				test_fail("%s: key %u:%u cmac(aes128) has unexpected len %u",
759 					  tst_name, sndid, rcvid,
760 					  dump_key->keylen);
761 				continue;
762 			}
763 		} else if (dump_key->keylen != key->len) {
764 			test_fail("%s: key %u:%u changed password len %u != %u",
765 				  tst_name, sndid, rcvid,
766 				  dump_key->keylen, key->len);
767 			passed_test = false;
768 			continue;
769 		}
770 		if (!is_kdf_aes_128_cmac &&
771 		    memcmp(dump_key->key, key->password, key->len)) {
772 			test_fail("%s: key %u:%u has different password",
773 				  tst_name, sndid, rcvid);
774 			passed_test = false;
775 			continue;
776 		}
777 		if (dump_key->maclen != key->maclen) {
778 			test_fail("%s: key %u:%u changed maclen %u != %u",
779 				  tst_name, sndid, rcvid,
780 				  dump_key->maclen, key->maclen);
781 			passed_test = false;
782 			continue;
783 		}
784 	}
785 
786 	if (passed_test)
787 		test_ok("%s: The socket keys are consistent with the expectations",
788 			tst_name);
789 out:
790 	free(keys);
791 }
792 
793 static int start_server(const char *tst_name, unsigned int port, size_t quota,
794 			struct tcp_ao_counters *begin,
795 			unsigned int current_index, unsigned int rnext_index)
796 {
797 	struct tcp_ao_counters lsk_c1, lsk_c2;
798 	ssize_t bytes;
799 	int sk, lsk;
800 
801 	synchronize_threads(); /* 1: key collection initialized */
802 	lsk = key_collection_socket(true, port);
803 	if (test_get_tcp_ao_counters(lsk, &lsk_c1))
804 		test_error("test_get_tcp_ao_counters()");
805 	synchronize_threads(); /* 2: MKTs added => connect() */
806 	if (test_wait_fd(lsk, TEST_TIMEOUT_SEC, 0))
807 		test_error("test_wait_fd()");
808 
809 	sk = accept(lsk, NULL, NULL);
810 	if (sk < 0)
811 		test_error("accept()");
812 	if (test_get_tcp_ao_counters(sk, begin))
813 		test_error("test_get_tcp_ao_counters()");
814 
815 	synchronize_threads(); /* 3: accepted => send data */
816 	if (test_get_tcp_ao_counters(lsk, &lsk_c2))
817 		test_error("test_get_tcp_ao_counters()");
818 	verify_keys(tst_name, lsk, true, true);
819 	close(lsk);
820 
821 	bytes = test_server_run(sk, quota, TEST_TIMEOUT_SEC);
822 	if (bytes != quota)
823 		test_fail("%s: server served: %zd", tst_name, bytes);
824 	else
825 		test_ok("%s: server alive", tst_name);
826 
827 	verify_counters(tst_name, true, true, &lsk_c1, &lsk_c2);
828 
829 	return sk;
830 }
831 
832 static void end_server(const char *tst_name, int sk,
833 		       struct tcp_ao_counters *begin)
834 {
835 	struct tcp_ao_counters end;
836 
837 	if (test_get_tcp_ao_counters(sk, &end))
838 		test_error("test_get_tcp_ao_counters()");
839 	verify_keys(tst_name, sk, false, true);
840 
841 	synchronize_threads(); /* 4: verified => closed */
842 	close(sk);
843 
844 	verify_counters(tst_name, false, true, begin, &end);
845 	synchronize_threads(); /* 5: counters */
846 }
847 
848 static void try_server_run(const char *tst_name, unsigned int port, size_t quota,
849 			   unsigned int current_index, unsigned int rnext_index)
850 {
851 	struct tcp_ao_counters tmp;
852 	int sk;
853 
854 	sk = start_server(tst_name, port, quota, &tmp,
855 			  current_index, rnext_index);
856 	end_server(tst_name, sk, &tmp);
857 }
858 
859 static void server_rotations(const char *tst_name, unsigned int port,
860 			     size_t quota, unsigned int rotations,
861 			     unsigned int current_index, unsigned int rnext_index)
862 {
863 	struct tcp_ao_counters tmp;
864 	unsigned int i;
865 	int sk;
866 
867 	sk = start_server(tst_name, port, quota, &tmp,
868 			  current_index, rnext_index);
869 
870 	for (i = current_index + 1; rotations > 0; i++, rotations--) {
871 		ssize_t bytes;
872 
873 		if (i >= collection.nr_keys)
874 			i = 0;
875 		bytes = test_server_run(sk, quota, TEST_TIMEOUT_SEC);
876 		if (bytes != quota) {
877 			test_fail("%s: server served: %zd", tst_name, bytes);
878 			return;
879 		}
880 		verify_current_rnext(tst_name, sk,
881 				     collection.keys[i].server_keyid, -1);
882 		synchronize_threads(); /* verify current/rnext */
883 	}
884 	end_server(tst_name, sk, &tmp);
885 }
886 
887 static int run_client(const char *tst_name, unsigned int port,
888 		      unsigned int nr_keys, int current_index, int rnext_index,
889 		      struct tcp_ao_counters *before,
890 		      const size_t msg_sz, const size_t msg_nr)
891 {
892 	int sk;
893 
894 	synchronize_threads(); /* 1: key collection initialized */
895 	sk = key_collection_socket(false, port);
896 
897 	if (current_index >= 0 || rnext_index >= 0) {
898 		int sndid = -1, rcvid = -1;
899 
900 		if (current_index >= 0)
901 			sndid = collection.keys[current_index].client_keyid;
902 		if (rnext_index >= 0)
903 			rcvid = collection.keys[rnext_index].server_keyid;
904 		if (test_set_key(sk, sndid, rcvid))
905 			test_error("failed to set current/rnext keys");
906 	}
907 	if (before && test_get_tcp_ao_counters(sk, before))
908 		test_error("test_get_tcp_ao_counters()");
909 
910 	synchronize_threads(); /* 2: MKTs added => connect() */
911 	if (test_connect_socket(sk, this_ip_dest, port++) <= 0)
912 		test_error("failed to connect()");
913 	if (current_index < 0)
914 		current_index = nr_keys - 1;
915 	if (rnext_index < 0)
916 		rnext_index = nr_keys - 1;
917 	collection.keys[current_index].used_on_client_tx = 1;
918 	collection.keys[rnext_index].used_on_server_tx = 1;
919 
920 	synchronize_threads(); /* 3: accepted => send data */
921 	if (test_client_verify(sk, msg_sz, msg_nr, TEST_TIMEOUT_SEC)) {
922 		test_fail("verify failed");
923 		close(sk);
924 		if (before)
925 			test_tcp_ao_counters_free(before);
926 		return -1;
927 	}
928 
929 	return sk;
930 }
931 
932 static int start_client(const char *tst_name, unsigned int port,
933 			unsigned int nr_keys, int current_index, int rnext_index,
934 			struct tcp_ao_counters *before,
935 			const size_t msg_sz, const size_t msg_nr)
936 {
937 	if (init_default_key_collection(nr_keys, true))
938 		test_error("Failed to init the key collection");
939 
940 	return run_client(tst_name, port, nr_keys, current_index,
941 			  rnext_index, before, msg_sz, msg_nr);
942 }
943 
944 static void end_client(const char *tst_name, int sk, unsigned int nr_keys,
945 		       int current_index, int rnext_index,
946 		       struct tcp_ao_counters *start)
947 {
948 	struct tcp_ao_counters end;
949 
950 	/* Some application may become dependent on this kernel choice */
951 	if (current_index < 0)
952 		current_index = nr_keys - 1;
953 	if (rnext_index < 0)
954 		rnext_index = nr_keys - 1;
955 	verify_current_rnext(tst_name, sk,
956 			     collection.keys[current_index].client_keyid,
957 			     collection.keys[rnext_index].server_keyid);
958 	if (start && test_get_tcp_ao_counters(sk, &end))
959 		test_error("test_get_tcp_ao_counters()");
960 	verify_keys(tst_name, sk, false, false);
961 	synchronize_threads(); /* 4: verify => closed */
962 	close(sk);
963 	if (start)
964 		verify_counters(tst_name, false, false, start, &end);
965 	synchronize_threads(); /* 5: counters */
966 }
967 
968 static void try_unmatched_keys(int sk, int *rnext_index)
969 {
970 	struct test_key *key;
971 	unsigned int i = 0;
972 	int err;
973 
974 	do {
975 		key = &collection.keys[i];
976 		if (!key->matches_server)
977 			break;
978 	} while (++i < collection.nr_keys);
979 	if (key->matches_server)
980 		test_error("all keys on client match the server");
981 
982 	err = test_add_key_cr(sk, key->password, key->len, wrong_addr,
983 			      0, key->client_keyid, key->server_keyid,
984 			      key->maclen, key->alg, 0, 0);
985 	if (!err) {
986 		test_fail("Added a key with non-matching ip-address for established sk");
987 		return;
988 	}
989 	if (err == -EINVAL)
990 		test_ok("Can't add a key with non-matching ip-address for established sk");
991 	else
992 		test_error("Failed to add a key");
993 
994 	err = test_add_key_cr(sk, key->password, key->len, this_ip_dest,
995 			      test_vrf_ifindex,
996 			      key->client_keyid, key->server_keyid,
997 			      key->maclen, key->alg, 0, 0);
998 	if (!err) {
999 		test_fail("Added a key with non-matching VRF for established sk");
1000 		return;
1001 	}
1002 	if (err == -EINVAL)
1003 		test_ok("Can't add a key with non-matching VRF for established sk");
1004 	else
1005 		test_error("Failed to add a key");
1006 
1007 	for (i = 0; i < collection.nr_keys; i++) {
1008 		key = &collection.keys[i];
1009 		if (!key->matches_client)
1010 			break;
1011 	}
1012 	if (key->matches_client)
1013 		test_error("all keys on server match the client");
1014 	if (test_set_key(sk, -1, key->server_keyid))
1015 		test_error("Can't change the current key");
1016 	if (test_client_verify(sk, msg_len, nr_packets, TEST_TIMEOUT_SEC))
1017 		test_fail("verify failed");
1018 	*rnext_index = i;
1019 }
1020 
1021 static int client_non_matching(const char *tst_name, unsigned int port,
1022 			       unsigned int nr_keys,
1023 			       int current_index, int rnext_index,
1024 			       const size_t msg_sz, const size_t msg_nr)
1025 {
1026 	unsigned int i;
1027 
1028 	if (init_default_key_collection(nr_keys, true))
1029 		test_error("Failed to init the key collection");
1030 
1031 	for (i = 0; i < nr_keys; i++) {
1032 		/* key (0, 0) matches */
1033 		collection.keys[i].matches_client = !!((i + 3) % 4);
1034 		collection.keys[i].matches_server = !!((i + 2) % 4);
1035 		if (kernel_config_has(KCONFIG_NET_VRF))
1036 			collection.keys[i].matches_vrf = !!((i + 1) % 4);
1037 	}
1038 
1039 	return run_client(tst_name, port, nr_keys, current_index,
1040 			  rnext_index, NULL, msg_sz, msg_nr);
1041 }
1042 
1043 static void check_current_back(const char *tst_name, unsigned int port,
1044 			       unsigned int nr_keys,
1045 			       unsigned int current_index, unsigned int rnext_index,
1046 			       unsigned int rotate_to_index)
1047 {
1048 	struct tcp_ao_counters tmp;
1049 	int sk;
1050 
1051 	sk = start_client(tst_name, port, nr_keys, current_index, rnext_index,
1052 			  &tmp, msg_len, nr_packets);
1053 	if (sk < 0)
1054 		return;
1055 	if (test_set_key(sk, collection.keys[rotate_to_index].client_keyid, -1))
1056 		test_error("Can't change the current key");
1057 	if (test_client_verify(sk, msg_len, nr_packets, TEST_TIMEOUT_SEC))
1058 		test_fail("verify failed");
1059 	/* There is a race here: between setting the current_key with
1060 	 * setsockopt(TCP_AO_INFO) and starting to send some data - there
1061 	 * might have been a segment received with the desired
1062 	 * RNext_key set. In turn that would mean that the first outgoing
1063 	 * segment will have the desired current_key (flipped back).
1064 	 * Which is what the user/test wants. As it's racy, skip checking
1065 	 * the counters, yet check what are the resulting current/rnext
1066 	 * keys on both sides.
1067 	 */
1068 	collection.keys[rotate_to_index].skip_counters_checks = 1;
1069 
1070 	end_client(tst_name, sk, nr_keys, current_index, rnext_index, &tmp);
1071 }
1072 
1073 static void roll_over_keys(const char *tst_name, unsigned int port,
1074 			   unsigned int nr_keys, unsigned int rotations,
1075 			   unsigned int current_index, unsigned int rnext_index)
1076 {
1077 	struct tcp_ao_counters tmp;
1078 	unsigned int i;
1079 	int sk;
1080 
1081 	sk = start_client(tst_name, port, nr_keys, current_index, rnext_index,
1082 			  &tmp, msg_len, nr_packets);
1083 	if (sk < 0)
1084 		return;
1085 	for (i = rnext_index + 1; rotations > 0; i++, rotations--) {
1086 		if (i >= collection.nr_keys)
1087 			i = 0;
1088 		if (test_set_key(sk, -1, collection.keys[i].server_keyid))
1089 			test_error("Can't change the Rnext key");
1090 		if (test_client_verify(sk, msg_len, nr_packets, TEST_TIMEOUT_SEC)) {
1091 			test_fail("verify failed");
1092 			close(sk);
1093 			test_tcp_ao_counters_free(&tmp);
1094 			return;
1095 		}
1096 		verify_current_rnext(tst_name, sk, -1,
1097 				     collection.keys[i].server_keyid);
1098 		collection.keys[i].used_on_server_tx = 1;
1099 		synchronize_threads(); /* verify current/rnext */
1100 	}
1101 	end_client(tst_name, sk, nr_keys, current_index, rnext_index, &tmp);
1102 }
1103 
1104 static void try_client_run(const char *tst_name, unsigned int port,
1105 			   unsigned int nr_keys, int current_index, int rnext_index)
1106 {
1107 	struct tcp_ao_counters tmp;
1108 	int sk;
1109 
1110 	sk = start_client(tst_name, port, nr_keys, current_index, rnext_index,
1111 			  &tmp, msg_len, nr_packets);
1112 	if (sk < 0)
1113 		return;
1114 	end_client(tst_name, sk, nr_keys, current_index, rnext_index, &tmp);
1115 }
1116 
1117 static void try_client_match(const char *tst_name, unsigned int port,
1118 			     unsigned int nr_keys,
1119 			     int current_index, int rnext_index)
1120 {
1121 	int sk;
1122 
1123 	sk = client_non_matching(tst_name, port, nr_keys, current_index,
1124 				 rnext_index, msg_len, nr_packets);
1125 	if (sk < 0)
1126 		return;
1127 	try_unmatched_keys(sk, &rnext_index);
1128 	end_client(tst_name, sk, nr_keys, current_index, rnext_index, NULL);
1129 }
1130 
1131 static void *server_fn(void *arg)
1132 {
1133 	unsigned int port = test_server_port;
1134 
1135 	setup_vrfs();
1136 	try_server_run("server: Check current/rnext keys unset before connect()",
1137 		       port++, quota, 19, 19);
1138 	try_server_run("server: Check current/rnext keys set before connect()",
1139 		       port++, quota, 10, 10);
1140 	try_server_run("server: Check current != rnext keys set before connect()",
1141 		       port++, quota, 5, 10);
1142 	try_server_run("server: Check current flapping back on peer's RnextKey request",
1143 		       port++, quota * 2, 5, 10);
1144 	server_rotations("server: Rotate over all different keys", port++,
1145 			 quota, 20, 0, 0);
1146 	try_server_run("server: Check accept() => established key matching",
1147 		       port++, quota * 2, 0, 0);
1148 
1149 	synchronize_threads(); /* don't race to exit: client exits */
1150 	return NULL;
1151 }
1152 
1153 static void check_established_socket(void)
1154 {
1155 	unsigned int port = test_server_port;
1156 
1157 	setup_vrfs();
1158 	try_client_run("client: Check current/rnext keys unset before connect()",
1159 		       port++, 20, -1, -1);
1160 	try_client_run("client: Check current/rnext keys set before connect()",
1161 		       port++, 20, 10, 10);
1162 	try_client_run("client: Check current != rnext keys set before connect()",
1163 		       port++, 20, 10, 5);
1164 	check_current_back("client: Check current flapping back on peer's RnextKey request",
1165 			   port++, 20, 10, 5, 2);
1166 	roll_over_keys("client: Rotate over all different keys", port++,
1167 		       20, 20, 0, 0);
1168 	try_client_match("client: Check connect() => established key matching",
1169 			 port++, 20, 0, 0);
1170 }
1171 
1172 static void *client_fn(void *arg)
1173 {
1174 	if (inet_pton(TEST_FAMILY, TEST_WRONG_IP, &wrong_addr) != 1)
1175 		test_error("Can't convert ip address %s", TEST_WRONG_IP);
1176 	check_closed_socket();
1177 	check_listen_socket();
1178 	check_established_socket();
1179 	return NULL;
1180 }
1181 
1182 int main(int argc, char *argv[])
1183 {
1184 	test_init(120, server_fn, client_fn);
1185 	return 0;
1186 }
1187