xref: /linux/tools/testing/selftests/net/tcp_ao/connect-deny.c (revision 1a9239bb4253f9076b5b4b2a1a4e8d7defd77a95)
1 // SPDX-License-Identifier: GPL-2.0
2 /* Author: Dmitry Safonov <dima@arista.com> */
3 #include <inttypes.h>
4 #include "aolib.h"
5 
6 #define fault(type)	(inj == FAULT_ ## type)
7 static volatile int sk_pair;
8 
test_add_key_maclen(int sk,const char * key,uint8_t maclen,union tcp_addr in_addr,uint8_t prefix,uint8_t sndid,uint8_t rcvid)9 static inline int test_add_key_maclen(int sk, const char *key, uint8_t maclen,
10 				      union tcp_addr in_addr, uint8_t prefix,
11 				      uint8_t sndid, uint8_t rcvid)
12 {
13 	struct tcp_ao_add tmp = {};
14 	int err;
15 
16 	if (prefix > DEFAULT_TEST_PREFIX)
17 		prefix = DEFAULT_TEST_PREFIX;
18 
19 	err = test_prepare_key(&tmp, DEFAULT_TEST_ALGO, in_addr, false, false,
20 			       prefix, 0, sndid, rcvid, maclen,
21 			       0, strlen(key), key);
22 	if (err)
23 		return err;
24 
25 	err = setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &tmp, sizeof(tmp));
26 	if (err < 0)
27 		return -errno;
28 
29 	return test_verify_socket_key(sk, &tmp);
30 }
31 
try_accept(const char * tst_name,unsigned int port,const char * pwd,union tcp_addr addr,uint8_t prefix,uint8_t sndid,uint8_t rcvid,uint8_t maclen,const char * cnt_name,test_cnt cnt_expected,fault_t inj)32 static void try_accept(const char *tst_name, unsigned int port, const char *pwd,
33 		       union tcp_addr addr, uint8_t prefix,
34 		       uint8_t sndid, uint8_t rcvid, uint8_t maclen,
35 		       const char *cnt_name, test_cnt cnt_expected,
36 		       fault_t inj)
37 {
38 	struct tcp_counters cnt1, cnt2;
39 	uint64_t before_cnt = 0, after_cnt = 0; /* silence GCC */
40 	test_cnt poll_cnt = (cnt_expected == TEST_CNT_GOOD) ? 0 : cnt_expected;
41 	int lsk, err, sk = 0;
42 
43 	lsk = test_listen_socket(this_ip_addr, port, 1);
44 
45 	if (pwd && test_add_key_maclen(lsk, pwd, maclen, addr, prefix, sndid, rcvid))
46 		test_error("setsockopt(TCP_AO_ADD_KEY)");
47 
48 	if (cnt_name)
49 		before_cnt = netstat_get_one(cnt_name, NULL);
50 	if (pwd && test_get_tcp_counters(lsk, &cnt1))
51 		test_error("test_get_tcp_counters()");
52 
53 	synchronize_threads(); /* preparations done */
54 
55 	err = test_skpair_wait_poll(lsk, 0, poll_cnt, &sk_pair);
56 	if (err == -ETIMEDOUT) {
57 		sk_pair = err;
58 		if (!fault(TIMEOUT))
59 			test_fail("%s: timed out for accept()", tst_name);
60 	} else if (err == -EKEYREJECTED) {
61 		if (!fault(KEYREJECT))
62 			test_fail("%s: key was rejected", tst_name);
63 	} else if (err < 0) {
64 		test_error("test_skpair_wait_poll()");
65 	} else {
66 		if (fault(TIMEOUT))
67 			test_fail("%s: ready to accept", tst_name);
68 
69 		sk = accept(lsk, NULL, NULL);
70 		if (sk < 0) {
71 			test_error("accept()");
72 		} else {
73 			if (fault(TIMEOUT))
74 				test_fail("%s: accepted", tst_name);
75 		}
76 	}
77 
78 	synchronize_threads(); /* before counter checks */
79 	if (pwd && test_get_tcp_counters(lsk, &cnt2))
80 		test_error("test_get_tcp_counters()");
81 
82 	close(lsk);
83 
84 	if (pwd)
85 		test_assert_counters(tst_name, &cnt1, &cnt2, cnt_expected);
86 
87 	if (!cnt_name)
88 		goto out;
89 
90 	after_cnt = netstat_get_one(cnt_name, NULL);
91 
92 	if (after_cnt <= before_cnt) {
93 		test_fail("%s: %s counter did not increase: %" PRIu64 " <= %" PRIu64,
94 				tst_name, cnt_name, after_cnt, before_cnt);
95 	} else {
96 		test_ok("%s: counter %s increased %" PRIu64  " => %" PRIu64,
97 			tst_name, cnt_name, before_cnt, after_cnt);
98 	}
99 
100 out:
101 	synchronize_threads(); /* close() */
102 	if (sk > 0)
103 		close(sk);
104 }
105 
server_fn(void * arg)106 static void *server_fn(void *arg)
107 {
108 	union tcp_addr wrong_addr, network_addr;
109 	unsigned int port = test_server_port;
110 
111 	if (inet_pton(TEST_FAMILY, TEST_WRONG_IP, &wrong_addr) != 1)
112 		test_error("Can't convert ip address %s", TEST_WRONG_IP);
113 
114 	try_accept("Non-AO server + AO client", port++, NULL,
115 		   this_ip_dest, -1, 100, 100, 0,
116 		   "TCPAOKeyNotFound", TEST_CNT_NS_KEY_NOT_FOUND, FAULT_TIMEOUT);
117 
118 	try_accept("AO server + Non-AO client", port++, DEFAULT_TEST_PASSWORD,
119 		   this_ip_dest, -1, 100, 100, 0,
120 		   "TCPAORequired", TEST_CNT_AO_REQUIRED, FAULT_TIMEOUT);
121 
122 	try_accept("Wrong password", port++, "something that is not DEFAULT_TEST_PASSWORD",
123 		   this_ip_dest, -1, 100, 100, 0,
124 		   "TCPAOBad", TEST_CNT_BAD, FAULT_TIMEOUT);
125 
126 	try_accept("Wrong rcv id", port++, DEFAULT_TEST_PASSWORD,
127 		   this_ip_dest, -1, 100, 101, 0,
128 		   "TCPAOKeyNotFound", TEST_CNT_AO_KEY_NOT_FOUND, FAULT_TIMEOUT);
129 
130 	try_accept("Wrong snd id", port++, DEFAULT_TEST_PASSWORD,
131 		   this_ip_dest, -1, 101, 100, 0,
132 		   "TCPAOGood", TEST_CNT_GOOD, FAULT_TIMEOUT);
133 
134 	try_accept("Different maclen", port++, DEFAULT_TEST_PASSWORD,
135 		   this_ip_dest, -1, 100, 100, 8,
136 		   "TCPAOBad", TEST_CNT_BAD, FAULT_TIMEOUT);
137 
138 	try_accept("Server: Wrong addr", port++, DEFAULT_TEST_PASSWORD,
139 		   wrong_addr, -1, 100, 100, 0,
140 		   "TCPAOKeyNotFound", TEST_CNT_AO_KEY_NOT_FOUND, FAULT_TIMEOUT);
141 
142 	/* Key rejected by the other side, failing short through skpair */
143 	try_accept("Client: Wrong addr", port++, NULL,
144 		   this_ip_dest, -1, 100, 100, 0, NULL, 0, FAULT_KEYREJECT);
145 
146 	try_accept("rcv id != snd id", port++, DEFAULT_TEST_PASSWORD,
147 		   this_ip_dest, -1, 200, 100, 0,
148 		   "TCPAOGood", TEST_CNT_GOOD, 0);
149 
150 	if (inet_pton(TEST_FAMILY, TEST_NETWORK, &network_addr) != 1)
151 		test_error("Can't convert ip address %s", TEST_NETWORK);
152 
153 	try_accept("Server: prefix match", port++, DEFAULT_TEST_PASSWORD,
154 		   network_addr, 16, 100, 100, 0,
155 		   "TCPAOGood", TEST_CNT_GOOD, 0);
156 
157 	try_accept("Client: prefix match", port++, DEFAULT_TEST_PASSWORD,
158 		   this_ip_dest, -1, 100, 100, 0,
159 		   "TCPAOGood", TEST_CNT_GOOD, 0);
160 
161 	/* client exits */
162 	synchronize_threads();
163 	return NULL;
164 }
165 
try_connect(const char * tst_name,unsigned int port,const char * pwd,union tcp_addr addr,uint8_t prefix,uint8_t sndid,uint8_t rcvid,test_cnt cnt_expected,fault_t inj)166 static void try_connect(const char *tst_name, unsigned int port,
167 			const char *pwd, union tcp_addr addr, uint8_t prefix,
168 			uint8_t sndid, uint8_t rcvid,
169 			test_cnt cnt_expected, fault_t inj)
170 {
171 	struct tcp_counters cnt1, cnt2;
172 	int sk, ret;
173 
174 	sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
175 	if (sk < 0)
176 		test_error("socket()");
177 
178 	if (pwd && test_add_key(sk, pwd, addr, prefix, sndid, rcvid))
179 		test_error("setsockopt(TCP_AO_ADD_KEY)");
180 
181 	if (pwd && test_get_tcp_counters(sk, &cnt1))
182 		test_error("test_get_tcp_counters()");
183 
184 	synchronize_threads(); /* preparations done */
185 
186 	ret = test_skpair_connect_poll(sk, this_ip_dest, port, cnt_expected, &sk_pair);
187 	synchronize_threads(); /* before counter checks */
188 	if (ret < 0) {
189 		sk_pair = ret;
190 		if (fault(KEYREJECT) && ret == -EKEYREJECTED) {
191 			test_ok("%s: connect() was prevented", tst_name);
192 		} else if (ret == -ETIMEDOUT && fault(TIMEOUT)) {
193 			test_ok("%s", tst_name);
194 		} else if (ret == -ECONNREFUSED &&
195 				(fault(TIMEOUT) || fault(KEYREJECT))) {
196 			test_ok("%s: refused to connect", tst_name);
197 		} else {
198 			test_error("%s: connect() returned %d", tst_name, ret);
199 		}
200 		goto out;
201 	}
202 
203 	if (fault(TIMEOUT) || fault(KEYREJECT))
204 		test_fail("%s: connected", tst_name);
205 	else
206 		test_ok("%s: connected", tst_name);
207 	if (pwd && ret > 0) {
208 		if (test_get_tcp_counters(sk, &cnt2))
209 			test_error("test_get_tcp_counters()");
210 		test_assert_counters(tst_name, &cnt1, &cnt2, cnt_expected);
211 	} else if (pwd) {
212 		test_tcp_counters_free(&cnt1);
213 	}
214 out:
215 	synchronize_threads(); /* close() */
216 
217 	if (ret > 0)
218 		close(sk);
219 }
220 
client_fn(void * arg)221 static void *client_fn(void *arg)
222 {
223 	union tcp_addr wrong_addr, network_addr, addr_any = {};
224 	unsigned int port = test_server_port;
225 
226 	if (inet_pton(TEST_FAMILY, TEST_WRONG_IP, &wrong_addr) != 1)
227 		test_error("Can't convert ip address %s", TEST_WRONG_IP);
228 
229 	trace_ao_event_expect(TCP_AO_KEY_NOT_FOUND, this_ip_addr, this_ip_dest,
230 			      -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1);
231 	try_connect("Non-AO server + AO client", port++, DEFAULT_TEST_PASSWORD,
232 			this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
233 
234 	trace_hash_event_expect(TCP_HASH_AO_REQUIRED, this_ip_addr, this_ip_dest,
235 				-1, port, 0, 0, 1, 0, 0, 0);
236 	try_connect("AO server + Non-AO client", port++, NULL,
237 			this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
238 
239 	trace_ao_event_expect(TCP_AO_MISMATCH, this_ip_addr, this_ip_dest,
240 			      -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1);
241 	try_connect("Wrong password", port++, DEFAULT_TEST_PASSWORD,
242 			this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
243 
244 	trace_ao_event_expect(TCP_AO_KEY_NOT_FOUND, this_ip_addr, this_ip_dest,
245 			      -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1);
246 	try_connect("Wrong rcv id", port++, DEFAULT_TEST_PASSWORD,
247 			this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
248 
249 	/*
250 	 * XXX: The test doesn't increase any counters, see tcp_make_synack().
251 	 * Potentially, it can be speed up by setting sk_pair = -ETIMEDOUT
252 	 * but the price would be increased complexity of the tracer thread.
253 	 */
254 	trace_ao_event_sk_expect(TCP_AO_SYNACK_NO_KEY, this_ip_dest, addr_any,
255 				 port, 0, 100, 100);
256 	try_connect("Wrong snd id", port++, DEFAULT_TEST_PASSWORD,
257 			this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
258 
259 	trace_ao_event_expect(TCP_AO_WRONG_MACLEN, this_ip_addr, this_ip_dest,
260 			      -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1);
261 	try_connect("Different maclen", port++, DEFAULT_TEST_PASSWORD,
262 			this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
263 
264 	trace_ao_event_expect(TCP_AO_KEY_NOT_FOUND, this_ip_addr, this_ip_dest,
265 			      -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1);
266 	try_connect("Server: Wrong addr", port++, DEFAULT_TEST_PASSWORD,
267 			this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
268 
269 	try_connect("Client: Wrong addr", port++, DEFAULT_TEST_PASSWORD,
270 			wrong_addr, -1, 100, 100, 0, FAULT_KEYREJECT);
271 
272 	try_connect("rcv id != snd id", port++, DEFAULT_TEST_PASSWORD,
273 			this_ip_dest, -1, 100, 200, TEST_CNT_GOOD, 0);
274 
275 	if (inet_pton(TEST_FAMILY, TEST_NETWORK, &network_addr) != 1)
276 		test_error("Can't convert ip address %s", TEST_NETWORK);
277 
278 	try_connect("Server: prefix match", port++, DEFAULT_TEST_PASSWORD,
279 			this_ip_dest, -1, 100, 100, TEST_CNT_GOOD, 0);
280 
281 	try_connect("Client: prefix match", port++, DEFAULT_TEST_PASSWORD,
282 			network_addr, 16, 100, 100, TEST_CNT_GOOD, 0);
283 
284 	return NULL;
285 }
286 
main(int argc,char * argv[])287 int main(int argc, char *argv[])
288 {
289 	test_init(22, server_fn, client_fn);
290 	return 0;
291 }
292