1 // SPDX-License-Identifier: GPL-2.0
2 /* Author: Dmitry Safonov <dima@arista.com> */
3 #include <inttypes.h>
4 #include "aolib.h"
5
6 #define fault(type) (inj == FAULT_ ## type)
7 static volatile int sk_pair;
8
test_add_key_maclen(int sk,const char * key,uint8_t maclen,union tcp_addr in_addr,uint8_t prefix,uint8_t sndid,uint8_t rcvid)9 static inline int test_add_key_maclen(int sk, const char *key, uint8_t maclen,
10 union tcp_addr in_addr, uint8_t prefix,
11 uint8_t sndid, uint8_t rcvid)
12 {
13 struct tcp_ao_add tmp = {};
14 int err;
15
16 if (prefix > DEFAULT_TEST_PREFIX)
17 prefix = DEFAULT_TEST_PREFIX;
18
19 err = test_prepare_key(&tmp, DEFAULT_TEST_ALGO, in_addr, false, false,
20 prefix, 0, sndid, rcvid, maclen,
21 0, strlen(key), key);
22 if (err)
23 return err;
24
25 err = setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &tmp, sizeof(tmp));
26 if (err < 0)
27 return -errno;
28
29 return test_verify_socket_key(sk, &tmp);
30 }
31
try_accept(const char * tst_name,unsigned int port,const char * pwd,union tcp_addr addr,uint8_t prefix,uint8_t sndid,uint8_t rcvid,uint8_t maclen,const char * cnt_name,test_cnt cnt_expected,fault_t inj)32 static void try_accept(const char *tst_name, unsigned int port, const char *pwd,
33 union tcp_addr addr, uint8_t prefix,
34 uint8_t sndid, uint8_t rcvid, uint8_t maclen,
35 const char *cnt_name, test_cnt cnt_expected,
36 fault_t inj)
37 {
38 struct tcp_counters cnt1, cnt2;
39 uint64_t before_cnt = 0, after_cnt = 0; /* silence GCC */
40 test_cnt poll_cnt = (cnt_expected == TEST_CNT_GOOD) ? 0 : cnt_expected;
41 int lsk, err, sk = 0;
42
43 lsk = test_listen_socket(this_ip_addr, port, 1);
44
45 if (pwd && test_add_key_maclen(lsk, pwd, maclen, addr, prefix, sndid, rcvid))
46 test_error("setsockopt(TCP_AO_ADD_KEY)");
47
48 if (cnt_name)
49 before_cnt = netstat_get_one(cnt_name, NULL);
50 if (pwd && test_get_tcp_counters(lsk, &cnt1))
51 test_error("test_get_tcp_counters()");
52
53 synchronize_threads(); /* preparations done */
54
55 err = test_skpair_wait_poll(lsk, 0, poll_cnt, &sk_pair);
56 if (err == -ETIMEDOUT) {
57 sk_pair = err;
58 if (!fault(TIMEOUT))
59 test_fail("%s: timed out for accept()", tst_name);
60 } else if (err == -EKEYREJECTED) {
61 if (!fault(KEYREJECT))
62 test_fail("%s: key was rejected", tst_name);
63 } else if (err < 0) {
64 test_error("test_skpair_wait_poll()");
65 } else {
66 if (fault(TIMEOUT))
67 test_fail("%s: ready to accept", tst_name);
68
69 sk = accept(lsk, NULL, NULL);
70 if (sk < 0) {
71 test_error("accept()");
72 } else {
73 if (fault(TIMEOUT))
74 test_fail("%s: accepted", tst_name);
75 }
76 }
77
78 synchronize_threads(); /* before counter checks */
79 if (pwd && test_get_tcp_counters(lsk, &cnt2))
80 test_error("test_get_tcp_counters()");
81
82 close(lsk);
83
84 if (pwd)
85 test_assert_counters(tst_name, &cnt1, &cnt2, cnt_expected);
86
87 if (!cnt_name)
88 goto out;
89
90 after_cnt = netstat_get_one(cnt_name, NULL);
91
92 if (after_cnt <= before_cnt) {
93 test_fail("%s: %s counter did not increase: %" PRIu64 " <= %" PRIu64,
94 tst_name, cnt_name, after_cnt, before_cnt);
95 } else {
96 test_ok("%s: counter %s increased %" PRIu64 " => %" PRIu64,
97 tst_name, cnt_name, before_cnt, after_cnt);
98 }
99
100 out:
101 synchronize_threads(); /* close() */
102 if (sk > 0)
103 close(sk);
104 }
105
server_fn(void * arg)106 static void *server_fn(void *arg)
107 {
108 union tcp_addr wrong_addr, network_addr;
109 unsigned int port = test_server_port;
110
111 if (inet_pton(TEST_FAMILY, TEST_WRONG_IP, &wrong_addr) != 1)
112 test_error("Can't convert ip address %s", TEST_WRONG_IP);
113
114 try_accept("Non-AO server + AO client", port++, NULL,
115 this_ip_dest, -1, 100, 100, 0,
116 "TCPAOKeyNotFound", TEST_CNT_NS_KEY_NOT_FOUND, FAULT_TIMEOUT);
117
118 try_accept("AO server + Non-AO client", port++, DEFAULT_TEST_PASSWORD,
119 this_ip_dest, -1, 100, 100, 0,
120 "TCPAORequired", TEST_CNT_AO_REQUIRED, FAULT_TIMEOUT);
121
122 try_accept("Wrong password", port++, "something that is not DEFAULT_TEST_PASSWORD",
123 this_ip_dest, -1, 100, 100, 0,
124 "TCPAOBad", TEST_CNT_BAD, FAULT_TIMEOUT);
125
126 try_accept("Wrong rcv id", port++, DEFAULT_TEST_PASSWORD,
127 this_ip_dest, -1, 100, 101, 0,
128 "TCPAOKeyNotFound", TEST_CNT_AO_KEY_NOT_FOUND, FAULT_TIMEOUT);
129
130 try_accept("Wrong snd id", port++, DEFAULT_TEST_PASSWORD,
131 this_ip_dest, -1, 101, 100, 0,
132 "TCPAOGood", TEST_CNT_GOOD, FAULT_TIMEOUT);
133
134 try_accept("Different maclen", port++, DEFAULT_TEST_PASSWORD,
135 this_ip_dest, -1, 100, 100, 8,
136 "TCPAOBad", TEST_CNT_BAD, FAULT_TIMEOUT);
137
138 try_accept("Server: Wrong addr", port++, DEFAULT_TEST_PASSWORD,
139 wrong_addr, -1, 100, 100, 0,
140 "TCPAOKeyNotFound", TEST_CNT_AO_KEY_NOT_FOUND, FAULT_TIMEOUT);
141
142 /* Key rejected by the other side, failing short through skpair */
143 try_accept("Client: Wrong addr", port++, NULL,
144 this_ip_dest, -1, 100, 100, 0, NULL, 0, FAULT_KEYREJECT);
145
146 try_accept("rcv id != snd id", port++, DEFAULT_TEST_PASSWORD,
147 this_ip_dest, -1, 200, 100, 0,
148 "TCPAOGood", TEST_CNT_GOOD, 0);
149
150 if (inet_pton(TEST_FAMILY, TEST_NETWORK, &network_addr) != 1)
151 test_error("Can't convert ip address %s", TEST_NETWORK);
152
153 try_accept("Server: prefix match", port++, DEFAULT_TEST_PASSWORD,
154 network_addr, 16, 100, 100, 0,
155 "TCPAOGood", TEST_CNT_GOOD, 0);
156
157 try_accept("Client: prefix match", port++, DEFAULT_TEST_PASSWORD,
158 this_ip_dest, -1, 100, 100, 0,
159 "TCPAOGood", TEST_CNT_GOOD, 0);
160
161 /* client exits */
162 synchronize_threads();
163 return NULL;
164 }
165
try_connect(const char * tst_name,unsigned int port,const char * pwd,union tcp_addr addr,uint8_t prefix,uint8_t sndid,uint8_t rcvid,test_cnt cnt_expected,fault_t inj)166 static void try_connect(const char *tst_name, unsigned int port,
167 const char *pwd, union tcp_addr addr, uint8_t prefix,
168 uint8_t sndid, uint8_t rcvid,
169 test_cnt cnt_expected, fault_t inj)
170 {
171 struct tcp_counters cnt1, cnt2;
172 int sk, ret;
173
174 sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
175 if (sk < 0)
176 test_error("socket()");
177
178 if (pwd && test_add_key(sk, pwd, addr, prefix, sndid, rcvid))
179 test_error("setsockopt(TCP_AO_ADD_KEY)");
180
181 if (pwd && test_get_tcp_counters(sk, &cnt1))
182 test_error("test_get_tcp_counters()");
183
184 synchronize_threads(); /* preparations done */
185
186 ret = test_skpair_connect_poll(sk, this_ip_dest, port, cnt_expected, &sk_pair);
187 synchronize_threads(); /* before counter checks */
188 if (ret < 0) {
189 sk_pair = ret;
190 if (fault(KEYREJECT) && ret == -EKEYREJECTED) {
191 test_ok("%s: connect() was prevented", tst_name);
192 } else if (ret == -ETIMEDOUT && fault(TIMEOUT)) {
193 test_ok("%s", tst_name);
194 } else if (ret == -ECONNREFUSED &&
195 (fault(TIMEOUT) || fault(KEYREJECT))) {
196 test_ok("%s: refused to connect", tst_name);
197 } else {
198 test_error("%s: connect() returned %d", tst_name, ret);
199 }
200 goto out;
201 }
202
203 if (fault(TIMEOUT) || fault(KEYREJECT))
204 test_fail("%s: connected", tst_name);
205 else
206 test_ok("%s: connected", tst_name);
207 if (pwd && ret > 0) {
208 if (test_get_tcp_counters(sk, &cnt2))
209 test_error("test_get_tcp_counters()");
210 test_assert_counters(tst_name, &cnt1, &cnt2, cnt_expected);
211 } else if (pwd) {
212 test_tcp_counters_free(&cnt1);
213 }
214 out:
215 synchronize_threads(); /* close() */
216
217 if (ret > 0)
218 close(sk);
219 }
220
client_fn(void * arg)221 static void *client_fn(void *arg)
222 {
223 union tcp_addr wrong_addr, network_addr, addr_any = {};
224 unsigned int port = test_server_port;
225
226 if (inet_pton(TEST_FAMILY, TEST_WRONG_IP, &wrong_addr) != 1)
227 test_error("Can't convert ip address %s", TEST_WRONG_IP);
228
229 trace_ao_event_expect(TCP_AO_KEY_NOT_FOUND, this_ip_addr, this_ip_dest,
230 -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1);
231 try_connect("Non-AO server + AO client", port++, DEFAULT_TEST_PASSWORD,
232 this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
233
234 trace_hash_event_expect(TCP_HASH_AO_REQUIRED, this_ip_addr, this_ip_dest,
235 -1, port, 0, 0, 1, 0, 0, 0);
236 try_connect("AO server + Non-AO client", port++, NULL,
237 this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
238
239 trace_ao_event_expect(TCP_AO_MISMATCH, this_ip_addr, this_ip_dest,
240 -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1);
241 try_connect("Wrong password", port++, DEFAULT_TEST_PASSWORD,
242 this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
243
244 trace_ao_event_expect(TCP_AO_KEY_NOT_FOUND, this_ip_addr, this_ip_dest,
245 -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1);
246 try_connect("Wrong rcv id", port++, DEFAULT_TEST_PASSWORD,
247 this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
248
249 /*
250 * XXX: The test doesn't increase any counters, see tcp_make_synack().
251 * Potentially, it can be speed up by setting sk_pair = -ETIMEDOUT
252 * but the price would be increased complexity of the tracer thread.
253 */
254 trace_ao_event_sk_expect(TCP_AO_SYNACK_NO_KEY, this_ip_dest, addr_any,
255 port, 0, 100, 100);
256 try_connect("Wrong snd id", port++, DEFAULT_TEST_PASSWORD,
257 this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
258
259 trace_ao_event_expect(TCP_AO_WRONG_MACLEN, this_ip_addr, this_ip_dest,
260 -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1);
261 try_connect("Different maclen", port++, DEFAULT_TEST_PASSWORD,
262 this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
263
264 trace_ao_event_expect(TCP_AO_KEY_NOT_FOUND, this_ip_addr, this_ip_dest,
265 -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1);
266 try_connect("Server: Wrong addr", port++, DEFAULT_TEST_PASSWORD,
267 this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
268
269 try_connect("Client: Wrong addr", port++, DEFAULT_TEST_PASSWORD,
270 wrong_addr, -1, 100, 100, 0, FAULT_KEYREJECT);
271
272 try_connect("rcv id != snd id", port++, DEFAULT_TEST_PASSWORD,
273 this_ip_dest, -1, 100, 200, TEST_CNT_GOOD, 0);
274
275 if (inet_pton(TEST_FAMILY, TEST_NETWORK, &network_addr) != 1)
276 test_error("Can't convert ip address %s", TEST_NETWORK);
277
278 try_connect("Server: prefix match", port++, DEFAULT_TEST_PASSWORD,
279 this_ip_dest, -1, 100, 100, TEST_CNT_GOOD, 0);
280
281 try_connect("Client: prefix match", port++, DEFAULT_TEST_PASSWORD,
282 network_addr, 16, 100, 100, TEST_CNT_GOOD, 0);
283
284 return NULL;
285 }
286
main(int argc,char * argv[])287 int main(int argc, char *argv[])
288 {
289 test_init(22, server_fn, client_fn);
290 return 0;
291 }
292