xref: /linux/tools/testing/selftests/net/tcp_ao/lib/sock.c (revision 6d9b262afe0ec1d6e0ef99321ca9d6b921310471)
1 // SPDX-License-Identifier: GPL-2.0
2 #include <alloca.h>
3 #include <fcntl.h>
4 #include <inttypes.h>
5 #include <string.h>
6 #include "../../../../../include/linux/kernel.h"
7 #include "../../../../../include/linux/stringify.h"
8 #include "aolib.h"
9 
10 const unsigned int test_server_port = 7010;
11 int __test_listen_socket(int backlog, void *addr, size_t addr_sz)
12 {
13 	int err, sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
14 	long flags;
15 
16 	if (sk < 0)
17 		test_error("socket()");
18 
19 	err = setsockopt(sk, SOL_SOCKET, SO_BINDTODEVICE, veth_name,
20 			 strlen(veth_name) + 1);
21 	if (err < 0)
22 		test_error("setsockopt(SO_BINDTODEVICE)");
23 
24 	if (bind(sk, (struct sockaddr *)addr, addr_sz) < 0)
25 		test_error("bind()");
26 
27 	flags = fcntl(sk, F_GETFL);
28 	if ((flags < 0) || (fcntl(sk, F_SETFL, flags | O_NONBLOCK) < 0))
29 		test_error("fcntl()");
30 
31 	if (listen(sk, backlog))
32 		test_error("listen()");
33 
34 	return sk;
35 }
36 
37 int test_wait_fd(int sk, time_t sec, bool write)
38 {
39 	struct timeval tv = { .tv_sec = sec };
40 	struct timeval *ptv = NULL;
41 	fd_set fds, efds;
42 	int ret;
43 	socklen_t slen = sizeof(ret);
44 
45 	FD_ZERO(&fds);
46 	FD_SET(sk, &fds);
47 	FD_ZERO(&efds);
48 	FD_SET(sk, &efds);
49 
50 	if (sec)
51 		ptv = &tv;
52 
53 	errno = 0;
54 	if (write)
55 		ret = select(sk + 1, NULL, &fds, &efds, ptv);
56 	else
57 		ret = select(sk + 1, &fds, NULL, &efds, ptv);
58 	if (ret < 0)
59 		return -errno;
60 	if (ret == 0) {
61 		errno = ETIMEDOUT;
62 		return -ETIMEDOUT;
63 	}
64 
65 	if (getsockopt(sk, SOL_SOCKET, SO_ERROR, &ret, &slen))
66 		return -errno;
67 	if (ret)
68 		return -ret;
69 	return 0;
70 }
71 
72 int __test_connect_socket(int sk, const char *device,
73 			  void *addr, size_t addr_sz, time_t timeout)
74 {
75 	long flags;
76 	int err;
77 
78 	if (device != NULL) {
79 		err = setsockopt(sk, SOL_SOCKET, SO_BINDTODEVICE, device,
80 				 strlen(device) + 1);
81 		if (err < 0)
82 			test_error("setsockopt(SO_BINDTODEVICE, %s)", device);
83 	}
84 
85 	if (!timeout) {
86 		err = connect(sk, addr, addr_sz);
87 		if (err) {
88 			err = -errno;
89 			goto out;
90 		}
91 		return 0;
92 	}
93 
94 	flags = fcntl(sk, F_GETFL);
95 	if ((flags < 0) || (fcntl(sk, F_SETFL, flags | O_NONBLOCK) < 0))
96 		test_error("fcntl()");
97 
98 	if (connect(sk, addr, addr_sz) < 0) {
99 		if (errno != EINPROGRESS) {
100 			err = -errno;
101 			goto out;
102 		}
103 		if (timeout < 0)
104 			return sk;
105 		err = test_wait_fd(sk, timeout, 1);
106 		if (err)
107 			goto out;
108 	}
109 	return sk;
110 
111 out:
112 	close(sk);
113 	return err;
114 }
115 
116 int __test_set_md5(int sk, void *addr, size_t addr_sz, uint8_t prefix,
117 		   int vrf, const char *password)
118 {
119 	size_t pwd_len = strlen(password);
120 	struct tcp_md5sig md5sig = {};
121 
122 	md5sig.tcpm_keylen = pwd_len;
123 	memcpy(md5sig.tcpm_key, password, pwd_len);
124 	md5sig.tcpm_flags = TCP_MD5SIG_FLAG_PREFIX;
125 	md5sig.tcpm_prefixlen = prefix;
126 	if (vrf >= 0) {
127 		md5sig.tcpm_flags |= TCP_MD5SIG_FLAG_IFINDEX;
128 		md5sig.tcpm_ifindex = (uint8_t)vrf;
129 	}
130 	memcpy(&md5sig.tcpm_addr, addr, addr_sz);
131 
132 	errno = 0;
133 	return setsockopt(sk, IPPROTO_TCP, TCP_MD5SIG_EXT,
134 			&md5sig, sizeof(md5sig));
135 }
136 
137 
138 int test_prepare_key_sockaddr(struct tcp_ao_add *ao, const char *alg,
139 		void *addr, size_t addr_sz, bool set_current, bool set_rnext,
140 		uint8_t prefix, uint8_t vrf, uint8_t sndid, uint8_t rcvid,
141 		uint8_t maclen, uint8_t keyflags,
142 		uint8_t keylen, const char *key)
143 {
144 	memset(ao, 0, sizeof(struct tcp_ao_add));
145 
146 	ao->set_current	= !!set_current;
147 	ao->set_rnext	= !!set_rnext;
148 	ao->prefix	= prefix;
149 	ao->sndid	= sndid;
150 	ao->rcvid	= rcvid;
151 	ao->maclen	= maclen;
152 	ao->keyflags	= keyflags;
153 	ao->keylen	= keylen;
154 	ao->ifindex	= vrf;
155 
156 	memcpy(&ao->addr, addr, addr_sz);
157 
158 	if (strlen(alg) > 64)
159 		return -ENOBUFS;
160 	strncpy(ao->alg_name, alg, 64);
161 
162 	memcpy(ao->key, key,
163 	       (keylen > TCP_AO_MAXKEYLEN) ? TCP_AO_MAXKEYLEN : keylen);
164 	return 0;
165 }
166 
167 static int test_get_ao_keys_nr(int sk)
168 {
169 	struct tcp_ao_getsockopt tmp = {};
170 	socklen_t tmp_sz = sizeof(tmp);
171 	int ret;
172 
173 	tmp.nkeys  = 1;
174 	tmp.get_all = 1;
175 
176 	ret = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, &tmp, &tmp_sz);
177 	if (ret)
178 		return -errno;
179 	return (int)tmp.nkeys;
180 }
181 
182 int test_get_one_ao(int sk, struct tcp_ao_getsockopt *out,
183 		void *addr, size_t addr_sz, uint8_t prefix,
184 		uint8_t sndid, uint8_t rcvid)
185 {
186 	struct tcp_ao_getsockopt tmp = {};
187 	socklen_t tmp_sz = sizeof(tmp);
188 	int ret;
189 
190 	memcpy(&tmp.addr, addr, addr_sz);
191 	tmp.prefix = prefix;
192 	tmp.sndid  = sndid;
193 	tmp.rcvid  = rcvid;
194 	tmp.nkeys  = 1;
195 
196 	ret = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, &tmp, &tmp_sz);
197 	if (ret)
198 		return ret;
199 	if (tmp.nkeys != 1)
200 		return -E2BIG;
201 	*out = tmp;
202 	return 0;
203 }
204 
205 int test_get_ao_info(int sk, struct tcp_ao_info_opt *out)
206 {
207 	socklen_t sz = sizeof(*out);
208 
209 	out->reserved = 0;
210 	out->reserved2 = 0;
211 	if (getsockopt(sk, IPPROTO_TCP, TCP_AO_INFO, out, &sz))
212 		return -errno;
213 	if (sz != sizeof(*out))
214 		return -EMSGSIZE;
215 	return 0;
216 }
217 
218 int test_set_ao_info(int sk, struct tcp_ao_info_opt *in)
219 {
220 	socklen_t sz = sizeof(*in);
221 
222 	in->reserved = 0;
223 	in->reserved2 = 0;
224 	if (setsockopt(sk, IPPROTO_TCP, TCP_AO_INFO, in, sz))
225 		return -errno;
226 	return 0;
227 }
228 
229 int test_cmp_getsockopt_setsockopt(const struct tcp_ao_add *a,
230 				   const struct tcp_ao_getsockopt *b)
231 {
232 	bool is_kdf_aes_128_cmac = false;
233 	bool is_cmac_aes = false;
234 
235 	if (!strcmp("cmac(aes128)", a->alg_name)) {
236 		is_kdf_aes_128_cmac = (a->keylen != 16);
237 		is_cmac_aes = true;
238 	}
239 
240 #define __cmp_ao(member)						\
241 do {									\
242 	if (b->member != a->member) {					\
243 		test_fail("getsockopt(): " __stringify(member) " %u != %u",	\
244 				b->member, a->member);			\
245 		return -1;						\
246 	}								\
247 } while(0)
248 	__cmp_ao(sndid);
249 	__cmp_ao(rcvid);
250 	__cmp_ao(prefix);
251 	__cmp_ao(keyflags);
252 	__cmp_ao(ifindex);
253 	if (a->maclen) {
254 		__cmp_ao(maclen);
255 	} else if (b->maclen != 12) {
256 		test_fail("getsockopt(): expected default maclen 12, but it's %u",
257 				b->maclen);
258 		return -1;
259 	}
260 	if (!is_kdf_aes_128_cmac) {
261 		__cmp_ao(keylen);
262 	} else if (b->keylen != 16) {
263 		test_fail("getsockopt(): expected keylen 16 for cmac(aes128), but it's %u",
264 				b->keylen);
265 		return -1;
266 	}
267 #undef __cmp_ao
268 	if (!is_kdf_aes_128_cmac && memcmp(b->key, a->key, a->keylen)) {
269 		test_fail("getsockopt(): returned key is different `%s' != `%s'",
270 				b->key, a->key);
271 		return -1;
272 	}
273 	if (memcmp(&b->addr, &a->addr, sizeof(b->addr))) {
274 		test_fail("getsockopt(): returned address is different");
275 		return -1;
276 	}
277 	if (!is_cmac_aes && strcmp(b->alg_name, a->alg_name)) {
278 		test_fail("getsockopt(): returned algorithm %s is different than %s", b->alg_name, a->alg_name);
279 		return -1;
280 	}
281 	if (is_cmac_aes && strcmp(b->alg_name, "cmac(aes)")) {
282 		test_fail("getsockopt(): returned algorithm %s is different than cmac(aes)", b->alg_name);
283 		return -1;
284 	}
285 	/* For a established key rotation test don't add a key with
286 	 * set_current = 1, as it's likely to change by peer's request;
287 	 * rather use setsockopt(TCP_AO_INFO)
288 	 */
289 	if (a->set_current != b->is_current) {
290 		test_fail("getsockopt(): returned key is not Current_key");
291 		return -1;
292 	}
293 	if (a->set_rnext != b->is_rnext) {
294 		test_fail("getsockopt(): returned key is not RNext_key");
295 		return -1;
296 	}
297 
298 	return 0;
299 }
300 
301 int test_cmp_getsockopt_setsockopt_ao(const struct tcp_ao_info_opt *a,
302 				      const struct tcp_ao_info_opt *b)
303 {
304 	/* No check for ::current_key, as it may change by the peer */
305 	if (a->ao_required != b->ao_required) {
306 		test_fail("getsockopt(): returned ao doesn't have ao_required");
307 		return -1;
308 	}
309 	if (a->accept_icmps != b->accept_icmps) {
310 		test_fail("getsockopt(): returned ao doesn't accept ICMPs");
311 		return -1;
312 	}
313 	if (a->set_rnext && a->rnext != b->rnext) {
314 		test_fail("getsockopt(): RNext KeyID has changed");
315 		return -1;
316 	}
317 #define __cmp_cnt(member)						\
318 do {									\
319 	if (b->member != a->member) {					\
320 		test_fail("getsockopt(): " __stringify(member) " %llu != %llu",	\
321 				b->member, a->member);			\
322 		return -1;						\
323 	}								\
324 } while(0)
325 	if (a->set_counters) {
326 		__cmp_cnt(pkt_good);
327 		__cmp_cnt(pkt_bad);
328 		__cmp_cnt(pkt_key_not_found);
329 		__cmp_cnt(pkt_ao_required);
330 		__cmp_cnt(pkt_dropped_icmp);
331 	}
332 #undef __cmp_cnt
333 	return 0;
334 }
335 
336 int test_get_tcp_ao_counters(int sk, struct tcp_ao_counters *out)
337 {
338 	struct tcp_ao_getsockopt *key_dump;
339 	socklen_t key_dump_sz = sizeof(*key_dump);
340 	struct tcp_ao_info_opt info = {};
341 	bool c1, c2, c3, c4, c5;
342 	struct netstat *ns;
343 	int err, nr_keys;
344 
345 	memset(out, 0, sizeof(*out));
346 
347 	/* per-netns */
348 	ns = netstat_read();
349 	out->netns_ao_good = netstat_get(ns, "TCPAOGood", &c1);
350 	out->netns_ao_bad = netstat_get(ns, "TCPAOBad", &c2);
351 	out->netns_ao_key_not_found = netstat_get(ns, "TCPAOKeyNotFound", &c3);
352 	out->netns_ao_required = netstat_get(ns, "TCPAORequired", &c4);
353 	out->netns_ao_dropped_icmp = netstat_get(ns, "TCPAODroppedIcmps", &c5);
354 	netstat_free(ns);
355 	if (c1 || c2 || c3 || c4 || c5)
356 		return -EOPNOTSUPP;
357 
358 	err = test_get_ao_info(sk, &info);
359 	if (err)
360 		return err;
361 
362 	/* per-socket */
363 	out->ao_info_pkt_good		= info.pkt_good;
364 	out->ao_info_pkt_bad		= info.pkt_bad;
365 	out->ao_info_pkt_key_not_found	= info.pkt_key_not_found;
366 	out->ao_info_pkt_ao_required	= info.pkt_ao_required;
367 	out->ao_info_pkt_dropped_icmp	= info.pkt_dropped_icmp;
368 
369 	/* per-key */
370 	nr_keys = test_get_ao_keys_nr(sk);
371 	if (nr_keys < 0)
372 		return nr_keys;
373 	if (nr_keys == 0)
374 		test_error("test_get_ao_keys_nr() == 0");
375 	out->nr_keys = (size_t)nr_keys;
376 	key_dump = calloc(nr_keys, key_dump_sz);
377 	if (!key_dump)
378 		return -errno;
379 
380 	key_dump[0].nkeys = nr_keys;
381 	key_dump[0].get_all = 1;
382 	key_dump[0].get_all = 1;
383 	err = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS,
384 			 key_dump, &key_dump_sz);
385 	if (err) {
386 		free(key_dump);
387 		return -errno;
388 	}
389 
390 	out->key_cnts = calloc(nr_keys, sizeof(out->key_cnts[0]));
391 	if (!out->key_cnts) {
392 		free(key_dump);
393 		return -errno;
394 	}
395 
396 	while (nr_keys--) {
397 		out->key_cnts[nr_keys].sndid = key_dump[nr_keys].sndid;
398 		out->key_cnts[nr_keys].rcvid = key_dump[nr_keys].rcvid;
399 		out->key_cnts[nr_keys].pkt_good = key_dump[nr_keys].pkt_good;
400 		out->key_cnts[nr_keys].pkt_bad = key_dump[nr_keys].pkt_bad;
401 	}
402 	free(key_dump);
403 
404 	return 0;
405 }
406 
407 int __test_tcp_ao_counters_cmp(const char *tst_name,
408 			       struct tcp_ao_counters *before,
409 			       struct tcp_ao_counters *after,
410 			       test_cnt expected)
411 {
412 #define __cmp_ao(cnt, expecting_inc)					\
413 do {									\
414 	if (before->cnt > after->cnt) {					\
415 		test_fail("%s: Decreased counter " __stringify(cnt) " %" PRIu64 " > %" PRIu64, \
416 			  tst_name ?: "", before->cnt, after->cnt);		\
417 		return -1;						\
418 	}								\
419 	if ((before->cnt != after->cnt) != (expecting_inc)) {		\
420 		test_fail("%s: Counter " __stringify(cnt) " was %sexpected to increase %" PRIu64 " => %" PRIu64, \
421 			  tst_name ?: "", (expecting_inc) ? "" : "not ",	\
422 			  before->cnt, after->cnt);			\
423 		return -1;						\
424 	}								\
425 } while(0)
426 
427 	errno = 0;
428 	/* per-netns */
429 	__cmp_ao(netns_ao_good, !!(expected & TEST_CNT_NS_GOOD));
430 	__cmp_ao(netns_ao_bad, !!(expected & TEST_CNT_NS_BAD));
431 	__cmp_ao(netns_ao_key_not_found,
432 		 !!(expected & TEST_CNT_NS_KEY_NOT_FOUND));
433 	__cmp_ao(netns_ao_required, !!(expected & TEST_CNT_NS_AO_REQUIRED));
434 	__cmp_ao(netns_ao_dropped_icmp,
435 		 !!(expected & TEST_CNT_NS_DROPPED_ICMP));
436 	/* per-socket */
437 	__cmp_ao(ao_info_pkt_good, !!(expected & TEST_CNT_SOCK_GOOD));
438 	__cmp_ao(ao_info_pkt_bad, !!(expected & TEST_CNT_SOCK_BAD));
439 	__cmp_ao(ao_info_pkt_key_not_found,
440 		 !!(expected & TEST_CNT_SOCK_KEY_NOT_FOUND));
441 	__cmp_ao(ao_info_pkt_ao_required, !!(expected & TEST_CNT_SOCK_AO_REQUIRED));
442 	__cmp_ao(ao_info_pkt_dropped_icmp,
443 		 !!(expected & TEST_CNT_SOCK_DROPPED_ICMP));
444 	return 0;
445 #undef __cmp_ao
446 }
447 
448 int test_tcp_ao_key_counters_cmp(const char *tst_name,
449 				 struct tcp_ao_counters *before,
450 				 struct tcp_ao_counters *after,
451 				 test_cnt expected,
452 				 int sndid, int rcvid)
453 {
454 	size_t i;
455 #define __cmp_ao(i, cnt, expecting_inc)					\
456 do {									\
457 	if (before->key_cnts[i].cnt > after->key_cnts[i].cnt) {		\
458 		test_fail("%s: Decreased counter " __stringify(cnt) " %" PRIu64 " > %" PRIu64 " for key %u:%u", \
459 			  tst_name ?: "", before->key_cnts[i].cnt,	\
460 			  after->key_cnts[i].cnt,			\
461 			  before->key_cnts[i].sndid,			\
462 			  before->key_cnts[i].rcvid);			\
463 		return -1;						\
464 	}								\
465 	if ((before->key_cnts[i].cnt != after->key_cnts[i].cnt) != (expecting_inc)) {		\
466 		test_fail("%s: Counter " __stringify(cnt) " was %sexpected to increase %" PRIu64 " => %" PRIu64 " for key %u:%u", \
467 			  tst_name ?: "", (expecting_inc) ? "" : "not ",\
468 			  before->key_cnts[i].cnt,			\
469 			  after->key_cnts[i].cnt,			\
470 			  before->key_cnts[i].sndid,			\
471 			  before->key_cnts[i].rcvid);			\
472 		return -1;						\
473 	}								\
474 } while(0)
475 
476 	if (before->nr_keys != after->nr_keys) {
477 		test_fail("%s: Keys changed on the socket %zu != %zu",
478 			  tst_name, before->nr_keys, after->nr_keys);
479 		return -1;
480 	}
481 
482 	/* per-key */
483 	i = before->nr_keys;
484 	while (i--) {
485 		if (sndid >= 0 && before->key_cnts[i].sndid != sndid)
486 			continue;
487 		if (rcvid >= 0 && before->key_cnts[i].rcvid != rcvid)
488 			continue;
489 		__cmp_ao(i, pkt_good, !!(expected & TEST_CNT_KEY_GOOD));
490 		__cmp_ao(i, pkt_bad, !!(expected & TEST_CNT_KEY_BAD));
491 	}
492 	return 0;
493 #undef __cmp_ao
494 }
495 
496 void test_tcp_ao_counters_free(struct tcp_ao_counters *cnts)
497 {
498 	free(cnts->key_cnts);
499 }
500 
501 #define TEST_BUF_SIZE 4096
502 ssize_t test_server_run(int sk, ssize_t quota, time_t timeout_sec)
503 {
504 	ssize_t total = 0;
505 
506 	do {
507 		char buf[TEST_BUF_SIZE];
508 		ssize_t bytes, sent;
509 		int ret;
510 
511 		ret = test_wait_fd(sk, timeout_sec, 0);
512 		if (ret)
513 			return ret;
514 
515 		bytes = recv(sk, buf, sizeof(buf), 0);
516 
517 		if (bytes < 0)
518 			test_error("recv(): %zd", bytes);
519 		if (bytes == 0)
520 			break;
521 
522 		ret = test_wait_fd(sk, timeout_sec, 1);
523 		if (ret)
524 			return ret;
525 
526 		sent = send(sk, buf, bytes, 0);
527 		if (sent == 0)
528 			break;
529 		if (sent != bytes)
530 			test_error("send()");
531 		total += bytes;
532 	} while (!quota || total < quota);
533 
534 	return total;
535 }
536 
537 ssize_t test_client_loop(int sk, char *buf, size_t buf_sz,
538 			 const size_t msg_len, time_t timeout_sec)
539 {
540 	char msg[msg_len];
541 	int nodelay = 1;
542 	size_t i;
543 
544 	if (setsockopt(sk, IPPROTO_TCP, TCP_NODELAY, &nodelay, sizeof(nodelay)))
545 		test_error("setsockopt(TCP_NODELAY)");
546 
547 	for (i = 0; i < buf_sz; i += min(msg_len, buf_sz - i)) {
548 		size_t sent, bytes = min(msg_len, buf_sz - i);
549 		int ret;
550 
551 		ret = test_wait_fd(sk, timeout_sec, 1);
552 		if (ret)
553 			return ret;
554 
555 		sent = send(sk, buf + i, bytes, 0);
556 		if (sent == 0)
557 			break;
558 		if (sent != bytes)
559 			test_error("send()");
560 
561 		bytes = 0;
562 		do {
563 			ssize_t got;
564 
565 			ret = test_wait_fd(sk, timeout_sec, 0);
566 			if (ret)
567 				return ret;
568 
569 			got = recv(sk, msg + bytes, sizeof(msg) - bytes, 0);
570 			if (got <= 0)
571 				return i;
572 			bytes += got;
573 		} while (bytes < sent);
574 		if (bytes > sent)
575 			test_error("recv(): %zd > %zd", bytes, sent);
576 		if (memcmp(buf + i, msg, bytes) != 0) {
577 			test_fail("received message differs");
578 			return -1;
579 		}
580 	}
581 	return i;
582 }
583 
584 int test_client_verify(int sk, const size_t msg_len, const size_t nr,
585 		       time_t timeout_sec)
586 {
587 	size_t buf_sz = msg_len * nr;
588 	char *buf = alloca(buf_sz);
589 	ssize_t ret;
590 
591 	randomize_buffer(buf, buf_sz);
592 	ret = test_client_loop(sk, buf, buf_sz, msg_len, timeout_sec);
593 	if (ret < 0)
594 		return (int)ret;
595 	return ret != buf_sz ? -1 : 0;
596 }
597