xref: /linux/tools/testing/selftests/net/tcp_ao/lib/sock.c (revision 1a9239bb4253f9076b5b4b2a1a4e8d7defd77a95)
1 // SPDX-License-Identifier: GPL-2.0
2 #include <alloca.h>
3 #include <fcntl.h>
4 #include <inttypes.h>
5 #include <string.h>
6 #include "../../../../../include/linux/kernel.h"
7 #include "../../../../../include/linux/stringify.h"
8 #include "aolib.h"
9 
10 const unsigned int test_server_port = 7010;
__test_listen_socket(int backlog,void * addr,size_t addr_sz)11 int __test_listen_socket(int backlog, void *addr, size_t addr_sz)
12 {
13 	int err, sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
14 	long flags;
15 
16 	if (sk < 0)
17 		test_error("socket()");
18 
19 	err = setsockopt(sk, SOL_SOCKET, SO_BINDTODEVICE, veth_name,
20 			 strlen(veth_name) + 1);
21 	if (err < 0)
22 		test_error("setsockopt(SO_BINDTODEVICE)");
23 
24 	if (bind(sk, (struct sockaddr *)addr, addr_sz) < 0)
25 		test_error("bind()");
26 
27 	flags = fcntl(sk, F_GETFL);
28 	if ((flags < 0) || (fcntl(sk, F_SETFL, flags | O_NONBLOCK) < 0))
29 		test_error("fcntl()");
30 
31 	if (listen(sk, backlog))
32 		test_error("listen()");
33 
34 	return sk;
35 }
36 
__test_wait_fd(int sk,struct timeval * tv,bool write)37 static int __test_wait_fd(int sk, struct timeval *tv, bool write)
38 {
39 	fd_set fds, efds;
40 	int ret;
41 	socklen_t slen = sizeof(ret);
42 
43 	FD_ZERO(&fds);
44 	FD_SET(sk, &fds);
45 	FD_ZERO(&efds);
46 	FD_SET(sk, &efds);
47 
48 	errno = 0;
49 	if (write)
50 		ret = select(sk + 1, NULL, &fds, &efds, tv);
51 	else
52 		ret = select(sk + 1, &fds, NULL, &efds, tv);
53 	if (ret < 0)
54 		return -errno;
55 	if (ret == 0) {
56 		errno = ETIMEDOUT;
57 		return -ETIMEDOUT;
58 	}
59 
60 	if (getsockopt(sk, SOL_SOCKET, SO_ERROR, &ret, &slen))
61 		return -errno;
62 	if (ret)
63 		return -ret;
64 	return 0;
65 }
66 
test_wait_fd(int sk,time_t sec,bool write)67 int test_wait_fd(int sk, time_t sec, bool write)
68 {
69 	struct timeval tv = { .tv_sec = sec, };
70 
71 	return __test_wait_fd(sk, sec ? &tv : NULL, write);
72 }
73 
__skpair_poll_should_stop(int sk,struct tcp_counters * c,test_cnt condition)74 static bool __skpair_poll_should_stop(int sk, struct tcp_counters *c,
75 				      test_cnt condition)
76 {
77 	struct tcp_counters c2;
78 	test_cnt diff;
79 
80 	if (test_get_tcp_counters(sk, &c2))
81 		test_error("test_get_tcp_counters()");
82 
83 	diff = test_cmp_counters(c, &c2);
84 	test_tcp_counters_free(&c2);
85 	return (diff & condition) == condition;
86 }
87 
88 /* How often wake up and check netns counters & paired (*err) */
89 #define POLL_USEC	150
__test_skpair_poll(int sk,bool write,uint64_t timeout,struct tcp_counters * c,test_cnt cond,volatile int * err)90 static int __test_skpair_poll(int sk, bool write, uint64_t timeout,
91 			      struct tcp_counters *c, test_cnt cond,
92 			      volatile int *err)
93 {
94 	uint64_t t;
95 
96 	for (t = 0; t <= timeout * 1000000; t += POLL_USEC) {
97 		struct timeval tv = { .tv_usec = POLL_USEC, };
98 		int ret;
99 
100 		ret = __test_wait_fd(sk, &tv, write);
101 		if (ret != -ETIMEDOUT)
102 			return ret;
103 		if (c && cond && __skpair_poll_should_stop(sk, c, cond))
104 			break;
105 		if (err && *err)
106 			return *err;
107 	}
108 	if (err)
109 		*err = -ETIMEDOUT;
110 	return -ETIMEDOUT;
111 }
112 
__test_connect_socket(int sk,const char * device,void * addr,size_t addr_sz,bool async)113 int __test_connect_socket(int sk, const char *device,
114 			  void *addr, size_t addr_sz, bool async)
115 {
116 	long flags;
117 	int err;
118 
119 	if (device != NULL) {
120 		err = setsockopt(sk, SOL_SOCKET, SO_BINDTODEVICE, device,
121 				 strlen(device) + 1);
122 		if (err < 0)
123 			test_error("setsockopt(SO_BINDTODEVICE, %s)", device);
124 	}
125 
126 	flags = fcntl(sk, F_GETFL);
127 	if ((flags < 0) || (fcntl(sk, F_SETFL, flags | O_NONBLOCK) < 0))
128 		test_error("fcntl()");
129 
130 	if (connect(sk, addr, addr_sz) < 0) {
131 		if (errno != EINPROGRESS) {
132 			err = -errno;
133 			goto out;
134 		}
135 		if (async)
136 			return sk;
137 		err = test_wait_fd(sk, TEST_TIMEOUT_SEC, 1);
138 		if (err)
139 			goto out;
140 	}
141 	return sk;
142 
143 out:
144 	close(sk);
145 	return err;
146 }
147 
test_skpair_wait_poll(int sk,bool write,test_cnt cond,volatile int * err)148 int test_skpair_wait_poll(int sk, bool write,
149 			  test_cnt cond, volatile int *err)
150 {
151 	struct tcp_counters c;
152 	int ret;
153 
154 	*err = 0;
155 	if (test_get_tcp_counters(sk, &c))
156 		test_error("test_get_tcp_counters()");
157 	synchronize_threads(); /* 1: init skpair & read nscounters */
158 
159 	ret = __test_skpair_poll(sk, write, TEST_TIMEOUT_SEC, &c, cond, err);
160 	test_tcp_counters_free(&c);
161 	return ret;
162 }
163 
_test_skpair_connect_poll(int sk,const char * device,void * addr,size_t addr_sz,test_cnt condition,volatile int * err)164 int _test_skpair_connect_poll(int sk, const char *device,
165 			      void *addr, size_t addr_sz,
166 			      test_cnt condition, volatile int *err)
167 {
168 	struct tcp_counters c;
169 	int ret;
170 
171 	*err = 0;
172 	if (test_get_tcp_counters(sk, &c))
173 		test_error("test_get_tcp_counters()");
174 	synchronize_threads(); /* 1: init skpair & read nscounters */
175 	ret = __test_connect_socket(sk, device, addr, addr_sz, true);
176 	if (ret < 0) {
177 		test_tcp_counters_free(&c);
178 		return (*err = ret);
179 	}
180 	ret = __test_skpair_poll(sk, 1, TEST_TIMEOUT_SEC, &c, condition, err);
181 	if (ret < 0)
182 		close(sk);
183 	test_tcp_counters_free(&c);
184 	return ret;
185 }
186 
__test_set_md5(int sk,void * addr,size_t addr_sz,uint8_t prefix,int vrf,const char * password)187 int __test_set_md5(int sk, void *addr, size_t addr_sz, uint8_t prefix,
188 		   int vrf, const char *password)
189 {
190 	size_t pwd_len = strlen(password);
191 	struct tcp_md5sig md5sig = {};
192 
193 	md5sig.tcpm_keylen = pwd_len;
194 	memcpy(md5sig.tcpm_key, password, pwd_len);
195 	md5sig.tcpm_flags = TCP_MD5SIG_FLAG_PREFIX;
196 	md5sig.tcpm_prefixlen = prefix;
197 	if (vrf >= 0) {
198 		md5sig.tcpm_flags |= TCP_MD5SIG_FLAG_IFINDEX;
199 		md5sig.tcpm_ifindex = (uint8_t)vrf;
200 	}
201 	memcpy(&md5sig.tcpm_addr, addr, addr_sz);
202 
203 	errno = 0;
204 	return setsockopt(sk, IPPROTO_TCP, TCP_MD5SIG_EXT,
205 			&md5sig, sizeof(md5sig));
206 }
207 
208 
test_prepare_key_sockaddr(struct tcp_ao_add * ao,const char * alg,void * addr,size_t addr_sz,bool set_current,bool set_rnext,uint8_t prefix,uint8_t vrf,uint8_t sndid,uint8_t rcvid,uint8_t maclen,uint8_t keyflags,uint8_t keylen,const char * key)209 int test_prepare_key_sockaddr(struct tcp_ao_add *ao, const char *alg,
210 		void *addr, size_t addr_sz, bool set_current, bool set_rnext,
211 		uint8_t prefix, uint8_t vrf, uint8_t sndid, uint8_t rcvid,
212 		uint8_t maclen, uint8_t keyflags,
213 		uint8_t keylen, const char *key)
214 {
215 	memset(ao, 0, sizeof(struct tcp_ao_add));
216 
217 	ao->set_current	= !!set_current;
218 	ao->set_rnext	= !!set_rnext;
219 	ao->prefix	= prefix;
220 	ao->sndid	= sndid;
221 	ao->rcvid	= rcvid;
222 	ao->maclen	= maclen;
223 	ao->keyflags	= keyflags;
224 	ao->keylen	= keylen;
225 	ao->ifindex	= vrf;
226 
227 	memcpy(&ao->addr, addr, addr_sz);
228 
229 	if (strlen(alg) > 64)
230 		return -ENOBUFS;
231 	strncpy(ao->alg_name, alg, 64);
232 
233 	memcpy(ao->key, key,
234 	       (keylen > TCP_AO_MAXKEYLEN) ? TCP_AO_MAXKEYLEN : keylen);
235 	return 0;
236 }
237 
test_get_ao_keys_nr(int sk)238 static int test_get_ao_keys_nr(int sk)
239 {
240 	struct tcp_ao_getsockopt tmp = {};
241 	socklen_t tmp_sz = sizeof(tmp);
242 	int ret;
243 
244 	tmp.nkeys  = 1;
245 	tmp.get_all = 1;
246 
247 	ret = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, &tmp, &tmp_sz);
248 	if (ret)
249 		return -errno;
250 	return (int)tmp.nkeys;
251 }
252 
test_get_one_ao(int sk,struct tcp_ao_getsockopt * out,void * addr,size_t addr_sz,uint8_t prefix,uint8_t sndid,uint8_t rcvid)253 int test_get_one_ao(int sk, struct tcp_ao_getsockopt *out,
254 		void *addr, size_t addr_sz, uint8_t prefix,
255 		uint8_t sndid, uint8_t rcvid)
256 {
257 	struct tcp_ao_getsockopt tmp = {};
258 	socklen_t tmp_sz = sizeof(tmp);
259 	int ret;
260 
261 	memcpy(&tmp.addr, addr, addr_sz);
262 	tmp.prefix = prefix;
263 	tmp.sndid  = sndid;
264 	tmp.rcvid  = rcvid;
265 	tmp.nkeys  = 1;
266 
267 	ret = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, &tmp, &tmp_sz);
268 	if (ret)
269 		return ret;
270 	if (tmp.nkeys != 1)
271 		return -E2BIG;
272 	*out = tmp;
273 	return 0;
274 }
275 
test_get_ao_info(int sk,struct tcp_ao_info_opt * out)276 int test_get_ao_info(int sk, struct tcp_ao_info_opt *out)
277 {
278 	socklen_t sz = sizeof(*out);
279 
280 	out->reserved = 0;
281 	out->reserved2 = 0;
282 	if (getsockopt(sk, IPPROTO_TCP, TCP_AO_INFO, out, &sz))
283 		return -errno;
284 	if (sz != sizeof(*out))
285 		return -EMSGSIZE;
286 	return 0;
287 }
288 
test_set_ao_info(int sk,struct tcp_ao_info_opt * in)289 int test_set_ao_info(int sk, struct tcp_ao_info_opt *in)
290 {
291 	socklen_t sz = sizeof(*in);
292 
293 	in->reserved = 0;
294 	in->reserved2 = 0;
295 	if (setsockopt(sk, IPPROTO_TCP, TCP_AO_INFO, in, sz))
296 		return -errno;
297 	return 0;
298 }
299 
test_cmp_getsockopt_setsockopt(const struct tcp_ao_add * a,const struct tcp_ao_getsockopt * b)300 int test_cmp_getsockopt_setsockopt(const struct tcp_ao_add *a,
301 				   const struct tcp_ao_getsockopt *b)
302 {
303 	bool is_kdf_aes_128_cmac = false;
304 	bool is_cmac_aes = false;
305 
306 	if (!strcmp("cmac(aes128)", a->alg_name)) {
307 		is_kdf_aes_128_cmac = (a->keylen != 16);
308 		is_cmac_aes = true;
309 	}
310 
311 #define __cmp_ao(member)						\
312 do {									\
313 	if (b->member != a->member) {					\
314 		test_fail("getsockopt(): " __stringify(member) " %u != %u",	\
315 				b->member, a->member);			\
316 		return -1;						\
317 	}								\
318 } while(0)
319 	__cmp_ao(sndid);
320 	__cmp_ao(rcvid);
321 	__cmp_ao(prefix);
322 	__cmp_ao(keyflags);
323 	__cmp_ao(ifindex);
324 	if (a->maclen) {
325 		__cmp_ao(maclen);
326 	} else if (b->maclen != 12) {
327 		test_fail("getsockopt(): expected default maclen 12, but it's %u",
328 				b->maclen);
329 		return -1;
330 	}
331 	if (!is_kdf_aes_128_cmac) {
332 		__cmp_ao(keylen);
333 	} else if (b->keylen != 16) {
334 		test_fail("getsockopt(): expected keylen 16 for cmac(aes128), but it's %u",
335 				b->keylen);
336 		return -1;
337 	}
338 #undef __cmp_ao
339 	if (!is_kdf_aes_128_cmac && memcmp(b->key, a->key, a->keylen)) {
340 		test_fail("getsockopt(): returned key is different `%s' != `%s'",
341 				b->key, a->key);
342 		return -1;
343 	}
344 	if (memcmp(&b->addr, &a->addr, sizeof(b->addr))) {
345 		test_fail("getsockopt(): returned address is different");
346 		return -1;
347 	}
348 	if (!is_cmac_aes && strcmp(b->alg_name, a->alg_name)) {
349 		test_fail("getsockopt(): returned algorithm %s is different than %s", b->alg_name, a->alg_name);
350 		return -1;
351 	}
352 	if (is_cmac_aes && strcmp(b->alg_name, "cmac(aes)")) {
353 		test_fail("getsockopt(): returned algorithm %s is different than cmac(aes)", b->alg_name);
354 		return -1;
355 	}
356 	/* For a established key rotation test don't add a key with
357 	 * set_current = 1, as it's likely to change by peer's request;
358 	 * rather use setsockopt(TCP_AO_INFO)
359 	 */
360 	if (a->set_current != b->is_current) {
361 		test_fail("getsockopt(): returned key is not Current_key");
362 		return -1;
363 	}
364 	if (a->set_rnext != b->is_rnext) {
365 		test_fail("getsockopt(): returned key is not RNext_key");
366 		return -1;
367 	}
368 
369 	return 0;
370 }
371 
test_cmp_getsockopt_setsockopt_ao(const struct tcp_ao_info_opt * a,const struct tcp_ao_info_opt * b)372 int test_cmp_getsockopt_setsockopt_ao(const struct tcp_ao_info_opt *a,
373 				      const struct tcp_ao_info_opt *b)
374 {
375 	/* No check for ::current_key, as it may change by the peer */
376 	if (a->ao_required != b->ao_required) {
377 		test_fail("getsockopt(): returned ao doesn't have ao_required");
378 		return -1;
379 	}
380 	if (a->accept_icmps != b->accept_icmps) {
381 		test_fail("getsockopt(): returned ao doesn't accept ICMPs");
382 		return -1;
383 	}
384 	if (a->set_rnext && a->rnext != b->rnext) {
385 		test_fail("getsockopt(): RNext KeyID has changed");
386 		return -1;
387 	}
388 #define __cmp_cnt(member)						\
389 do {									\
390 	if (b->member != a->member) {					\
391 		test_fail("getsockopt(): " __stringify(member) " %llu != %llu",	\
392 				b->member, a->member);			\
393 		return -1;						\
394 	}								\
395 } while(0)
396 	if (a->set_counters) {
397 		__cmp_cnt(pkt_good);
398 		__cmp_cnt(pkt_bad);
399 		__cmp_cnt(pkt_key_not_found);
400 		__cmp_cnt(pkt_ao_required);
401 		__cmp_cnt(pkt_dropped_icmp);
402 	}
403 #undef __cmp_cnt
404 	return 0;
405 }
406 
test_get_tcp_counters(int sk,struct tcp_counters * out)407 int test_get_tcp_counters(int sk, struct tcp_counters *out)
408 {
409 	struct tcp_ao_getsockopt *key_dump;
410 	socklen_t key_dump_sz = sizeof(*key_dump);
411 	struct tcp_ao_info_opt info = {};
412 	bool c1, c2, c3, c4, c5, c6, c7, c8;
413 	struct netstat *ns;
414 	int err, nr_keys;
415 
416 	memset(out, 0, sizeof(*out));
417 
418 	/* per-netns */
419 	ns = netstat_read();
420 	out->ao.netns_ao_good = netstat_get(ns, "TCPAOGood", &c1);
421 	out->ao.netns_ao_bad = netstat_get(ns, "TCPAOBad", &c2);
422 	out->ao.netns_ao_key_not_found = netstat_get(ns, "TCPAOKeyNotFound", &c3);
423 	out->ao.netns_ao_required = netstat_get(ns, "TCPAORequired", &c4);
424 	out->ao.netns_ao_dropped_icmp = netstat_get(ns, "TCPAODroppedIcmps", &c5);
425 	out->netns_md5_notfound = netstat_get(ns, "TCPMD5NotFound", &c6);
426 	out->netns_md5_unexpected = netstat_get(ns, "TCPMD5Unexpected", &c7);
427 	out->netns_md5_failure = netstat_get(ns, "TCPMD5Failure", &c8);
428 	netstat_free(ns);
429 	if (c1 || c2 || c3 || c4 || c5 || c6 || c7 || c8)
430 		return -EOPNOTSUPP;
431 
432 	err = test_get_ao_info(sk, &info);
433 	if (err == -ENOENT)
434 		return 0;
435 	if (err)
436 		return err;
437 
438 	/* per-socket */
439 	out->ao.ao_info_pkt_good = info.pkt_good;
440 	out->ao.ao_info_pkt_bad = info.pkt_bad;
441 	out->ao.ao_info_pkt_key_not_found = info.pkt_key_not_found;
442 	out->ao.ao_info_pkt_ao_required = info.pkt_ao_required;
443 	out->ao.ao_info_pkt_dropped_icmp = info.pkt_dropped_icmp;
444 
445 	/* per-key */
446 	nr_keys = test_get_ao_keys_nr(sk);
447 	if (nr_keys < 0)
448 		return nr_keys;
449 	if (nr_keys == 0)
450 		test_error("test_get_ao_keys_nr() == 0");
451 	out->ao.nr_keys = (size_t)nr_keys;
452 	key_dump = calloc(nr_keys, key_dump_sz);
453 	if (!key_dump)
454 		return -errno;
455 
456 	key_dump[0].nkeys = nr_keys;
457 	key_dump[0].get_all = 1;
458 	err = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS,
459 			 key_dump, &key_dump_sz);
460 	if (err) {
461 		free(key_dump);
462 		return -errno;
463 	}
464 
465 	out->ao.key_cnts = calloc(nr_keys, sizeof(out->ao.key_cnts[0]));
466 	if (!out->ao.key_cnts) {
467 		free(key_dump);
468 		return -errno;
469 	}
470 
471 	while (nr_keys--) {
472 		out->ao.key_cnts[nr_keys].sndid = key_dump[nr_keys].sndid;
473 		out->ao.key_cnts[nr_keys].rcvid = key_dump[nr_keys].rcvid;
474 		out->ao.key_cnts[nr_keys].pkt_good = key_dump[nr_keys].pkt_good;
475 		out->ao.key_cnts[nr_keys].pkt_bad = key_dump[nr_keys].pkt_bad;
476 	}
477 	free(key_dump);
478 
479 	return 0;
480 }
481 
test_cmp_counters(struct tcp_counters * before,struct tcp_counters * after)482 test_cnt test_cmp_counters(struct tcp_counters *before,
483 			   struct tcp_counters *after)
484 {
485 #define __cmp(cnt, e_cnt)						\
486 do {									\
487 	if (before->cnt > after->cnt)					\
488 		test_error("counter " __stringify(cnt) " decreased");	\
489 	if (before->cnt != after->cnt)					\
490 		ret |= e_cnt;						\
491 } while (0)
492 
493 	test_cnt ret = 0;
494 	size_t i;
495 
496 	if (before->ao.nr_keys != after->ao.nr_keys)
497 		test_error("the number of keys has changed");
498 
499 	_for_each_counter(__cmp);
500 
501 	i = before->ao.nr_keys;
502 	while (i--) {
503 		__cmp(ao.key_cnts[i].pkt_good, TEST_CNT_KEY_GOOD);
504 		__cmp(ao.key_cnts[i].pkt_bad, TEST_CNT_KEY_BAD);
505 	}
506 #undef __cmp
507 	return ret;
508 }
509 
test_assert_counters_sk(const char * tst_name,struct tcp_counters * before,struct tcp_counters * after,test_cnt expected)510 int test_assert_counters_sk(const char *tst_name,
511 			    struct tcp_counters *before,
512 			    struct tcp_counters *after,
513 			    test_cnt expected)
514 {
515 #define __cmp_ao(cnt, e_cnt)						\
516 do {									\
517 	if (before->cnt > after->cnt) {					\
518 		test_fail("%s: Decreased counter " __stringify(cnt) " %" PRIu64 " > %" PRIu64, \
519 			  tst_name ?: "", before->cnt, after->cnt);	\
520 		return -1;						\
521 	}								\
522 	if ((before->cnt != after->cnt) != !!(expected & e_cnt)) {	\
523 		test_fail("%s: Counter " __stringify(cnt) " was %sexpected to increase %" PRIu64 " => %" PRIu64, \
524 			  tst_name ?: "", (expected & e_cnt) ? "" : "not ",	\
525 			  before->cnt, after->cnt);			\
526 		return -1;						\
527 	}								\
528 } while (0)
529 
530 	errno = 0;
531 	_for_each_counter(__cmp_ao);
532 	return 0;
533 #undef __cmp_ao
534 }
535 
test_assert_counters_key(const char * tst_name,struct tcp_ao_counters * before,struct tcp_ao_counters * after,test_cnt expected,int sndid,int rcvid)536 int test_assert_counters_key(const char *tst_name,
537 			     struct tcp_ao_counters *before,
538 			     struct tcp_ao_counters *after,
539 			     test_cnt expected, int sndid, int rcvid)
540 {
541 	size_t i;
542 #define __cmp_ao(i, cnt, e_cnt)					\
543 do {									\
544 	if (before->key_cnts[i].cnt > after->key_cnts[i].cnt) {		\
545 		test_fail("%s: Decreased counter " __stringify(cnt) " %" PRIu64 " > %" PRIu64 " for key %u:%u", \
546 			  tst_name ?: "", before->key_cnts[i].cnt,	\
547 			  after->key_cnts[i].cnt,			\
548 			  before->key_cnts[i].sndid,			\
549 			  before->key_cnts[i].rcvid);			\
550 		return -1;						\
551 	}								\
552 	if ((before->key_cnts[i].cnt != after->key_cnts[i].cnt) != !!(expected & e_cnt)) {		\
553 		test_fail("%s: Counter " __stringify(cnt) " was %sexpected to increase %" PRIu64 " => %" PRIu64 " for key %u:%u", \
554 			  tst_name ?: "", (expected & e_cnt) ? "" : "not ",\
555 			  before->key_cnts[i].cnt,			\
556 			  after->key_cnts[i].cnt,			\
557 			  before->key_cnts[i].sndid,			\
558 			  before->key_cnts[i].rcvid);			\
559 		return -1;						\
560 	}								\
561 } while (0)
562 
563 	if (before->nr_keys != after->nr_keys) {
564 		test_fail("%s: Keys changed on the socket %zu != %zu",
565 			  tst_name, before->nr_keys, after->nr_keys);
566 		return -1;
567 	}
568 
569 	/* per-key */
570 	i = before->nr_keys;
571 	while (i--) {
572 		if (sndid >= 0 && before->key_cnts[i].sndid != sndid)
573 			continue;
574 		if (rcvid >= 0 && before->key_cnts[i].rcvid != rcvid)
575 			continue;
576 		__cmp_ao(i, pkt_good, TEST_CNT_KEY_GOOD);
577 		__cmp_ao(i, pkt_bad, TEST_CNT_KEY_BAD);
578 	}
579 	return 0;
580 #undef __cmp_ao
581 }
582 
test_tcp_counters_free(struct tcp_counters * cnts)583 void test_tcp_counters_free(struct tcp_counters *cnts)
584 {
585 	free(cnts->ao.key_cnts);
586 }
587 
588 #define TEST_BUF_SIZE 4096
_test_server_run(int sk,ssize_t quota,struct tcp_counters * c,test_cnt cond,volatile int * err,time_t timeout_sec)589 static ssize_t _test_server_run(int sk, ssize_t quota, struct tcp_counters *c,
590 				test_cnt cond, volatile int *err,
591 				time_t timeout_sec)
592 {
593 	ssize_t total = 0;
594 
595 	do {
596 		char buf[TEST_BUF_SIZE];
597 		ssize_t bytes, sent;
598 		int ret;
599 
600 		ret = __test_skpair_poll(sk, 0, timeout_sec, c, cond, err);
601 		if (ret)
602 			return ret;
603 
604 		bytes = recv(sk, buf, sizeof(buf), 0);
605 
606 		if (bytes < 0)
607 			test_error("recv(): %zd", bytes);
608 		if (bytes == 0)
609 			break;
610 
611 		ret = __test_skpair_poll(sk, 1, timeout_sec, c, cond, err);
612 		if (ret)
613 			return ret;
614 
615 		sent = send(sk, buf, bytes, 0);
616 		if (sent == 0)
617 			break;
618 		if (sent != bytes)
619 			test_error("send()");
620 		total += bytes;
621 	} while (!quota || total < quota);
622 
623 	return total;
624 }
625 
test_server_run(int sk,ssize_t quota,time_t timeout_sec)626 ssize_t test_server_run(int sk, ssize_t quota, time_t timeout_sec)
627 {
628 	return _test_server_run(sk, quota, NULL, 0, NULL,
629 				timeout_sec ?: TEST_TIMEOUT_SEC);
630 }
631 
test_skpair_server(int sk,ssize_t quota,test_cnt cond,volatile int * err)632 int test_skpair_server(int sk, ssize_t quota, test_cnt cond, volatile int *err)
633 {
634 	struct tcp_counters c;
635 	ssize_t ret;
636 
637 	*err = 0;
638 	if (test_get_tcp_counters(sk, &c))
639 		test_error("test_get_tcp_counters()");
640 	synchronize_threads(); /* 1: init skpair & read nscounters */
641 
642 	ret = _test_server_run(sk, quota, &c, cond, err, TEST_TIMEOUT_SEC);
643 	test_tcp_counters_free(&c);
644 	return ret;
645 }
646 
test_client_loop(int sk,size_t buf_sz,const size_t msg_len,struct tcp_counters * c,test_cnt cond,volatile int * err)647 static ssize_t test_client_loop(int sk, size_t buf_sz, const size_t msg_len,
648 				struct tcp_counters *c, test_cnt cond,
649 				volatile int *err)
650 {
651 	char msg[msg_len];
652 	int nodelay = 1;
653 	char *buf;
654 	size_t i;
655 
656 	buf = alloca(buf_sz);
657 	if (!buf)
658 		return -ENOMEM;
659 	randomize_buffer(buf, buf_sz);
660 
661 	if (setsockopt(sk, IPPROTO_TCP, TCP_NODELAY, &nodelay, sizeof(nodelay)))
662 		test_error("setsockopt(TCP_NODELAY)");
663 
664 	for (i = 0; i < buf_sz; i += min(msg_len, buf_sz - i)) {
665 		size_t sent, bytes = min(msg_len, buf_sz - i);
666 		int ret;
667 
668 		ret = __test_skpair_poll(sk, 1, TEST_TIMEOUT_SEC, c, cond, err);
669 		if (ret)
670 			return ret;
671 
672 		sent = send(sk, buf + i, bytes, 0);
673 		if (sent == 0)
674 			break;
675 		if (sent != bytes)
676 			test_error("send()");
677 
678 		bytes = 0;
679 		do {
680 			ssize_t got;
681 
682 			ret = __test_skpair_poll(sk, 0, TEST_TIMEOUT_SEC,
683 						 c, cond, err);
684 			if (ret)
685 				return ret;
686 
687 			got = recv(sk, msg + bytes, sizeof(msg) - bytes, 0);
688 			if (got <= 0)
689 				return i;
690 			bytes += got;
691 		} while (bytes < sent);
692 		if (bytes > sent)
693 			test_error("recv(): %zd > %zd", bytes, sent);
694 		if (memcmp(buf + i, msg, bytes) != 0) {
695 			test_fail("received message differs");
696 			return -1;
697 		}
698 	}
699 	return i;
700 }
701 
test_client_verify(int sk,const size_t msg_len,const size_t nr)702 int test_client_verify(int sk, const size_t msg_len, const size_t nr)
703 {
704 	size_t buf_sz = msg_len * nr;
705 	ssize_t ret;
706 
707 	ret = test_client_loop(sk, buf_sz, msg_len, NULL, 0, NULL);
708 	if (ret < 0)
709 		return (int)ret;
710 	return ret != buf_sz ? -1 : 0;
711 }
712 
test_skpair_client(int sk,const size_t msg_len,const size_t nr,test_cnt cond,volatile int * err)713 int test_skpair_client(int sk, const size_t msg_len, const size_t nr,
714 		       test_cnt cond, volatile int *err)
715 {
716 	struct tcp_counters c;
717 	size_t buf_sz = msg_len * nr;
718 	ssize_t ret;
719 
720 	*err = 0;
721 	if (test_get_tcp_counters(sk, &c))
722 		test_error("test_get_tcp_counters()");
723 	synchronize_threads(); /* 1: init skpair & read nscounters */
724 
725 	ret = test_client_loop(sk, buf_sz, msg_len, &c, cond, err);
726 	test_tcp_counters_free(&c);
727 	if (ret < 0)
728 		return (int)ret;
729 	return ret != buf_sz ? -1 : 0;
730 }
731