xref: /linux/tools/testing/selftests/net/tcp_ao/self-connect.c (revision ae22a94997b8a03dcb3c922857c203246711f9d4)
1 // SPDX-License-Identifier: GPL-2.0
2 /* Author: Dmitry Safonov <dima@arista.com> */
3 #include <inttypes.h>
4 #include "aolib.h"
5 
6 static union tcp_addr local_addr;
7 
8 static void __setup_lo_intf(const char *lo_intf,
9 			    const char *addr_str, uint8_t prefix)
10 {
11 	if (inet_pton(TEST_FAMILY, addr_str, &local_addr) != 1)
12 		test_error("Can't convert local ip address");
13 
14 	if (ip_addr_add(lo_intf, TEST_FAMILY, local_addr, prefix))
15 		test_error("Failed to add %s ip address", lo_intf);
16 
17 	if (link_set_up(lo_intf))
18 		test_error("Failed to bring %s up", lo_intf);
19 }
20 
21 static void setup_lo_intf(const char *lo_intf)
22 {
23 #ifdef IPV6_TEST
24 	__setup_lo_intf(lo_intf, "::1", 128);
25 #else
26 	__setup_lo_intf(lo_intf, "127.0.0.1", 8);
27 #endif
28 }
29 
30 static void tcp_self_connect(const char *tst, unsigned int port,
31 			     bool different_keyids, bool check_restore)
32 {
33 	uint64_t before_challenge_ack, after_challenge_ack;
34 	uint64_t before_syn_challenge, after_syn_challenge;
35 	struct tcp_ao_counters before_ao, after_ao;
36 	uint64_t before_aogood, after_aogood;
37 	struct netstat *ns_before, *ns_after;
38 	const size_t nr_packets = 20;
39 	struct tcp_ao_repair ao_img;
40 	struct tcp_sock_state img;
41 	sockaddr_af addr;
42 	int sk;
43 
44 	tcp_addr_to_sockaddr_in(&addr, &local_addr, htons(port));
45 
46 	sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
47 	if (sk < 0)
48 		test_error("socket()");
49 
50 	if (different_keyids) {
51 		if (test_add_key(sk, DEFAULT_TEST_PASSWORD, local_addr, -1, 5, 7))
52 			test_error("setsockopt(TCP_AO_ADD_KEY)");
53 		if (test_add_key(sk, DEFAULT_TEST_PASSWORD, local_addr, -1, 7, 5))
54 			test_error("setsockopt(TCP_AO_ADD_KEY)");
55 	} else {
56 		if (test_add_key(sk, DEFAULT_TEST_PASSWORD, local_addr, -1, 100, 100))
57 			test_error("setsockopt(TCP_AO_ADD_KEY)");
58 	}
59 
60 	if (bind(sk, (struct sockaddr *)&addr, sizeof(addr)) < 0)
61 		test_error("bind()");
62 
63 	ns_before = netstat_read();
64 	before_aogood = netstat_get(ns_before, "TCPAOGood", NULL);
65 	before_challenge_ack = netstat_get(ns_before, "TCPChallengeACK", NULL);
66 	before_syn_challenge = netstat_get(ns_before, "TCPSYNChallenge", NULL);
67 	if (test_get_tcp_ao_counters(sk, &before_ao))
68 		test_error("test_get_tcp_ao_counters()");
69 
70 	if (__test_connect_socket(sk, "lo", (struct sockaddr *)&addr,
71 				  sizeof(addr), TEST_TIMEOUT_SEC) < 0) {
72 		ns_after = netstat_read();
73 		netstat_print_diff(ns_before, ns_after);
74 		test_error("failed to connect()");
75 	}
76 
77 	if (test_client_verify(sk, 100, nr_packets, TEST_TIMEOUT_SEC)) {
78 		test_fail("%s: tcp connection verify failed", tst);
79 		close(sk);
80 		return;
81 	}
82 
83 	ns_after = netstat_read();
84 	after_aogood = netstat_get(ns_after, "TCPAOGood", NULL);
85 	after_challenge_ack = netstat_get(ns_after, "TCPChallengeACK", NULL);
86 	after_syn_challenge = netstat_get(ns_after, "TCPSYNChallenge", NULL);
87 	if (test_get_tcp_ao_counters(sk, &after_ao))
88 		test_error("test_get_tcp_ao_counters()");
89 	if (!check_restore) {
90 		/* to debug: netstat_print_diff(ns_before, ns_after); */
91 		netstat_free(ns_before);
92 	}
93 	netstat_free(ns_after);
94 
95 	if (after_aogood <= before_aogood) {
96 		test_fail("%s: TCPAOGood counter mismatch: %zu <= %zu",
97 			  tst, after_aogood, before_aogood);
98 		close(sk);
99 		return;
100 	}
101 	if (after_challenge_ack <= before_challenge_ack ||
102 	    after_syn_challenge <= before_syn_challenge) {
103 		/*
104 		 * It's also meant to test simultaneous open, so check
105 		 * these counters as well.
106 		 */
107 		test_fail("%s: Didn't challenge SYN or ACK: %zu <= %zu OR %zu <= %zu",
108 			  tst, after_challenge_ack, before_challenge_ack,
109 			  after_syn_challenge, before_syn_challenge);
110 		close(sk);
111 		return;
112 	}
113 
114 	if (test_tcp_ao_counters_cmp(tst, &before_ao, &after_ao, TEST_CNT_GOOD)) {
115 		close(sk);
116 		return;
117 	}
118 
119 	if (!check_restore) {
120 		test_ok("%s: connect TCPAOGood %" PRIu64 " => %" PRIu64,
121 				tst, before_aogood, after_aogood);
122 		close(sk);
123 		return;
124 	}
125 
126 	test_enable_repair(sk);
127 	test_sock_checkpoint(sk, &img, &addr);
128 #ifdef IPV6_TEST
129 	addr.sin6_port = htons(port + 1);
130 #else
131 	addr.sin_port = htons(port + 1);
132 #endif
133 	test_ao_checkpoint(sk, &ao_img);
134 	test_kill_sk(sk);
135 
136 	sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
137 	if (sk < 0)
138 		test_error("socket()");
139 
140 	test_enable_repair(sk);
141 	__test_sock_restore(sk, "lo", &img, &addr, &addr, sizeof(addr));
142 	if (different_keyids) {
143 		if (test_add_repaired_key(sk, DEFAULT_TEST_PASSWORD, 0,
144 					  local_addr, -1, 7, 5))
145 			test_error("setsockopt(TCP_AO_ADD_KEY)");
146 		if (test_add_repaired_key(sk, DEFAULT_TEST_PASSWORD, 0,
147 					  local_addr, -1, 5, 7))
148 			test_error("setsockopt(TCP_AO_ADD_KEY)");
149 	} else {
150 		if (test_add_repaired_key(sk, DEFAULT_TEST_PASSWORD, 0,
151 					  local_addr, -1, 100, 100))
152 			test_error("setsockopt(TCP_AO_ADD_KEY)");
153 	}
154 	test_ao_restore(sk, &ao_img);
155 	test_disable_repair(sk);
156 	test_sock_state_free(&img);
157 	if (test_client_verify(sk, 100, nr_packets, TEST_TIMEOUT_SEC)) {
158 		test_fail("%s: tcp connection verify failed", tst);
159 		close(sk);
160 		return;
161 	}
162 	ns_after = netstat_read();
163 	after_aogood = netstat_get(ns_after, "TCPAOGood", NULL);
164 	/* to debug: netstat_print_diff(ns_before, ns_after); */
165 	netstat_free(ns_before);
166 	netstat_free(ns_after);
167 	close(sk);
168 	if (after_aogood <= before_aogood) {
169 		test_fail("%s: TCPAOGood counter mismatch: %zu <= %zu",
170 			  tst, after_aogood, before_aogood);
171 		return;
172 	}
173 	test_ok("%s: connect TCPAOGood %" PRIu64 " => %" PRIu64,
174 			tst, before_aogood, after_aogood);
175 }
176 
177 static void *client_fn(void *arg)
178 {
179 	unsigned int port = test_server_port;
180 
181 	setup_lo_intf("lo");
182 
183 	tcp_self_connect("self-connect(same keyids)", port++, false, false);
184 	tcp_self_connect("self-connect(different keyids)", port++, true, false);
185 	tcp_self_connect("self-connect(restore)", port, false, true);
186 	port += 2;
187 	tcp_self_connect("self-connect(restore, different keyids)", port, true, true);
188 	port += 2;
189 
190 	return NULL;
191 }
192 
193 int main(int argc, char *argv[])
194 {
195 	test_init(4, client_fn, NULL);
196 	return 0;
197 }
198