1 // SPDX-License-Identifier: GPL-2.0 2 /* Author: Dmitry Safonov <dima@arista.com> */ 3 #include <inttypes.h> 4 #include "aolib.h" 5 6 #define fault(type) (inj == FAULT_ ## type) 7 static volatile int sk_pair; 8 9 static inline int test_add_key_maclen(int sk, const char *key, uint8_t maclen, 10 union tcp_addr in_addr, uint8_t prefix, 11 uint8_t sndid, uint8_t rcvid) 12 { 13 struct tcp_ao_add tmp = {}; 14 int err; 15 16 if (prefix > DEFAULT_TEST_PREFIX) 17 prefix = DEFAULT_TEST_PREFIX; 18 19 err = test_prepare_key(&tmp, DEFAULT_TEST_ALGO, in_addr, false, false, 20 prefix, 0, sndid, rcvid, maclen, 21 0, strlen(key), key); 22 if (err) 23 return err; 24 25 err = setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &tmp, sizeof(tmp)); 26 if (err < 0) 27 return -errno; 28 29 return test_verify_socket_key(sk, &tmp); 30 } 31 32 static void try_accept(const char *tst_name, unsigned int port, const char *pwd, 33 union tcp_addr addr, uint8_t prefix, 34 uint8_t sndid, uint8_t rcvid, uint8_t maclen, 35 const char *cnt_name, test_cnt cnt_expected, 36 fault_t inj) 37 { 38 struct tcp_counters cnt1, cnt2; 39 uint64_t before_cnt = 0, after_cnt = 0; /* silence GCC */ 40 test_cnt poll_cnt = (cnt_expected == TEST_CNT_GOOD) ? 0 : cnt_expected; 41 int lsk, err, sk = 0; 42 43 lsk = test_listen_socket(this_ip_addr, port, 1); 44 45 if (pwd && test_add_key_maclen(lsk, pwd, maclen, addr, prefix, sndid, rcvid)) 46 test_error("setsockopt(TCP_AO_ADD_KEY)"); 47 48 if (cnt_name) 49 before_cnt = netstat_get_one(cnt_name, NULL); 50 if (pwd && test_get_tcp_counters(lsk, &cnt1)) 51 test_error("test_get_tcp_counters()"); 52 53 synchronize_threads(); /* preparations done */ 54 55 err = test_skpair_wait_poll(lsk, 0, poll_cnt, &sk_pair); 56 if (err == -ETIMEDOUT) { 57 sk_pair = err; 58 if (!fault(TIMEOUT)) 59 test_fail("%s: timed out for accept()", tst_name); 60 } else if (err == -EKEYREJECTED) { 61 if (!fault(KEYREJECT)) 62 test_fail("%s: key was rejected", tst_name); 63 } else if (err < 0) { 64 test_error("test_skpair_wait_poll()"); 65 } else { 66 if (fault(TIMEOUT)) 67 test_fail("%s: ready to accept", tst_name); 68 69 sk = accept(lsk, NULL, NULL); 70 if (sk < 0) { 71 test_error("accept()"); 72 } else { 73 if (fault(TIMEOUT)) 74 test_fail("%s: accepted", tst_name); 75 } 76 } 77 78 synchronize_threads(); /* before counter checks */ 79 if (pwd && test_get_tcp_counters(lsk, &cnt2)) 80 test_error("test_get_tcp_counters()"); 81 82 close(lsk); 83 84 if (pwd) 85 test_assert_counters(tst_name, &cnt1, &cnt2, cnt_expected); 86 87 if (!cnt_name) 88 goto out; 89 90 after_cnt = netstat_get_one(cnt_name, NULL); 91 92 if (after_cnt <= before_cnt) { 93 test_fail("%s: %s counter did not increase: %" PRIu64 " <= %" PRIu64, 94 tst_name, cnt_name, after_cnt, before_cnt); 95 } else { 96 test_ok("%s: counter %s increased %" PRIu64 " => %" PRIu64, 97 tst_name, cnt_name, before_cnt, after_cnt); 98 } 99 100 out: 101 synchronize_threads(); /* close() */ 102 if (sk > 0) 103 close(sk); 104 } 105 106 static void *server_fn(void *arg) 107 { 108 union tcp_addr wrong_addr, network_addr; 109 unsigned int port = test_server_port; 110 111 if (inet_pton(TEST_FAMILY, TEST_WRONG_IP, &wrong_addr) != 1) 112 test_error("Can't convert ip address %s", TEST_WRONG_IP); 113 114 try_accept("Non-AO server + AO client", port++, NULL, 115 this_ip_dest, -1, 100, 100, 0, 116 "TCPAOKeyNotFound", TEST_CNT_NS_KEY_NOT_FOUND, FAULT_TIMEOUT); 117 118 try_accept("AO server + Non-AO client", port++, DEFAULT_TEST_PASSWORD, 119 this_ip_dest, -1, 100, 100, 0, 120 "TCPAORequired", TEST_CNT_AO_REQUIRED, FAULT_TIMEOUT); 121 122 try_accept("Wrong password", port++, "something that is not DEFAULT_TEST_PASSWORD", 123 this_ip_dest, -1, 100, 100, 0, 124 "TCPAOBad", TEST_CNT_BAD, FAULT_TIMEOUT); 125 126 try_accept("Wrong rcv id", port++, DEFAULT_TEST_PASSWORD, 127 this_ip_dest, -1, 100, 101, 0, 128 "TCPAOKeyNotFound", TEST_CNT_AO_KEY_NOT_FOUND, FAULT_TIMEOUT); 129 130 try_accept("Wrong snd id", port++, DEFAULT_TEST_PASSWORD, 131 this_ip_dest, -1, 101, 100, 0, 132 "TCPAOGood", TEST_CNT_GOOD, FAULT_TIMEOUT); 133 134 try_accept("Different maclen", port++, DEFAULT_TEST_PASSWORD, 135 this_ip_dest, -1, 100, 100, 8, 136 "TCPAOBad", TEST_CNT_BAD, FAULT_TIMEOUT); 137 138 try_accept("Server: Wrong addr", port++, DEFAULT_TEST_PASSWORD, 139 wrong_addr, -1, 100, 100, 0, 140 "TCPAOKeyNotFound", TEST_CNT_AO_KEY_NOT_FOUND, FAULT_TIMEOUT); 141 142 /* Key rejected by the other side, failing short through skpair */ 143 try_accept("Client: Wrong addr", port++, NULL, 144 this_ip_dest, -1, 100, 100, 0, NULL, 0, FAULT_KEYREJECT); 145 146 try_accept("rcv id != snd id", port++, DEFAULT_TEST_PASSWORD, 147 this_ip_dest, -1, 200, 100, 0, 148 "TCPAOGood", TEST_CNT_GOOD, 0); 149 150 if (inet_pton(TEST_FAMILY, TEST_NETWORK, &network_addr) != 1) 151 test_error("Can't convert ip address %s", TEST_NETWORK); 152 153 try_accept("Server: prefix match", port++, DEFAULT_TEST_PASSWORD, 154 network_addr, 16, 100, 100, 0, 155 "TCPAOGood", TEST_CNT_GOOD, 0); 156 157 try_accept("Client: prefix match", port++, DEFAULT_TEST_PASSWORD, 158 this_ip_dest, -1, 100, 100, 0, 159 "TCPAOGood", TEST_CNT_GOOD, 0); 160 161 /* client exits */ 162 synchronize_threads(); 163 return NULL; 164 } 165 166 static void try_connect(const char *tst_name, unsigned int port, 167 const char *pwd, union tcp_addr addr, uint8_t prefix, 168 uint8_t sndid, uint8_t rcvid, 169 test_cnt cnt_expected, fault_t inj) 170 { 171 struct tcp_counters cnt1, cnt2; 172 int sk, ret; 173 174 sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP); 175 if (sk < 0) 176 test_error("socket()"); 177 178 if (pwd && test_add_key(sk, pwd, addr, prefix, sndid, rcvid)) 179 test_error("setsockopt(TCP_AO_ADD_KEY)"); 180 181 if (pwd && test_get_tcp_counters(sk, &cnt1)) 182 test_error("test_get_tcp_counters()"); 183 184 synchronize_threads(); /* preparations done */ 185 186 ret = test_skpair_connect_poll(sk, this_ip_dest, port, cnt_expected, &sk_pair); 187 synchronize_threads(); /* before counter checks */ 188 if (ret < 0) { 189 sk_pair = ret; 190 if (fault(KEYREJECT) && ret == -EKEYREJECTED) { 191 test_ok("%s: connect() was prevented", tst_name); 192 } else if (ret == -ETIMEDOUT && fault(TIMEOUT)) { 193 test_ok("%s", tst_name); 194 } else if (ret == -ECONNREFUSED && 195 (fault(TIMEOUT) || fault(KEYREJECT))) { 196 test_ok("%s: refused to connect", tst_name); 197 } else { 198 test_error("%s: connect() returned %d", tst_name, ret); 199 } 200 goto out; 201 } 202 203 if (fault(TIMEOUT) || fault(KEYREJECT)) 204 test_fail("%s: connected", tst_name); 205 else 206 test_ok("%s: connected", tst_name); 207 if (pwd && ret > 0) { 208 if (test_get_tcp_counters(sk, &cnt2)) 209 test_error("test_get_tcp_counters()"); 210 test_assert_counters(tst_name, &cnt1, &cnt2, cnt_expected); 211 } else if (pwd) { 212 test_tcp_counters_free(&cnt1); 213 } 214 out: 215 synchronize_threads(); /* close() */ 216 217 if (ret > 0) 218 close(sk); 219 } 220 221 static void *client_fn(void *arg) 222 { 223 union tcp_addr wrong_addr, network_addr, addr_any = {}; 224 unsigned int port = test_server_port; 225 226 if (inet_pton(TEST_FAMILY, TEST_WRONG_IP, &wrong_addr) != 1) 227 test_error("Can't convert ip address %s", TEST_WRONG_IP); 228 229 trace_ao_event_expect(TCP_AO_KEY_NOT_FOUND, this_ip_addr, this_ip_dest, 230 -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1); 231 try_connect("Non-AO server + AO client", port++, DEFAULT_TEST_PASSWORD, 232 this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT); 233 234 trace_hash_event_expect(TCP_HASH_AO_REQUIRED, this_ip_addr, this_ip_dest, 235 -1, port, 0, 0, 1, 0, 0, 0); 236 try_connect("AO server + Non-AO client", port++, NULL, 237 this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT); 238 239 trace_ao_event_expect(TCP_AO_MISMATCH, this_ip_addr, this_ip_dest, 240 -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1); 241 try_connect("Wrong password", port++, DEFAULT_TEST_PASSWORD, 242 this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT); 243 244 trace_ao_event_expect(TCP_AO_KEY_NOT_FOUND, this_ip_addr, this_ip_dest, 245 -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1); 246 try_connect("Wrong rcv id", port++, DEFAULT_TEST_PASSWORD, 247 this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT); 248 249 /* 250 * XXX: The test doesn't increase any counters, see tcp_make_synack(). 251 * Potentially, it can be speed up by setting sk_pair = -ETIMEDOUT 252 * but the price would be increased complexity of the tracer thread. 253 */ 254 trace_ao_event_sk_expect(TCP_AO_SYNACK_NO_KEY, this_ip_dest, addr_any, 255 port, 0, 100, 100); 256 try_connect("Wrong snd id", port++, DEFAULT_TEST_PASSWORD, 257 this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT); 258 259 trace_ao_event_expect(TCP_AO_WRONG_MACLEN, this_ip_addr, this_ip_dest, 260 -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1); 261 try_connect("Different maclen", port++, DEFAULT_TEST_PASSWORD, 262 this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT); 263 264 trace_ao_event_expect(TCP_AO_KEY_NOT_FOUND, this_ip_addr, this_ip_dest, 265 -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1); 266 try_connect("Server: Wrong addr", port++, DEFAULT_TEST_PASSWORD, 267 this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT); 268 269 try_connect("Client: Wrong addr", port++, DEFAULT_TEST_PASSWORD, 270 wrong_addr, -1, 100, 100, 0, FAULT_KEYREJECT); 271 272 try_connect("rcv id != snd id", port++, DEFAULT_TEST_PASSWORD, 273 this_ip_dest, -1, 100, 200, TEST_CNT_GOOD, 0); 274 275 if (inet_pton(TEST_FAMILY, TEST_NETWORK, &network_addr) != 1) 276 test_error("Can't convert ip address %s", TEST_NETWORK); 277 278 try_connect("Server: prefix match", port++, DEFAULT_TEST_PASSWORD, 279 this_ip_dest, -1, 100, 100, TEST_CNT_GOOD, 0); 280 281 try_connect("Client: prefix match", port++, DEFAULT_TEST_PASSWORD, 282 network_addr, 16, 100, 100, TEST_CNT_GOOD, 0); 283 284 return NULL; 285 } 286 287 int main(int argc, char *argv[]) 288 { 289 test_init(22, server_fn, client_fn); 290 return 0; 291 } 292