1 // SPDX-License-Identifier: GPL-2.0
2 #include <alloca.h>
3 #include <fcntl.h>
4 #include <inttypes.h>
5 #include <string.h>
6 #include "../../../../../include/linux/kernel.h"
7 #include "../../../../../include/linux/stringify.h"
8 #include "aolib.h"
9
10 const unsigned int test_server_port = 7010;
__test_listen_socket(int backlog,void * addr,size_t addr_sz)11 int __test_listen_socket(int backlog, void *addr, size_t addr_sz)
12 {
13 int err, sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
14 long flags;
15
16 if (sk < 0)
17 test_error("socket()");
18
19 err = setsockopt(sk, SOL_SOCKET, SO_BINDTODEVICE, veth_name,
20 strlen(veth_name) + 1);
21 if (err < 0)
22 test_error("setsockopt(SO_BINDTODEVICE)");
23
24 if (bind(sk, (struct sockaddr *)addr, addr_sz) < 0)
25 test_error("bind()");
26
27 flags = fcntl(sk, F_GETFL);
28 if ((flags < 0) || (fcntl(sk, F_SETFL, flags | O_NONBLOCK) < 0))
29 test_error("fcntl()");
30
31 if (listen(sk, backlog))
32 test_error("listen()");
33
34 return sk;
35 }
36
__test_wait_fd(int sk,struct timeval * tv,bool write)37 static int __test_wait_fd(int sk, struct timeval *tv, bool write)
38 {
39 fd_set fds, efds;
40 int ret;
41 socklen_t slen = sizeof(ret);
42
43 FD_ZERO(&fds);
44 FD_SET(sk, &fds);
45 FD_ZERO(&efds);
46 FD_SET(sk, &efds);
47
48 errno = 0;
49 if (write)
50 ret = select(sk + 1, NULL, &fds, &efds, tv);
51 else
52 ret = select(sk + 1, &fds, NULL, &efds, tv);
53 if (ret < 0)
54 return -errno;
55 if (ret == 0) {
56 errno = ETIMEDOUT;
57 return -ETIMEDOUT;
58 }
59
60 if (getsockopt(sk, SOL_SOCKET, SO_ERROR, &ret, &slen))
61 return -errno;
62 if (ret)
63 return -ret;
64 return 0;
65 }
66
test_wait_fd(int sk,time_t sec,bool write)67 int test_wait_fd(int sk, time_t sec, bool write)
68 {
69 struct timeval tv = { .tv_sec = sec, };
70
71 return __test_wait_fd(sk, sec ? &tv : NULL, write);
72 }
73
__skpair_poll_should_stop(int sk,struct tcp_counters * c,test_cnt condition)74 static bool __skpair_poll_should_stop(int sk, struct tcp_counters *c,
75 test_cnt condition)
76 {
77 struct tcp_counters c2;
78 test_cnt diff;
79
80 if (test_get_tcp_counters(sk, &c2))
81 test_error("test_get_tcp_counters()");
82
83 diff = test_cmp_counters(c, &c2);
84 test_tcp_counters_free(&c2);
85 return (diff & condition) == condition;
86 }
87
88 /* How often wake up and check netns counters & paired (*err) */
89 #define POLL_USEC 150
__test_skpair_poll(int sk,bool write,uint64_t timeout,struct tcp_counters * c,test_cnt cond,volatile int * err)90 static int __test_skpair_poll(int sk, bool write, uint64_t timeout,
91 struct tcp_counters *c, test_cnt cond,
92 volatile int *err)
93 {
94 uint64_t t;
95
96 for (t = 0; t <= timeout * 1000000; t += POLL_USEC) {
97 struct timeval tv = { .tv_usec = POLL_USEC, };
98 int ret;
99
100 ret = __test_wait_fd(sk, &tv, write);
101 if (ret != -ETIMEDOUT)
102 return ret;
103 if (c && cond && __skpair_poll_should_stop(sk, c, cond))
104 break;
105 if (err && *err)
106 return *err;
107 }
108 if (err)
109 *err = -ETIMEDOUT;
110 return -ETIMEDOUT;
111 }
112
__test_connect_socket(int sk,const char * device,void * addr,size_t addr_sz,bool async)113 int __test_connect_socket(int sk, const char *device,
114 void *addr, size_t addr_sz, bool async)
115 {
116 long flags;
117 int err;
118
119 if (device != NULL) {
120 err = setsockopt(sk, SOL_SOCKET, SO_BINDTODEVICE, device,
121 strlen(device) + 1);
122 if (err < 0)
123 test_error("setsockopt(SO_BINDTODEVICE, %s)", device);
124 }
125
126 flags = fcntl(sk, F_GETFL);
127 if ((flags < 0) || (fcntl(sk, F_SETFL, flags | O_NONBLOCK) < 0))
128 test_error("fcntl()");
129
130 if (connect(sk, addr, addr_sz) < 0) {
131 if (errno != EINPROGRESS) {
132 err = -errno;
133 goto out;
134 }
135 if (async)
136 return sk;
137 err = test_wait_fd(sk, TEST_TIMEOUT_SEC, 1);
138 if (err)
139 goto out;
140 }
141 return sk;
142
143 out:
144 close(sk);
145 return err;
146 }
147
test_skpair_wait_poll(int sk,bool write,test_cnt cond,volatile int * err)148 int test_skpair_wait_poll(int sk, bool write,
149 test_cnt cond, volatile int *err)
150 {
151 struct tcp_counters c;
152 int ret;
153
154 *err = 0;
155 if (test_get_tcp_counters(sk, &c))
156 test_error("test_get_tcp_counters()");
157 synchronize_threads(); /* 1: init skpair & read nscounters */
158
159 ret = __test_skpair_poll(sk, write, TEST_TIMEOUT_SEC, &c, cond, err);
160 test_tcp_counters_free(&c);
161 return ret;
162 }
163
_test_skpair_connect_poll(int sk,const char * device,void * addr,size_t addr_sz,test_cnt condition,volatile int * err)164 int _test_skpair_connect_poll(int sk, const char *device,
165 void *addr, size_t addr_sz,
166 test_cnt condition, volatile int *err)
167 {
168 struct tcp_counters c;
169 int ret;
170
171 *err = 0;
172 if (test_get_tcp_counters(sk, &c))
173 test_error("test_get_tcp_counters()");
174 synchronize_threads(); /* 1: init skpair & read nscounters */
175 ret = __test_connect_socket(sk, device, addr, addr_sz, true);
176 if (ret < 0) {
177 test_tcp_counters_free(&c);
178 return (*err = ret);
179 }
180 ret = __test_skpair_poll(sk, 1, TEST_TIMEOUT_SEC, &c, condition, err);
181 if (ret < 0)
182 close(sk);
183 test_tcp_counters_free(&c);
184 return ret;
185 }
186
__test_set_md5(int sk,void * addr,size_t addr_sz,uint8_t prefix,int vrf,const char * password)187 int __test_set_md5(int sk, void *addr, size_t addr_sz, uint8_t prefix,
188 int vrf, const char *password)
189 {
190 size_t pwd_len = strlen(password);
191 struct tcp_md5sig md5sig = {};
192
193 md5sig.tcpm_keylen = pwd_len;
194 memcpy(md5sig.tcpm_key, password, pwd_len);
195 md5sig.tcpm_flags = TCP_MD5SIG_FLAG_PREFIX;
196 md5sig.tcpm_prefixlen = prefix;
197 if (vrf >= 0) {
198 md5sig.tcpm_flags |= TCP_MD5SIG_FLAG_IFINDEX;
199 md5sig.tcpm_ifindex = (uint8_t)vrf;
200 }
201 memcpy(&md5sig.tcpm_addr, addr, addr_sz);
202
203 errno = 0;
204 return setsockopt(sk, IPPROTO_TCP, TCP_MD5SIG_EXT,
205 &md5sig, sizeof(md5sig));
206 }
207
208
test_prepare_key_sockaddr(struct tcp_ao_add * ao,const char * alg,void * addr,size_t addr_sz,bool set_current,bool set_rnext,uint8_t prefix,uint8_t vrf,uint8_t sndid,uint8_t rcvid,uint8_t maclen,uint8_t keyflags,uint8_t keylen,const char * key)209 int test_prepare_key_sockaddr(struct tcp_ao_add *ao, const char *alg,
210 void *addr, size_t addr_sz, bool set_current, bool set_rnext,
211 uint8_t prefix, uint8_t vrf, uint8_t sndid, uint8_t rcvid,
212 uint8_t maclen, uint8_t keyflags,
213 uint8_t keylen, const char *key)
214 {
215 memset(ao, 0, sizeof(struct tcp_ao_add));
216
217 ao->set_current = !!set_current;
218 ao->set_rnext = !!set_rnext;
219 ao->prefix = prefix;
220 ao->sndid = sndid;
221 ao->rcvid = rcvid;
222 ao->maclen = maclen;
223 ao->keyflags = keyflags;
224 ao->keylen = keylen;
225 ao->ifindex = vrf;
226
227 memcpy(&ao->addr, addr, addr_sz);
228
229 if (strlen(alg) > 64)
230 return -ENOBUFS;
231 strncpy(ao->alg_name, alg, 64);
232
233 memcpy(ao->key, key,
234 (keylen > TCP_AO_MAXKEYLEN) ? TCP_AO_MAXKEYLEN : keylen);
235 return 0;
236 }
237
test_get_ao_keys_nr(int sk)238 static int test_get_ao_keys_nr(int sk)
239 {
240 struct tcp_ao_getsockopt tmp = {};
241 socklen_t tmp_sz = sizeof(tmp);
242 int ret;
243
244 tmp.nkeys = 1;
245 tmp.get_all = 1;
246
247 ret = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, &tmp, &tmp_sz);
248 if (ret)
249 return -errno;
250 return (int)tmp.nkeys;
251 }
252
test_get_one_ao(int sk,struct tcp_ao_getsockopt * out,void * addr,size_t addr_sz,uint8_t prefix,uint8_t sndid,uint8_t rcvid)253 int test_get_one_ao(int sk, struct tcp_ao_getsockopt *out,
254 void *addr, size_t addr_sz, uint8_t prefix,
255 uint8_t sndid, uint8_t rcvid)
256 {
257 struct tcp_ao_getsockopt tmp = {};
258 socklen_t tmp_sz = sizeof(tmp);
259 int ret;
260
261 memcpy(&tmp.addr, addr, addr_sz);
262 tmp.prefix = prefix;
263 tmp.sndid = sndid;
264 tmp.rcvid = rcvid;
265 tmp.nkeys = 1;
266
267 ret = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, &tmp, &tmp_sz);
268 if (ret)
269 return ret;
270 if (tmp.nkeys != 1)
271 return -E2BIG;
272 *out = tmp;
273 return 0;
274 }
275
test_get_ao_info(int sk,struct tcp_ao_info_opt * out)276 int test_get_ao_info(int sk, struct tcp_ao_info_opt *out)
277 {
278 socklen_t sz = sizeof(*out);
279
280 out->reserved = 0;
281 out->reserved2 = 0;
282 if (getsockopt(sk, IPPROTO_TCP, TCP_AO_INFO, out, &sz))
283 return -errno;
284 if (sz != sizeof(*out))
285 return -EMSGSIZE;
286 return 0;
287 }
288
test_set_ao_info(int sk,struct tcp_ao_info_opt * in)289 int test_set_ao_info(int sk, struct tcp_ao_info_opt *in)
290 {
291 socklen_t sz = sizeof(*in);
292
293 in->reserved = 0;
294 in->reserved2 = 0;
295 if (setsockopt(sk, IPPROTO_TCP, TCP_AO_INFO, in, sz))
296 return -errno;
297 return 0;
298 }
299
test_cmp_getsockopt_setsockopt(const struct tcp_ao_add * a,const struct tcp_ao_getsockopt * b)300 int test_cmp_getsockopt_setsockopt(const struct tcp_ao_add *a,
301 const struct tcp_ao_getsockopt *b)
302 {
303 bool is_kdf_aes_128_cmac = false;
304 bool is_cmac_aes = false;
305
306 if (!strcmp("cmac(aes128)", a->alg_name)) {
307 is_kdf_aes_128_cmac = (a->keylen != 16);
308 is_cmac_aes = true;
309 }
310
311 #define __cmp_ao(member) \
312 do { \
313 if (b->member != a->member) { \
314 test_fail("getsockopt(): " __stringify(member) " %u != %u", \
315 b->member, a->member); \
316 return -1; \
317 } \
318 } while(0)
319 __cmp_ao(sndid);
320 __cmp_ao(rcvid);
321 __cmp_ao(prefix);
322 __cmp_ao(keyflags);
323 __cmp_ao(ifindex);
324 if (a->maclen) {
325 __cmp_ao(maclen);
326 } else if (b->maclen != 12) {
327 test_fail("getsockopt(): expected default maclen 12, but it's %u",
328 b->maclen);
329 return -1;
330 }
331 if (!is_kdf_aes_128_cmac) {
332 __cmp_ao(keylen);
333 } else if (b->keylen != 16) {
334 test_fail("getsockopt(): expected keylen 16 for cmac(aes128), but it's %u",
335 b->keylen);
336 return -1;
337 }
338 #undef __cmp_ao
339 if (!is_kdf_aes_128_cmac && memcmp(b->key, a->key, a->keylen)) {
340 test_fail("getsockopt(): returned key is different `%s' != `%s'",
341 b->key, a->key);
342 return -1;
343 }
344 if (memcmp(&b->addr, &a->addr, sizeof(b->addr))) {
345 test_fail("getsockopt(): returned address is different");
346 return -1;
347 }
348 if (!is_cmac_aes && strcmp(b->alg_name, a->alg_name)) {
349 test_fail("getsockopt(): returned algorithm %s is different than %s", b->alg_name, a->alg_name);
350 return -1;
351 }
352 if (is_cmac_aes && strcmp(b->alg_name, "cmac(aes)")) {
353 test_fail("getsockopt(): returned algorithm %s is different than cmac(aes)", b->alg_name);
354 return -1;
355 }
356 /* For a established key rotation test don't add a key with
357 * set_current = 1, as it's likely to change by peer's request;
358 * rather use setsockopt(TCP_AO_INFO)
359 */
360 if (a->set_current != b->is_current) {
361 test_fail("getsockopt(): returned key is not Current_key");
362 return -1;
363 }
364 if (a->set_rnext != b->is_rnext) {
365 test_fail("getsockopt(): returned key is not RNext_key");
366 return -1;
367 }
368
369 return 0;
370 }
371
test_cmp_getsockopt_setsockopt_ao(const struct tcp_ao_info_opt * a,const struct tcp_ao_info_opt * b)372 int test_cmp_getsockopt_setsockopt_ao(const struct tcp_ao_info_opt *a,
373 const struct tcp_ao_info_opt *b)
374 {
375 /* No check for ::current_key, as it may change by the peer */
376 if (a->ao_required != b->ao_required) {
377 test_fail("getsockopt(): returned ao doesn't have ao_required");
378 return -1;
379 }
380 if (a->accept_icmps != b->accept_icmps) {
381 test_fail("getsockopt(): returned ao doesn't accept ICMPs");
382 return -1;
383 }
384 if (a->set_rnext && a->rnext != b->rnext) {
385 test_fail("getsockopt(): RNext KeyID has changed");
386 return -1;
387 }
388 #define __cmp_cnt(member) \
389 do { \
390 if (b->member != a->member) { \
391 test_fail("getsockopt(): " __stringify(member) " %llu != %llu", \
392 b->member, a->member); \
393 return -1; \
394 } \
395 } while(0)
396 if (a->set_counters) {
397 __cmp_cnt(pkt_good);
398 __cmp_cnt(pkt_bad);
399 __cmp_cnt(pkt_key_not_found);
400 __cmp_cnt(pkt_ao_required);
401 __cmp_cnt(pkt_dropped_icmp);
402 }
403 #undef __cmp_cnt
404 return 0;
405 }
406
test_get_tcp_counters(int sk,struct tcp_counters * out)407 int test_get_tcp_counters(int sk, struct tcp_counters *out)
408 {
409 struct tcp_ao_getsockopt *key_dump;
410 socklen_t key_dump_sz = sizeof(*key_dump);
411 struct tcp_ao_info_opt info = {};
412 bool c1, c2, c3, c4, c5, c6, c7, c8;
413 struct netstat *ns;
414 int err, nr_keys;
415
416 memset(out, 0, sizeof(*out));
417
418 /* per-netns */
419 ns = netstat_read();
420 out->ao.netns_ao_good = netstat_get(ns, "TCPAOGood", &c1);
421 out->ao.netns_ao_bad = netstat_get(ns, "TCPAOBad", &c2);
422 out->ao.netns_ao_key_not_found = netstat_get(ns, "TCPAOKeyNotFound", &c3);
423 out->ao.netns_ao_required = netstat_get(ns, "TCPAORequired", &c4);
424 out->ao.netns_ao_dropped_icmp = netstat_get(ns, "TCPAODroppedIcmps", &c5);
425 out->netns_md5_notfound = netstat_get(ns, "TCPMD5NotFound", &c6);
426 out->netns_md5_unexpected = netstat_get(ns, "TCPMD5Unexpected", &c7);
427 out->netns_md5_failure = netstat_get(ns, "TCPMD5Failure", &c8);
428 netstat_free(ns);
429 if (c1 || c2 || c3 || c4 || c5 || c6 || c7 || c8)
430 return -EOPNOTSUPP;
431
432 err = test_get_ao_info(sk, &info);
433 if (err == -ENOENT)
434 return 0;
435 if (err)
436 return err;
437
438 /* per-socket */
439 out->ao.ao_info_pkt_good = info.pkt_good;
440 out->ao.ao_info_pkt_bad = info.pkt_bad;
441 out->ao.ao_info_pkt_key_not_found = info.pkt_key_not_found;
442 out->ao.ao_info_pkt_ao_required = info.pkt_ao_required;
443 out->ao.ao_info_pkt_dropped_icmp = info.pkt_dropped_icmp;
444
445 /* per-key */
446 nr_keys = test_get_ao_keys_nr(sk);
447 if (nr_keys < 0)
448 return nr_keys;
449 if (nr_keys == 0)
450 test_error("test_get_ao_keys_nr() == 0");
451 out->ao.nr_keys = (size_t)nr_keys;
452 key_dump = calloc(nr_keys, key_dump_sz);
453 if (!key_dump)
454 return -errno;
455
456 key_dump[0].nkeys = nr_keys;
457 key_dump[0].get_all = 1;
458 err = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS,
459 key_dump, &key_dump_sz);
460 if (err) {
461 free(key_dump);
462 return -errno;
463 }
464
465 out->ao.key_cnts = calloc(nr_keys, sizeof(out->ao.key_cnts[0]));
466 if (!out->ao.key_cnts) {
467 free(key_dump);
468 return -errno;
469 }
470
471 while (nr_keys--) {
472 out->ao.key_cnts[nr_keys].sndid = key_dump[nr_keys].sndid;
473 out->ao.key_cnts[nr_keys].rcvid = key_dump[nr_keys].rcvid;
474 out->ao.key_cnts[nr_keys].pkt_good = key_dump[nr_keys].pkt_good;
475 out->ao.key_cnts[nr_keys].pkt_bad = key_dump[nr_keys].pkt_bad;
476 }
477 free(key_dump);
478
479 return 0;
480 }
481
test_cmp_counters(struct tcp_counters * before,struct tcp_counters * after)482 test_cnt test_cmp_counters(struct tcp_counters *before,
483 struct tcp_counters *after)
484 {
485 #define __cmp(cnt, e_cnt) \
486 do { \
487 if (before->cnt > after->cnt) \
488 test_error("counter " __stringify(cnt) " decreased"); \
489 if (before->cnt != after->cnt) \
490 ret |= e_cnt; \
491 } while (0)
492
493 test_cnt ret = 0;
494 size_t i;
495
496 if (before->ao.nr_keys != after->ao.nr_keys)
497 test_error("the number of keys has changed");
498
499 _for_each_counter(__cmp);
500
501 i = before->ao.nr_keys;
502 while (i--) {
503 __cmp(ao.key_cnts[i].pkt_good, TEST_CNT_KEY_GOOD);
504 __cmp(ao.key_cnts[i].pkt_bad, TEST_CNT_KEY_BAD);
505 }
506 #undef __cmp
507 return ret;
508 }
509
test_assert_counters_sk(const char * tst_name,struct tcp_counters * before,struct tcp_counters * after,test_cnt expected)510 int test_assert_counters_sk(const char *tst_name,
511 struct tcp_counters *before,
512 struct tcp_counters *after,
513 test_cnt expected)
514 {
515 #define __cmp_ao(cnt, e_cnt) \
516 do { \
517 if (before->cnt > after->cnt) { \
518 test_fail("%s: Decreased counter " __stringify(cnt) " %" PRIu64 " > %" PRIu64, \
519 tst_name ?: "", before->cnt, after->cnt); \
520 return -1; \
521 } \
522 if ((before->cnt != after->cnt) != !!(expected & e_cnt)) { \
523 test_fail("%s: Counter " __stringify(cnt) " was %sexpected to increase %" PRIu64 " => %" PRIu64, \
524 tst_name ?: "", (expected & e_cnt) ? "" : "not ", \
525 before->cnt, after->cnt); \
526 return -1; \
527 } \
528 } while (0)
529
530 errno = 0;
531 _for_each_counter(__cmp_ao);
532 return 0;
533 #undef __cmp_ao
534 }
535
test_assert_counters_key(const char * tst_name,struct tcp_ao_counters * before,struct tcp_ao_counters * after,test_cnt expected,int sndid,int rcvid)536 int test_assert_counters_key(const char *tst_name,
537 struct tcp_ao_counters *before,
538 struct tcp_ao_counters *after,
539 test_cnt expected, int sndid, int rcvid)
540 {
541 size_t i;
542 #define __cmp_ao(i, cnt, e_cnt) \
543 do { \
544 if (before->key_cnts[i].cnt > after->key_cnts[i].cnt) { \
545 test_fail("%s: Decreased counter " __stringify(cnt) " %" PRIu64 " > %" PRIu64 " for key %u:%u", \
546 tst_name ?: "", before->key_cnts[i].cnt, \
547 after->key_cnts[i].cnt, \
548 before->key_cnts[i].sndid, \
549 before->key_cnts[i].rcvid); \
550 return -1; \
551 } \
552 if ((before->key_cnts[i].cnt != after->key_cnts[i].cnt) != !!(expected & e_cnt)) { \
553 test_fail("%s: Counter " __stringify(cnt) " was %sexpected to increase %" PRIu64 " => %" PRIu64 " for key %u:%u", \
554 tst_name ?: "", (expected & e_cnt) ? "" : "not ",\
555 before->key_cnts[i].cnt, \
556 after->key_cnts[i].cnt, \
557 before->key_cnts[i].sndid, \
558 before->key_cnts[i].rcvid); \
559 return -1; \
560 } \
561 } while (0)
562
563 if (before->nr_keys != after->nr_keys) {
564 test_fail("%s: Keys changed on the socket %zu != %zu",
565 tst_name, before->nr_keys, after->nr_keys);
566 return -1;
567 }
568
569 /* per-key */
570 i = before->nr_keys;
571 while (i--) {
572 if (sndid >= 0 && before->key_cnts[i].sndid != sndid)
573 continue;
574 if (rcvid >= 0 && before->key_cnts[i].rcvid != rcvid)
575 continue;
576 __cmp_ao(i, pkt_good, TEST_CNT_KEY_GOOD);
577 __cmp_ao(i, pkt_bad, TEST_CNT_KEY_BAD);
578 }
579 return 0;
580 #undef __cmp_ao
581 }
582
test_tcp_counters_free(struct tcp_counters * cnts)583 void test_tcp_counters_free(struct tcp_counters *cnts)
584 {
585 free(cnts->ao.key_cnts);
586 }
587
588 #define TEST_BUF_SIZE 4096
_test_server_run(int sk,ssize_t quota,struct tcp_counters * c,test_cnt cond,volatile int * err,time_t timeout_sec)589 static ssize_t _test_server_run(int sk, ssize_t quota, struct tcp_counters *c,
590 test_cnt cond, volatile int *err,
591 time_t timeout_sec)
592 {
593 ssize_t total = 0;
594
595 do {
596 char buf[TEST_BUF_SIZE];
597 ssize_t bytes, sent;
598 int ret;
599
600 ret = __test_skpair_poll(sk, 0, timeout_sec, c, cond, err);
601 if (ret)
602 return ret;
603
604 bytes = recv(sk, buf, sizeof(buf), 0);
605
606 if (bytes < 0)
607 test_error("recv(): %zd", bytes);
608 if (bytes == 0)
609 break;
610
611 ret = __test_skpair_poll(sk, 1, timeout_sec, c, cond, err);
612 if (ret)
613 return ret;
614
615 sent = send(sk, buf, bytes, 0);
616 if (sent == 0)
617 break;
618 if (sent != bytes)
619 test_error("send()");
620 total += bytes;
621 } while (!quota || total < quota);
622
623 return total;
624 }
625
test_server_run(int sk,ssize_t quota,time_t timeout_sec)626 ssize_t test_server_run(int sk, ssize_t quota, time_t timeout_sec)
627 {
628 return _test_server_run(sk, quota, NULL, 0, NULL,
629 timeout_sec ?: TEST_TIMEOUT_SEC);
630 }
631
test_skpair_server(int sk,ssize_t quota,test_cnt cond,volatile int * err)632 int test_skpair_server(int sk, ssize_t quota, test_cnt cond, volatile int *err)
633 {
634 struct tcp_counters c;
635 ssize_t ret;
636
637 *err = 0;
638 if (test_get_tcp_counters(sk, &c))
639 test_error("test_get_tcp_counters()");
640 synchronize_threads(); /* 1: init skpair & read nscounters */
641
642 ret = _test_server_run(sk, quota, &c, cond, err, TEST_TIMEOUT_SEC);
643 test_tcp_counters_free(&c);
644 return ret;
645 }
646
test_client_loop(int sk,size_t buf_sz,const size_t msg_len,struct tcp_counters * c,test_cnt cond,volatile int * err)647 static ssize_t test_client_loop(int sk, size_t buf_sz, const size_t msg_len,
648 struct tcp_counters *c, test_cnt cond,
649 volatile int *err)
650 {
651 char msg[msg_len];
652 int nodelay = 1;
653 char *buf;
654 size_t i;
655
656 buf = alloca(buf_sz);
657 if (!buf)
658 return -ENOMEM;
659 randomize_buffer(buf, buf_sz);
660
661 if (setsockopt(sk, IPPROTO_TCP, TCP_NODELAY, &nodelay, sizeof(nodelay)))
662 test_error("setsockopt(TCP_NODELAY)");
663
664 for (i = 0; i < buf_sz; i += min(msg_len, buf_sz - i)) {
665 size_t sent, bytes = min(msg_len, buf_sz - i);
666 int ret;
667
668 ret = __test_skpair_poll(sk, 1, TEST_TIMEOUT_SEC, c, cond, err);
669 if (ret)
670 return ret;
671
672 sent = send(sk, buf + i, bytes, 0);
673 if (sent == 0)
674 break;
675 if (sent != bytes)
676 test_error("send()");
677
678 bytes = 0;
679 do {
680 ssize_t got;
681
682 ret = __test_skpair_poll(sk, 0, TEST_TIMEOUT_SEC,
683 c, cond, err);
684 if (ret)
685 return ret;
686
687 got = recv(sk, msg + bytes, sizeof(msg) - bytes, 0);
688 if (got <= 0)
689 return i;
690 bytes += got;
691 } while (bytes < sent);
692 if (bytes > sent)
693 test_error("recv(): %zd > %zd", bytes, sent);
694 if (memcmp(buf + i, msg, bytes) != 0) {
695 test_fail("received message differs");
696 return -1;
697 }
698 }
699 return i;
700 }
701
test_client_verify(int sk,const size_t msg_len,const size_t nr)702 int test_client_verify(int sk, const size_t msg_len, const size_t nr)
703 {
704 size_t buf_sz = msg_len * nr;
705 ssize_t ret;
706
707 ret = test_client_loop(sk, buf_sz, msg_len, NULL, 0, NULL);
708 if (ret < 0)
709 return (int)ret;
710 return ret != buf_sz ? -1 : 0;
711 }
712
test_skpair_client(int sk,const size_t msg_len,const size_t nr,test_cnt cond,volatile int * err)713 int test_skpair_client(int sk, const size_t msg_len, const size_t nr,
714 test_cnt cond, volatile int *err)
715 {
716 struct tcp_counters c;
717 size_t buf_sz = msg_len * nr;
718 ssize_t ret;
719
720 *err = 0;
721 if (test_get_tcp_counters(sk, &c))
722 test_error("test_get_tcp_counters()");
723 synchronize_threads(); /* 1: init skpair & read nscounters */
724
725 ret = test_client_loop(sk, buf_sz, msg_len, &c, cond, err);
726 test_tcp_counters_free(&c);
727 if (ret < 0)
728 return (int)ret;
729 return ret != buf_sz ? -1 : 0;
730 }
731