1 // SPDX-License-Identifier: GPL-2.0
2 #include <alloca.h>
3 #include <fcntl.h>
4 #include <inttypes.h>
5 #include <string.h>
6 #include "../../../../../include/linux/kernel.h"
7 #include "../../../../../include/linux/stringify.h"
8 #include "aolib.h"
9
10 const unsigned int test_server_port = 7010;
__test_listen_socket(int backlog,void * addr,size_t addr_sz)11 int __test_listen_socket(int backlog, void *addr, size_t addr_sz)
12 {
13 int err, sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
14 long flags;
15
16 if (sk < 0)
17 test_error("socket()");
18
19 err = setsockopt(sk, SOL_SOCKET, SO_BINDTODEVICE, veth_name,
20 strlen(veth_name) + 1);
21 if (err < 0)
22 test_error("setsockopt(SO_BINDTODEVICE)");
23
24 if (bind(sk, (struct sockaddr *)addr, addr_sz) < 0)
25 test_error("bind()");
26
27 flags = fcntl(sk, F_GETFL);
28 if ((flags < 0) || (fcntl(sk, F_SETFL, flags | O_NONBLOCK) < 0))
29 test_error("fcntl()");
30
31 if (listen(sk, backlog))
32 test_error("listen()");
33
34 return sk;
35 }
36
test_wait_fd(int sk,time_t sec,bool write)37 int test_wait_fd(int sk, time_t sec, bool write)
38 {
39 struct timeval tv = { .tv_sec = sec };
40 struct timeval *ptv = NULL;
41 fd_set fds, efds;
42 int ret;
43 socklen_t slen = sizeof(ret);
44
45 FD_ZERO(&fds);
46 FD_SET(sk, &fds);
47 FD_ZERO(&efds);
48 FD_SET(sk, &efds);
49
50 if (sec)
51 ptv = &tv;
52
53 errno = 0;
54 if (write)
55 ret = select(sk + 1, NULL, &fds, &efds, ptv);
56 else
57 ret = select(sk + 1, &fds, NULL, &efds, ptv);
58 if (ret < 0)
59 return -errno;
60 if (ret == 0) {
61 errno = ETIMEDOUT;
62 return -ETIMEDOUT;
63 }
64
65 if (getsockopt(sk, SOL_SOCKET, SO_ERROR, &ret, &slen))
66 return -errno;
67 if (ret)
68 return -ret;
69 return 0;
70 }
71
__test_connect_socket(int sk,const char * device,void * addr,size_t addr_sz,time_t timeout)72 int __test_connect_socket(int sk, const char *device,
73 void *addr, size_t addr_sz, time_t timeout)
74 {
75 long flags;
76 int err;
77
78 if (device != NULL) {
79 err = setsockopt(sk, SOL_SOCKET, SO_BINDTODEVICE, device,
80 strlen(device) + 1);
81 if (err < 0)
82 test_error("setsockopt(SO_BINDTODEVICE, %s)", device);
83 }
84
85 if (!timeout) {
86 err = connect(sk, addr, addr_sz);
87 if (err) {
88 err = -errno;
89 goto out;
90 }
91 return 0;
92 }
93
94 flags = fcntl(sk, F_GETFL);
95 if ((flags < 0) || (fcntl(sk, F_SETFL, flags | O_NONBLOCK) < 0))
96 test_error("fcntl()");
97
98 if (connect(sk, addr, addr_sz) < 0) {
99 if (errno != EINPROGRESS) {
100 err = -errno;
101 goto out;
102 }
103 if (timeout < 0)
104 return sk;
105 err = test_wait_fd(sk, timeout, 1);
106 if (err)
107 goto out;
108 }
109 return sk;
110
111 out:
112 close(sk);
113 return err;
114 }
115
__test_set_md5(int sk,void * addr,size_t addr_sz,uint8_t prefix,int vrf,const char * password)116 int __test_set_md5(int sk, void *addr, size_t addr_sz, uint8_t prefix,
117 int vrf, const char *password)
118 {
119 size_t pwd_len = strlen(password);
120 struct tcp_md5sig md5sig = {};
121
122 md5sig.tcpm_keylen = pwd_len;
123 memcpy(md5sig.tcpm_key, password, pwd_len);
124 md5sig.tcpm_flags = TCP_MD5SIG_FLAG_PREFIX;
125 md5sig.tcpm_prefixlen = prefix;
126 if (vrf >= 0) {
127 md5sig.tcpm_flags |= TCP_MD5SIG_FLAG_IFINDEX;
128 md5sig.tcpm_ifindex = (uint8_t)vrf;
129 }
130 memcpy(&md5sig.tcpm_addr, addr, addr_sz);
131
132 errno = 0;
133 return setsockopt(sk, IPPROTO_TCP, TCP_MD5SIG_EXT,
134 &md5sig, sizeof(md5sig));
135 }
136
137
test_prepare_key_sockaddr(struct tcp_ao_add * ao,const char * alg,void * addr,size_t addr_sz,bool set_current,bool set_rnext,uint8_t prefix,uint8_t vrf,uint8_t sndid,uint8_t rcvid,uint8_t maclen,uint8_t keyflags,uint8_t keylen,const char * key)138 int test_prepare_key_sockaddr(struct tcp_ao_add *ao, const char *alg,
139 void *addr, size_t addr_sz, bool set_current, bool set_rnext,
140 uint8_t prefix, uint8_t vrf, uint8_t sndid, uint8_t rcvid,
141 uint8_t maclen, uint8_t keyflags,
142 uint8_t keylen, const char *key)
143 {
144 memset(ao, 0, sizeof(struct tcp_ao_add));
145
146 ao->set_current = !!set_current;
147 ao->set_rnext = !!set_rnext;
148 ao->prefix = prefix;
149 ao->sndid = sndid;
150 ao->rcvid = rcvid;
151 ao->maclen = maclen;
152 ao->keyflags = keyflags;
153 ao->keylen = keylen;
154 ao->ifindex = vrf;
155
156 memcpy(&ao->addr, addr, addr_sz);
157
158 if (strlen(alg) > 64)
159 return -ENOBUFS;
160 strncpy(ao->alg_name, alg, 64);
161
162 memcpy(ao->key, key,
163 (keylen > TCP_AO_MAXKEYLEN) ? TCP_AO_MAXKEYLEN : keylen);
164 return 0;
165 }
166
test_get_ao_keys_nr(int sk)167 static int test_get_ao_keys_nr(int sk)
168 {
169 struct tcp_ao_getsockopt tmp = {};
170 socklen_t tmp_sz = sizeof(tmp);
171 int ret;
172
173 tmp.nkeys = 1;
174 tmp.get_all = 1;
175
176 ret = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, &tmp, &tmp_sz);
177 if (ret)
178 return -errno;
179 return (int)tmp.nkeys;
180 }
181
test_get_one_ao(int sk,struct tcp_ao_getsockopt * out,void * addr,size_t addr_sz,uint8_t prefix,uint8_t sndid,uint8_t rcvid)182 int test_get_one_ao(int sk, struct tcp_ao_getsockopt *out,
183 void *addr, size_t addr_sz, uint8_t prefix,
184 uint8_t sndid, uint8_t rcvid)
185 {
186 struct tcp_ao_getsockopt tmp = {};
187 socklen_t tmp_sz = sizeof(tmp);
188 int ret;
189
190 memcpy(&tmp.addr, addr, addr_sz);
191 tmp.prefix = prefix;
192 tmp.sndid = sndid;
193 tmp.rcvid = rcvid;
194 tmp.nkeys = 1;
195
196 ret = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, &tmp, &tmp_sz);
197 if (ret)
198 return ret;
199 if (tmp.nkeys != 1)
200 return -E2BIG;
201 *out = tmp;
202 return 0;
203 }
204
test_get_ao_info(int sk,struct tcp_ao_info_opt * out)205 int test_get_ao_info(int sk, struct tcp_ao_info_opt *out)
206 {
207 socklen_t sz = sizeof(*out);
208
209 out->reserved = 0;
210 out->reserved2 = 0;
211 if (getsockopt(sk, IPPROTO_TCP, TCP_AO_INFO, out, &sz))
212 return -errno;
213 if (sz != sizeof(*out))
214 return -EMSGSIZE;
215 return 0;
216 }
217
test_set_ao_info(int sk,struct tcp_ao_info_opt * in)218 int test_set_ao_info(int sk, struct tcp_ao_info_opt *in)
219 {
220 socklen_t sz = sizeof(*in);
221
222 in->reserved = 0;
223 in->reserved2 = 0;
224 if (setsockopt(sk, IPPROTO_TCP, TCP_AO_INFO, in, sz))
225 return -errno;
226 return 0;
227 }
228
test_cmp_getsockopt_setsockopt(const struct tcp_ao_add * a,const struct tcp_ao_getsockopt * b)229 int test_cmp_getsockopt_setsockopt(const struct tcp_ao_add *a,
230 const struct tcp_ao_getsockopt *b)
231 {
232 bool is_kdf_aes_128_cmac = false;
233 bool is_cmac_aes = false;
234
235 if (!strcmp("cmac(aes128)", a->alg_name)) {
236 is_kdf_aes_128_cmac = (a->keylen != 16);
237 is_cmac_aes = true;
238 }
239
240 #define __cmp_ao(member) \
241 do { \
242 if (b->member != a->member) { \
243 test_fail("getsockopt(): " __stringify(member) " %u != %u", \
244 b->member, a->member); \
245 return -1; \
246 } \
247 } while(0)
248 __cmp_ao(sndid);
249 __cmp_ao(rcvid);
250 __cmp_ao(prefix);
251 __cmp_ao(keyflags);
252 __cmp_ao(ifindex);
253 if (a->maclen) {
254 __cmp_ao(maclen);
255 } else if (b->maclen != 12) {
256 test_fail("getsockopt(): expected default maclen 12, but it's %u",
257 b->maclen);
258 return -1;
259 }
260 if (!is_kdf_aes_128_cmac) {
261 __cmp_ao(keylen);
262 } else if (b->keylen != 16) {
263 test_fail("getsockopt(): expected keylen 16 for cmac(aes128), but it's %u",
264 b->keylen);
265 return -1;
266 }
267 #undef __cmp_ao
268 if (!is_kdf_aes_128_cmac && memcmp(b->key, a->key, a->keylen)) {
269 test_fail("getsockopt(): returned key is different `%s' != `%s'",
270 b->key, a->key);
271 return -1;
272 }
273 if (memcmp(&b->addr, &a->addr, sizeof(b->addr))) {
274 test_fail("getsockopt(): returned address is different");
275 return -1;
276 }
277 if (!is_cmac_aes && strcmp(b->alg_name, a->alg_name)) {
278 test_fail("getsockopt(): returned algorithm %s is different than %s", b->alg_name, a->alg_name);
279 return -1;
280 }
281 if (is_cmac_aes && strcmp(b->alg_name, "cmac(aes)")) {
282 test_fail("getsockopt(): returned algorithm %s is different than cmac(aes)", b->alg_name);
283 return -1;
284 }
285 /* For a established key rotation test don't add a key with
286 * set_current = 1, as it's likely to change by peer's request;
287 * rather use setsockopt(TCP_AO_INFO)
288 */
289 if (a->set_current != b->is_current) {
290 test_fail("getsockopt(): returned key is not Current_key");
291 return -1;
292 }
293 if (a->set_rnext != b->is_rnext) {
294 test_fail("getsockopt(): returned key is not RNext_key");
295 return -1;
296 }
297
298 return 0;
299 }
300
test_cmp_getsockopt_setsockopt_ao(const struct tcp_ao_info_opt * a,const struct tcp_ao_info_opt * b)301 int test_cmp_getsockopt_setsockopt_ao(const struct tcp_ao_info_opt *a,
302 const struct tcp_ao_info_opt *b)
303 {
304 /* No check for ::current_key, as it may change by the peer */
305 if (a->ao_required != b->ao_required) {
306 test_fail("getsockopt(): returned ao doesn't have ao_required");
307 return -1;
308 }
309 if (a->accept_icmps != b->accept_icmps) {
310 test_fail("getsockopt(): returned ao doesn't accept ICMPs");
311 return -1;
312 }
313 if (a->set_rnext && a->rnext != b->rnext) {
314 test_fail("getsockopt(): RNext KeyID has changed");
315 return -1;
316 }
317 #define __cmp_cnt(member) \
318 do { \
319 if (b->member != a->member) { \
320 test_fail("getsockopt(): " __stringify(member) " %llu != %llu", \
321 b->member, a->member); \
322 return -1; \
323 } \
324 } while(0)
325 if (a->set_counters) {
326 __cmp_cnt(pkt_good);
327 __cmp_cnt(pkt_bad);
328 __cmp_cnt(pkt_key_not_found);
329 __cmp_cnt(pkt_ao_required);
330 __cmp_cnt(pkt_dropped_icmp);
331 }
332 #undef __cmp_cnt
333 return 0;
334 }
335
test_get_tcp_ao_counters(int sk,struct tcp_ao_counters * out)336 int test_get_tcp_ao_counters(int sk, struct tcp_ao_counters *out)
337 {
338 struct tcp_ao_getsockopt *key_dump;
339 socklen_t key_dump_sz = sizeof(*key_dump);
340 struct tcp_ao_info_opt info = {};
341 bool c1, c2, c3, c4, c5;
342 struct netstat *ns;
343 int err, nr_keys;
344
345 memset(out, 0, sizeof(*out));
346
347 /* per-netns */
348 ns = netstat_read();
349 out->netns_ao_good = netstat_get(ns, "TCPAOGood", &c1);
350 out->netns_ao_bad = netstat_get(ns, "TCPAOBad", &c2);
351 out->netns_ao_key_not_found = netstat_get(ns, "TCPAOKeyNotFound", &c3);
352 out->netns_ao_required = netstat_get(ns, "TCPAORequired", &c4);
353 out->netns_ao_dropped_icmp = netstat_get(ns, "TCPAODroppedIcmps", &c5);
354 netstat_free(ns);
355 if (c1 || c2 || c3 || c4 || c5)
356 return -EOPNOTSUPP;
357
358 err = test_get_ao_info(sk, &info);
359 if (err)
360 return err;
361
362 /* per-socket */
363 out->ao_info_pkt_good = info.pkt_good;
364 out->ao_info_pkt_bad = info.pkt_bad;
365 out->ao_info_pkt_key_not_found = info.pkt_key_not_found;
366 out->ao_info_pkt_ao_required = info.pkt_ao_required;
367 out->ao_info_pkt_dropped_icmp = info.pkt_dropped_icmp;
368
369 /* per-key */
370 nr_keys = test_get_ao_keys_nr(sk);
371 if (nr_keys < 0)
372 return nr_keys;
373 if (nr_keys == 0)
374 test_error("test_get_ao_keys_nr() == 0");
375 out->nr_keys = (size_t)nr_keys;
376 key_dump = calloc(nr_keys, key_dump_sz);
377 if (!key_dump)
378 return -errno;
379
380 key_dump[0].nkeys = nr_keys;
381 key_dump[0].get_all = 1;
382 err = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS,
383 key_dump, &key_dump_sz);
384 if (err) {
385 free(key_dump);
386 return -errno;
387 }
388
389 out->key_cnts = calloc(nr_keys, sizeof(out->key_cnts[0]));
390 if (!out->key_cnts) {
391 free(key_dump);
392 return -errno;
393 }
394
395 while (nr_keys--) {
396 out->key_cnts[nr_keys].sndid = key_dump[nr_keys].sndid;
397 out->key_cnts[nr_keys].rcvid = key_dump[nr_keys].rcvid;
398 out->key_cnts[nr_keys].pkt_good = key_dump[nr_keys].pkt_good;
399 out->key_cnts[nr_keys].pkt_bad = key_dump[nr_keys].pkt_bad;
400 }
401 free(key_dump);
402
403 return 0;
404 }
405
__test_tcp_ao_counters_cmp(const char * tst_name,struct tcp_ao_counters * before,struct tcp_ao_counters * after,test_cnt expected)406 int __test_tcp_ao_counters_cmp(const char *tst_name,
407 struct tcp_ao_counters *before,
408 struct tcp_ao_counters *after,
409 test_cnt expected)
410 {
411 #define __cmp_ao(cnt, expecting_inc) \
412 do { \
413 if (before->cnt > after->cnt) { \
414 test_fail("%s: Decreased counter " __stringify(cnt) " %" PRIu64 " > %" PRIu64, \
415 tst_name ?: "", before->cnt, after->cnt); \
416 return -1; \
417 } \
418 if ((before->cnt != after->cnt) != (expecting_inc)) { \
419 test_fail("%s: Counter " __stringify(cnt) " was %sexpected to increase %" PRIu64 " => %" PRIu64, \
420 tst_name ?: "", (expecting_inc) ? "" : "not ", \
421 before->cnt, after->cnt); \
422 return -1; \
423 } \
424 } while(0)
425
426 errno = 0;
427 /* per-netns */
428 __cmp_ao(netns_ao_good, !!(expected & TEST_CNT_NS_GOOD));
429 __cmp_ao(netns_ao_bad, !!(expected & TEST_CNT_NS_BAD));
430 __cmp_ao(netns_ao_key_not_found,
431 !!(expected & TEST_CNT_NS_KEY_NOT_FOUND));
432 __cmp_ao(netns_ao_required, !!(expected & TEST_CNT_NS_AO_REQUIRED));
433 __cmp_ao(netns_ao_dropped_icmp,
434 !!(expected & TEST_CNT_NS_DROPPED_ICMP));
435 /* per-socket */
436 __cmp_ao(ao_info_pkt_good, !!(expected & TEST_CNT_SOCK_GOOD));
437 __cmp_ao(ao_info_pkt_bad, !!(expected & TEST_CNT_SOCK_BAD));
438 __cmp_ao(ao_info_pkt_key_not_found,
439 !!(expected & TEST_CNT_SOCK_KEY_NOT_FOUND));
440 __cmp_ao(ao_info_pkt_ao_required, !!(expected & TEST_CNT_SOCK_AO_REQUIRED));
441 __cmp_ao(ao_info_pkt_dropped_icmp,
442 !!(expected & TEST_CNT_SOCK_DROPPED_ICMP));
443 return 0;
444 #undef __cmp_ao
445 }
446
test_tcp_ao_key_counters_cmp(const char * tst_name,struct tcp_ao_counters * before,struct tcp_ao_counters * after,test_cnt expected,int sndid,int rcvid)447 int test_tcp_ao_key_counters_cmp(const char *tst_name,
448 struct tcp_ao_counters *before,
449 struct tcp_ao_counters *after,
450 test_cnt expected,
451 int sndid, int rcvid)
452 {
453 size_t i;
454 #define __cmp_ao(i, cnt, expecting_inc) \
455 do { \
456 if (before->key_cnts[i].cnt > after->key_cnts[i].cnt) { \
457 test_fail("%s: Decreased counter " __stringify(cnt) " %" PRIu64 " > %" PRIu64 " for key %u:%u", \
458 tst_name ?: "", before->key_cnts[i].cnt, \
459 after->key_cnts[i].cnt, \
460 before->key_cnts[i].sndid, \
461 before->key_cnts[i].rcvid); \
462 return -1; \
463 } \
464 if ((before->key_cnts[i].cnt != after->key_cnts[i].cnt) != (expecting_inc)) { \
465 test_fail("%s: Counter " __stringify(cnt) " was %sexpected to increase %" PRIu64 " => %" PRIu64 " for key %u:%u", \
466 tst_name ?: "", (expecting_inc) ? "" : "not ",\
467 before->key_cnts[i].cnt, \
468 after->key_cnts[i].cnt, \
469 before->key_cnts[i].sndid, \
470 before->key_cnts[i].rcvid); \
471 return -1; \
472 } \
473 } while(0)
474
475 if (before->nr_keys != after->nr_keys) {
476 test_fail("%s: Keys changed on the socket %zu != %zu",
477 tst_name, before->nr_keys, after->nr_keys);
478 return -1;
479 }
480
481 /* per-key */
482 i = before->nr_keys;
483 while (i--) {
484 if (sndid >= 0 && before->key_cnts[i].sndid != sndid)
485 continue;
486 if (rcvid >= 0 && before->key_cnts[i].rcvid != rcvid)
487 continue;
488 __cmp_ao(i, pkt_good, !!(expected & TEST_CNT_KEY_GOOD));
489 __cmp_ao(i, pkt_bad, !!(expected & TEST_CNT_KEY_BAD));
490 }
491 return 0;
492 #undef __cmp_ao
493 }
494
test_tcp_ao_counters_free(struct tcp_ao_counters * cnts)495 void test_tcp_ao_counters_free(struct tcp_ao_counters *cnts)
496 {
497 free(cnts->key_cnts);
498 }
499
500 #define TEST_BUF_SIZE 4096
test_server_run(int sk,ssize_t quota,time_t timeout_sec)501 ssize_t test_server_run(int sk, ssize_t quota, time_t timeout_sec)
502 {
503 ssize_t total = 0;
504
505 do {
506 char buf[TEST_BUF_SIZE];
507 ssize_t bytes, sent;
508 int ret;
509
510 ret = test_wait_fd(sk, timeout_sec, 0);
511 if (ret)
512 return ret;
513
514 bytes = recv(sk, buf, sizeof(buf), 0);
515
516 if (bytes < 0)
517 test_error("recv(): %zd", bytes);
518 if (bytes == 0)
519 break;
520
521 ret = test_wait_fd(sk, timeout_sec, 1);
522 if (ret)
523 return ret;
524
525 sent = send(sk, buf, bytes, 0);
526 if (sent == 0)
527 break;
528 if (sent != bytes)
529 test_error("send()");
530 total += bytes;
531 } while (!quota || total < quota);
532
533 return total;
534 }
535
test_client_loop(int sk,char * buf,size_t buf_sz,const size_t msg_len,time_t timeout_sec)536 ssize_t test_client_loop(int sk, char *buf, size_t buf_sz,
537 const size_t msg_len, time_t timeout_sec)
538 {
539 char msg[msg_len];
540 int nodelay = 1;
541 size_t i;
542
543 if (setsockopt(sk, IPPROTO_TCP, TCP_NODELAY, &nodelay, sizeof(nodelay)))
544 test_error("setsockopt(TCP_NODELAY)");
545
546 for (i = 0; i < buf_sz; i += min(msg_len, buf_sz - i)) {
547 size_t sent, bytes = min(msg_len, buf_sz - i);
548 int ret;
549
550 ret = test_wait_fd(sk, timeout_sec, 1);
551 if (ret)
552 return ret;
553
554 sent = send(sk, buf + i, bytes, 0);
555 if (sent == 0)
556 break;
557 if (sent != bytes)
558 test_error("send()");
559
560 bytes = 0;
561 do {
562 ssize_t got;
563
564 ret = test_wait_fd(sk, timeout_sec, 0);
565 if (ret)
566 return ret;
567
568 got = recv(sk, msg + bytes, sizeof(msg) - bytes, 0);
569 if (got <= 0)
570 return i;
571 bytes += got;
572 } while (bytes < sent);
573 if (bytes > sent)
574 test_error("recv(): %zd > %zd", bytes, sent);
575 if (memcmp(buf + i, msg, bytes) != 0) {
576 test_fail("received message differs");
577 return -1;
578 }
579 }
580 return i;
581 }
582
test_client_verify(int sk,const size_t msg_len,const size_t nr,time_t timeout_sec)583 int test_client_verify(int sk, const size_t msg_len, const size_t nr,
584 time_t timeout_sec)
585 {
586 size_t buf_sz = msg_len * nr;
587 char *buf = alloca(buf_sz);
588 ssize_t ret;
589
590 randomize_buffer(buf, buf_sz);
591 ret = test_client_loop(sk, buf, buf_sz, msg_len, timeout_sec);
592 if (ret < 0)
593 return (int)ret;
594 return ret != buf_sz ? -1 : 0;
595 }
596