xref: /linux/tools/testing/selftests/net/tcp_ao/rst.c (revision 9e56ff53b4115875667760445b028357848b4748)
1 // SPDX-License-Identifier: GPL-2.0
2 /* Author: Dmitry Safonov <dima@arista.com> */
3 #include <inttypes.h>
4 #include "../../../../include/linux/kernel.h"
5 #include "aolib.h"
6 
7 const size_t quota = 1000;
8 /*
9  * Backlog == 0 means 1 connection in queue, see:
10  * commit 64a146513f8f ("[NET]: Revert incorrect accept queue...")
11  */
12 const unsigned int backlog;
13 
14 static void netstats_check(struct netstat *before, struct netstat *after,
15 			   char *msg)
16 {
17 	uint64_t before_cnt, after_cnt;
18 
19 	before_cnt = netstat_get(before, "TCPAORequired", NULL);
20 	after_cnt = netstat_get(after, "TCPAORequired", NULL);
21 	if (after_cnt > before_cnt)
22 		test_fail("Segments without AO sign (%s): %" PRIu64 " => %" PRIu64,
23 			  msg, before_cnt, after_cnt);
24 	else
25 		test_ok("No segments without AO sign (%s)", msg);
26 
27 	before_cnt = netstat_get(before, "TCPAOGood", NULL);
28 	after_cnt = netstat_get(after, "TCPAOGood", NULL);
29 	if (after_cnt <= before_cnt)
30 		test_fail("Signed AO segments (%s): %" PRIu64 " => %" PRIu64,
31 			  msg, before_cnt, after_cnt);
32 	else
33 		test_ok("Signed AO segments (%s): %" PRIu64 " => %" PRIu64,
34 			  msg, before_cnt, after_cnt);
35 
36 	before_cnt = netstat_get(before, "TCPAOBad", NULL);
37 	after_cnt = netstat_get(after, "TCPAOBad", NULL);
38 	if (after_cnt > before_cnt)
39 		test_fail("Segments with bad AO sign (%s): %" PRIu64 " => %" PRIu64,
40 			  msg, before_cnt, after_cnt);
41 	else
42 		test_ok("No segments with bad AO sign (%s)", msg);
43 }
44 
45 /*
46  * Another way to send RST, but not through tcp_v{4,6}_send_reset()
47  * is tcp_send_active_reset(), that is not in reply to inbound segment,
48  * but rather active send. It uses tcp_transmit_skb(), so that should
49  * work, but as it also sends RST - nice that it can be covered as well.
50  */
51 static void close_forced(int sk)
52 {
53 	struct linger sl;
54 
55 	sl.l_onoff = 1;
56 	sl.l_linger = 0;
57 	if (setsockopt(sk, SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)))
58 		test_error("setsockopt(SO_LINGER)");
59 	close(sk);
60 }
61 
62 static int test_wait_for_exception(int sk, time_t sec)
63 {
64 	struct timeval tv = { .tv_sec = sec };
65 	struct timeval *ptv = NULL;
66 	fd_set efds;
67 	int ret;
68 
69 	FD_ZERO(&efds);
70 	FD_SET(sk, &efds);
71 
72 	if (sec)
73 		ptv = &tv;
74 
75 	errno = 0;
76 	ret = select(sk + 1, NULL, NULL, &efds, ptv);
77 	if (ret < 0)
78 		return -errno;
79 	return ret ? sk : 0;
80 }
81 
82 static void test_server_active_rst(unsigned int port)
83 {
84 	struct tcp_ao_counters cnt1, cnt2;
85 	ssize_t bytes;
86 	int sk, lsk;
87 
88 	lsk = test_listen_socket(this_ip_addr, port, backlog);
89 	if (test_add_key(lsk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100))
90 		test_error("setsockopt(TCP_AO_ADD_KEY)");
91 	if (test_get_tcp_ao_counters(lsk, &cnt1))
92 		test_error("test_get_tcp_ao_counters()");
93 
94 	synchronize_threads(); /* 1: MKT added */
95 	if (test_wait_fd(lsk, TEST_TIMEOUT_SEC, 0))
96 		test_error("test_wait_fd()");
97 
98 	sk = accept(lsk, NULL, NULL);
99 	if (sk < 0)
100 		test_error("accept()");
101 
102 	synchronize_threads(); /* 2: connection accept()ed, another queued */
103 	if (test_get_tcp_ao_counters(lsk, &cnt2))
104 		test_error("test_get_tcp_ao_counters()");
105 
106 	synchronize_threads(); /* 3: close listen socket */
107 	close(lsk);
108 	bytes = test_server_run(sk, quota, 0);
109 	if (bytes != quota)
110 		test_error("servered only %zd bytes", bytes);
111 	else
112 		test_ok("servered %zd bytes", bytes);
113 
114 	synchronize_threads(); /* 4: finishing up */
115 	close_forced(sk);
116 
117 	synchronize_threads(); /* 5: closed active sk */
118 
119 	synchronize_threads(); /* 6: counters checks */
120 	if (test_tcp_ao_counters_cmp("active RST server", &cnt1, &cnt2, TEST_CNT_GOOD))
121 		test_fail("MKT counters (server) have not only good packets");
122 	else
123 		test_ok("MKT counters are good on server");
124 }
125 
126 static void test_server_passive_rst(unsigned int port)
127 {
128 	struct tcp_ao_counters ao1, ao2;
129 	int sk, lsk;
130 	ssize_t bytes;
131 
132 	lsk = test_listen_socket(this_ip_addr, port, 1);
133 
134 	if (test_add_key(lsk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100))
135 		test_error("setsockopt(TCP_AO_ADD_KEY)");
136 
137 	synchronize_threads(); /* 1: MKT added => connect() */
138 	if (test_wait_fd(lsk, TEST_TIMEOUT_SEC, 0))
139 		test_error("test_wait_fd()");
140 
141 	sk = accept(lsk, NULL, NULL);
142 	if (sk < 0)
143 		test_error("accept()");
144 
145 	synchronize_threads(); /* 2: accepted => send data */
146 	close(lsk);
147 	if (test_get_tcp_ao_counters(sk, &ao1))
148 		test_error("test_get_tcp_ao_counters()");
149 
150 	bytes = test_server_run(sk, quota, TEST_TIMEOUT_SEC);
151 	if (bytes != quota) {
152 		if (bytes > 0)
153 			test_fail("server served: %zd", bytes);
154 		else
155 			test_fail("server returned %zd", bytes);
156 	}
157 
158 	synchronize_threads(); /* 3: chekpoint/restore the connection */
159 	if (test_get_tcp_ao_counters(sk, &ao2))
160 		test_error("test_get_tcp_ao_counters()");
161 
162 	synchronize_threads(); /* 4: terminate server + send more on client */
163 	bytes = test_server_run(sk, quota, TEST_RETRANSMIT_SEC);
164 	close(sk);
165 	test_tcp_ao_counters_cmp("passive RST server", &ao1, &ao2, TEST_CNT_GOOD);
166 
167 	synchronize_threads(); /* 5: verified => closed */
168 	close(sk);
169 }
170 
171 static void *server_fn(void *arg)
172 {
173 	struct netstat *ns_before, *ns_after;
174 	unsigned int port = test_server_port;
175 
176 	ns_before = netstat_read();
177 
178 	test_server_active_rst(port++);
179 	test_server_passive_rst(port++);
180 
181 	ns_after = netstat_read();
182 	netstats_check(ns_before, ns_after, "server");
183 	netstat_free(ns_after);
184 	netstat_free(ns_before);
185 	synchronize_threads(); /* exit */
186 
187 	synchronize_threads(); /* don't race to exit() - client exits */
188 	return NULL;
189 }
190 
191 static int test_wait_fds(int sk[], size_t nr, bool is_writable[],
192 			 ssize_t wait_for, time_t sec)
193 {
194 	struct timeval tv = { .tv_sec = sec };
195 	struct timeval *ptv = NULL;
196 	fd_set left;
197 	size_t i;
198 	int ret;
199 
200 	FD_ZERO(&left);
201 	for (i = 0; i < nr; i++) {
202 		FD_SET(sk[i], &left);
203 		if (is_writable)
204 			is_writable[i] = false;
205 	}
206 
207 	if (sec)
208 		ptv = &tv;
209 
210 	do {
211 		bool is_empty = true;
212 		fd_set fds, efds;
213 		int nfd = 0;
214 
215 		FD_ZERO(&fds);
216 		FD_ZERO(&efds);
217 		for (i = 0; i < nr; i++) {
218 			if (!FD_ISSET(sk[i], &left))
219 				continue;
220 
221 			if (sk[i] > nfd)
222 				nfd = sk[i];
223 
224 			FD_SET(sk[i], &fds);
225 			FD_SET(sk[i], &efds);
226 			is_empty = false;
227 		}
228 		if (is_empty)
229 			return -ENOENT;
230 
231 		errno = 0;
232 		ret = select(nfd + 1, NULL, &fds, &efds, ptv);
233 		if (ret < 0)
234 			return -errno;
235 		if (!ret)
236 			return -ETIMEDOUT;
237 		for (i = 0; i < nr; i++) {
238 			if (FD_ISSET(sk[i], &fds)) {
239 				if (is_writable)
240 					is_writable[i] = true;
241 				FD_CLR(sk[i], &left);
242 				wait_for--;
243 				continue;
244 			}
245 			if (FD_ISSET(sk[i], &efds)) {
246 				FD_CLR(sk[i], &left);
247 				wait_for--;
248 			}
249 		}
250 	} while (wait_for > 0);
251 
252 	return 0;
253 }
254 
255 static void test_client_active_rst(unsigned int port)
256 {
257 	/* one in queue, another accept()ed */
258 	unsigned int wait_for = backlog + 2;
259 	int i, sk[3], err;
260 	bool is_writable[ARRAY_SIZE(sk)] = {false};
261 	unsigned int last = ARRAY_SIZE(sk) - 1;
262 
263 	for (i = 0; i < ARRAY_SIZE(sk); i++) {
264 		sk[i] = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
265 		if (sk[i] < 0)
266 			test_error("socket()");
267 		if (test_add_key(sk[i], DEFAULT_TEST_PASSWORD,
268 				 this_ip_dest, -1, 100, 100))
269 			test_error("setsockopt(TCP_AO_ADD_KEY)");
270 	}
271 
272 	synchronize_threads(); /* 1: MKT added */
273 	for (i = 0; i < last; i++) {
274 		err = _test_connect_socket(sk[i], this_ip_dest, port,
275 					       (i == 0) ? TEST_TIMEOUT_SEC : -1);
276 
277 		if (err < 0)
278 			test_error("failed to connect()");
279 	}
280 
281 	synchronize_threads(); /* 2: connection accept()ed, another queued */
282 	err = test_wait_fds(sk, last, is_writable, wait_for, TEST_TIMEOUT_SEC);
283 	if (err < 0)
284 		test_error("test_wait_fds(): %d", err);
285 
286 	synchronize_threads(); /* 3: close listen socket */
287 	if (test_client_verify(sk[0], 100, quota / 100, TEST_TIMEOUT_SEC))
288 		test_fail("Failed to send data on connected socket");
289 	else
290 		test_ok("Verified established tcp connection");
291 
292 	synchronize_threads(); /* 4: finishing up */
293 	err = _test_connect_socket(sk[last], this_ip_dest, port, -1);
294 	if (err < 0)
295 		test_error("failed to connect()");
296 
297 	synchronize_threads(); /* 5: closed active sk */
298 	err = test_wait_fds(sk, ARRAY_SIZE(sk), NULL,
299 			    wait_for, TEST_TIMEOUT_SEC);
300 	if (err < 0)
301 		test_error("select(): %d", err);
302 
303 	for (i = 0; i < ARRAY_SIZE(sk); i++) {
304 		socklen_t slen = sizeof(err);
305 
306 		if (getsockopt(sk[i], SOL_SOCKET, SO_ERROR, &err, &slen))
307 			test_error("getsockopt()");
308 		if (is_writable[i] && err != ECONNRESET) {
309 			test_fail("sk[%d] = %d, err = %d, connection wasn't reset",
310 				  i, sk[i], err);
311 		} else {
312 			test_ok("sk[%d] = %d%s", i, sk[i],
313 				is_writable[i] ? ", connection was reset" : "");
314 		}
315 	}
316 	synchronize_threads(); /* 6: counters checks */
317 }
318 
319 static void test_client_passive_rst(unsigned int port)
320 {
321 	struct tcp_ao_counters ao1, ao2;
322 	struct tcp_ao_repair ao_img;
323 	struct tcp_sock_state img;
324 	sockaddr_af saddr;
325 	int sk, err;
326 	socklen_t slen = sizeof(err);
327 
328 	sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
329 	if (sk < 0)
330 		test_error("socket()");
331 
332 	if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100))
333 		test_error("setsockopt(TCP_AO_ADD_KEY)");
334 
335 	synchronize_threads(); /* 1: MKT added => connect() */
336 	if (test_connect_socket(sk, this_ip_dest, port) <= 0)
337 		test_error("failed to connect()");
338 
339 	synchronize_threads(); /* 2: accepted => send data */
340 	if (test_client_verify(sk, 100, quota / 100, TEST_TIMEOUT_SEC))
341 		test_fail("Failed to send data on connected socket");
342 	else
343 		test_ok("Verified established tcp connection");
344 
345 	synchronize_threads(); /* 3: chekpoint/restore the connection */
346 	test_enable_repair(sk);
347 	test_sock_checkpoint(sk, &img, &saddr);
348 	test_ao_checkpoint(sk, &ao_img);
349 	test_kill_sk(sk);
350 
351 	img.out.seq += quota;
352 
353 	sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
354 	if (sk < 0)
355 		test_error("socket()");
356 
357 	test_enable_repair(sk);
358 	test_sock_restore(sk, &img, &saddr, this_ip_dest, port);
359 	if (test_add_repaired_key(sk, DEFAULT_TEST_PASSWORD, 0, this_ip_dest, -1, 100, 100))
360 		test_error("setsockopt(TCP_AO_ADD_KEY)");
361 	test_ao_restore(sk, &ao_img);
362 
363 	if (test_get_tcp_ao_counters(sk, &ao1))
364 		test_error("test_get_tcp_ao_counters()");
365 
366 	test_disable_repair(sk);
367 	test_sock_state_free(&img);
368 
369 	synchronize_threads(); /* 4: terminate server + send more on client */
370 	if (test_client_verify(sk, 100, quota / 100, 2 * TEST_TIMEOUT_SEC))
371 		test_ok("client connection broken post-seq-adjust");
372 	else
373 		test_fail("client connection still works post-seq-adjust");
374 
375 	test_wait_for_exception(sk, TEST_TIMEOUT_SEC);
376 
377 	if (getsockopt(sk, SOL_SOCKET, SO_ERROR, &err, &slen))
378 		test_error("getsockopt()");
379 	if (err != ECONNRESET && err != EPIPE)
380 		test_fail("client connection was not reset: %d", err);
381 	else
382 		test_ok("client connection was reset");
383 
384 	if (test_get_tcp_ao_counters(sk, &ao2))
385 		test_error("test_get_tcp_ao_counters()");
386 
387 	synchronize_threads(); /* 5: verified => closed */
388 	close(sk);
389 	test_tcp_ao_counters_cmp("client passive RST", &ao1, &ao2, TEST_CNT_GOOD);
390 }
391 
392 static void *client_fn(void *arg)
393 {
394 	struct netstat *ns_before, *ns_after;
395 	unsigned int port = test_server_port;
396 
397 	ns_before = netstat_read();
398 
399 	test_client_active_rst(port++);
400 	test_client_passive_rst(port++);
401 
402 	ns_after = netstat_read();
403 	netstats_check(ns_before, ns_after, "client");
404 	netstat_free(ns_after);
405 	netstat_free(ns_before);
406 
407 	synchronize_threads(); /* exit */
408 	return NULL;
409 }
410 
411 int main(int argc, char *argv[])
412 {
413 	test_init(15, server_fn, client_fn);
414 	return 0;
415 }
416