xref: /linux/tools/testing/selftests/net/tcp_ao/icmps-discard.c (revision 186779c036468038b0d077ec5333a51512f867e5)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Selftest that verifies that incomping ICMPs are ignored,
4  * the TCP connection stays alive, no hard or soft errors get reported
5  * to the usespace and the counter for ignored ICMPs is updated.
6  *
7  * RFC5925, 7.8:
8  * >> A TCP-AO implementation MUST default to ignore incoming ICMPv4
9  * messages of Type 3 (destination unreachable), Codes 2-4 (protocol
10  * unreachable, port unreachable, and fragmentation needed -- ’hard
11  * errors’), and ICMPv6 Type 1 (destination unreachable), Code 1
12  * (administratively prohibited) and Code 4 (port unreachable) intended
13  * for connections in synchronized states (ESTABLISHED, FIN-WAIT-1, FIN-
14  * WAIT-2, CLOSE-WAIT, CLOSING, LAST-ACK, TIME-WAIT) that match MKTs.
15  *
16  * Author: Dmitry Safonov <dima@arista.com>
17  */
18 #include <inttypes.h>
19 #include <linux/icmp.h>
20 #include <linux/icmpv6.h>
21 #include <linux/ipv6.h>
22 #include <netinet/in.h>
23 #include <netinet/ip.h>
24 #include <sys/socket.h>
25 #include "aolib.h"
26 #include "../../../../include/linux/compiler.h"
27 
28 const size_t packets_nr = 20;
29 const size_t packet_size = 100;
30 const char *tcpao_icmps	= "TCPAODroppedIcmps";
31 
32 #ifdef IPV6_TEST
33 const char *dst_unreach	= "Icmp6InDestUnreachs";
34 const int sk_ip_level	= SOL_IPV6;
35 const int sk_recverr	= IPV6_RECVERR;
36 #else
37 const char *dst_unreach	= "InDestUnreachs";
38 const int sk_ip_level	= SOL_IP;
39 const int sk_recverr	= IP_RECVERR;
40 #endif
41 
42 /* Server is expected to fail with hard error if ::accept_icmp is set */
43 #ifdef TEST_ICMPS_ACCEPT
44 # define test_icmps_fail test_ok
45 # define test_icmps_ok test_fail
46 #else
47 # define test_icmps_fail test_fail
48 # define test_icmps_ok test_ok
49 #endif
50 
51 static void serve_interfered(int sk)
52 {
53 	ssize_t test_quota = packet_size * packets_nr * 10;
54 	uint64_t dest_unreach_a, dest_unreach_b;
55 	uint64_t icmp_ignored_a, icmp_ignored_b;
56 	struct tcp_counters cnt1, cnt2;
57 	bool counter_not_found;
58 	struct netstat *ns_after, *ns_before;
59 	ssize_t bytes;
60 
61 	ns_before = netstat_read();
62 	dest_unreach_a = netstat_get(ns_before, dst_unreach, NULL);
63 	icmp_ignored_a = netstat_get(ns_before, tcpao_icmps, NULL);
64 	if (test_get_tcp_counters(sk, &cnt1))
65 		test_error("test_get_tcp_counters()");
66 	bytes = test_server_run(sk, test_quota, 0);
67 	ns_after = netstat_read();
68 	netstat_print_diff(ns_before, ns_after);
69 	dest_unreach_b = netstat_get(ns_after, dst_unreach, NULL);
70 	icmp_ignored_b = netstat_get(ns_after, tcpao_icmps,
71 					&counter_not_found);
72 	if (test_get_tcp_counters(sk, &cnt2))
73 		test_error("test_get_tcp_counters()");
74 
75 	netstat_free(ns_before);
76 	netstat_free(ns_after);
77 
78 	if (dest_unreach_a >= dest_unreach_b) {
79 		test_fail("%s counter didn't change: %" PRIu64 " >= %" PRIu64,
80 				dst_unreach, dest_unreach_a, dest_unreach_b);
81 		return;
82 	}
83 	test_ok("%s delivered %" PRIu64,
84 		dst_unreach, dest_unreach_b - dest_unreach_a);
85 	if (bytes < 0)
86 		test_icmps_fail("Server failed with %zd: %s", bytes, strerrordesc_np(-bytes));
87 	else
88 		test_icmps_ok("Server survived %zd bytes of traffic", test_quota);
89 	if (counter_not_found) {
90 		test_fail("Not found %s counter", tcpao_icmps);
91 		return;
92 	}
93 #ifdef TEST_ICMPS_ACCEPT
94 	test_assert_counters(NULL, &cnt1, &cnt2, TEST_CNT_GOOD);
95 #else
96 	test_assert_counters(NULL, &cnt1, &cnt2, TEST_CNT_GOOD | TEST_CNT_AO_DROPPED_ICMP);
97 #endif
98 	if (icmp_ignored_a >= icmp_ignored_b) {
99 		test_icmps_fail("%s counter didn't change: %" PRIu64 " >= %" PRIu64,
100 				tcpao_icmps, icmp_ignored_a, icmp_ignored_b);
101 		return;
102 	}
103 	test_icmps_ok("ICMPs ignored %" PRIu64, icmp_ignored_b - icmp_ignored_a);
104 }
105 
106 static void *server_fn(void *arg)
107 {
108 	int val, sk, lsk;
109 	bool accept_icmps = false;
110 
111 	lsk = test_listen_socket(this_ip_addr, test_server_port, 1);
112 
113 #ifdef TEST_ICMPS_ACCEPT
114 	accept_icmps = true;
115 #endif
116 
117 	if (test_set_ao_flags(lsk, false, accept_icmps))
118 		test_error("setsockopt(TCP_AO_INFO)");
119 
120 	if (test_add_key(lsk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100))
121 		test_error("setsockopt(TCP_AO_ADD_KEY)");
122 	synchronize_threads();
123 
124 	if (test_wait_fd(lsk, TEST_TIMEOUT_SEC, 0))
125 		test_error("test_wait_fd()");
126 
127 	sk = accept(lsk, NULL, NULL);
128 	if (sk < 0)
129 		test_error("accept()");
130 
131 	/* Fail on hard ip errors, such as dest unreachable (RFC1122) */
132 	val = 1;
133 	if (setsockopt(sk, sk_ip_level, sk_recverr, &val, sizeof(val)))
134 		test_error("setsockopt()");
135 
136 	synchronize_threads();
137 
138 	serve_interfered(sk);
139 	return NULL;
140 }
141 
142 static size_t packets_sent;
143 static size_t icmps_sent;
144 
145 static uint32_t checksum4_nofold(void *data, size_t len, uint32_t sum)
146 {
147 	uint16_t *words = data;
148 	size_t i;
149 
150 	for (i = 0; i < len / sizeof(uint16_t); i++)
151 		sum += words[i];
152 	if (len & 1)
153 		sum += ((char *)data)[len - 1];
154 	return sum;
155 }
156 
157 static uint16_t checksum4_fold(void *data, size_t len, uint32_t sum)
158 {
159 	sum = checksum4_nofold(data, len, sum);
160 	while (sum > 0xFFFF)
161 		sum = (sum & 0xFFFF) + (sum >> 16);
162 	return ~sum;
163 }
164 
165 static void set_ip4hdr(struct iphdr *iph, size_t packet_len, int proto,
166 		struct sockaddr_in *src, struct sockaddr_in *dst)
167 {
168 	iph->version	= 4;
169 	iph->ihl	= 5;
170 	iph->tos	= 0;
171 	iph->tot_len	= htons(packet_len);
172 	iph->ttl	= 2;
173 	iph->protocol	= proto;
174 	iph->saddr	= src->sin_addr.s_addr;
175 	iph->daddr	= dst->sin_addr.s_addr;
176 	iph->check	= checksum4_fold((void *)iph, iph->ihl << 1, 0);
177 }
178 
179 static void icmp_interfere4(uint8_t type, uint8_t code, uint32_t rcv_nxt,
180 		struct sockaddr_in *src, struct sockaddr_in *dst)
181 {
182 	int sk = socket(AF_INET, SOCK_RAW, IPPROTO_RAW);
183 	struct {
184 		struct iphdr iph;
185 		struct icmphdr icmph;
186 		struct iphdr iphe;
187 		struct {
188 			uint16_t sport;
189 			uint16_t dport;
190 			uint32_t seq;
191 		} tcph;
192 	} packet = {};
193 	size_t packet_len;
194 	ssize_t bytes;
195 
196 	if (sk < 0)
197 		test_error("socket(AF_INET, SOCK_RAW, IPPROTO_RAW)");
198 
199 	packet_len = sizeof(packet);
200 	set_ip4hdr(&packet.iph, packet_len, IPPROTO_ICMP, src, dst);
201 
202 	packet.icmph.type = type;
203 	packet.icmph.code = code;
204 	if (code == ICMP_FRAG_NEEDED) {
205 		randomize_buffer(&packet.icmph.un.frag.mtu,
206 				sizeof(packet.icmph.un.frag.mtu));
207 	}
208 
209 	packet_len = sizeof(packet.iphe) + sizeof(packet.tcph);
210 	set_ip4hdr(&packet.iphe, packet_len, IPPROTO_TCP, dst, src);
211 
212 	packet.tcph.sport = dst->sin_port;
213 	packet.tcph.dport = src->sin_port;
214 	packet.tcph.seq = htonl(rcv_nxt);
215 
216 	packet_len = sizeof(packet) - sizeof(packet.iph);
217 	packet.icmph.checksum = checksum4_fold((void *)&packet.icmph,
218 						packet_len, 0);
219 
220 	bytes = sendto(sk, &packet, sizeof(packet), 0,
221 		       (struct sockaddr *)dst, sizeof(*dst));
222 	if (bytes != sizeof(packet))
223 		test_error("send(): %zd", bytes);
224 	icmps_sent++;
225 
226 	close(sk);
227 }
228 
229 static void set_ip6hdr(struct ipv6hdr *iph, size_t packet_len, int proto,
230 		struct sockaddr_in6 *src, struct sockaddr_in6 *dst)
231 {
232 	iph->version		= 6;
233 	iph->payload_len	= htons(packet_len);
234 	iph->nexthdr		= proto;
235 	iph->hop_limit		= 2;
236 	iph->saddr		= src->sin6_addr;
237 	iph->daddr		= dst->sin6_addr;
238 }
239 
240 static inline uint16_t csum_fold(uint32_t csum)
241 {
242 	uint32_t sum = csum;
243 
244 	sum = (sum & 0xffff) + (sum >> 16);
245 	sum = (sum & 0xffff) + (sum >> 16);
246 	return (uint16_t)~sum;
247 }
248 
249 static inline uint32_t csum_add(uint32_t csum, uint32_t addend)
250 {
251 	uint32_t res = csum;
252 
253 	res += addend;
254 	return res + (res < addend);
255 }
256 
257 noinline uint32_t checksum6_nofold(void *data, size_t len, uint32_t sum)
258 {
259 	uint16_t *words = data;
260 	size_t i;
261 
262 	for (i = 0; i < len / sizeof(uint16_t); i++)
263 		sum = csum_add(sum, words[i]);
264 	if (len & 1)
265 		sum = csum_add(sum, ((char *)data)[len - 1]);
266 	return sum;
267 }
268 
269 noinline uint16_t icmp6_checksum(struct sockaddr_in6 *src,
270 				 struct sockaddr_in6 *dst,
271 				 void *ptr, size_t len, uint8_t proto)
272 {
273 	struct {
274 		struct in6_addr saddr;
275 		struct in6_addr daddr;
276 		uint32_t payload_len;
277 		uint8_t zero[3];
278 		uint8_t nexthdr;
279 	} pseudo_header = {};
280 	uint32_t sum;
281 
282 	pseudo_header.saddr		= src->sin6_addr;
283 	pseudo_header.daddr		= dst->sin6_addr;
284 	pseudo_header.payload_len	= htonl(len);
285 	pseudo_header.nexthdr		= proto;
286 
287 	sum = checksum6_nofold(&pseudo_header, sizeof(pseudo_header), 0);
288 	sum = checksum6_nofold(ptr, len, sum);
289 
290 	return csum_fold(sum);
291 }
292 
293 static void icmp6_interfere(int type, int code, uint32_t rcv_nxt,
294 		struct sockaddr_in6 *src, struct sockaddr_in6 *dst)
295 {
296 	int sk = socket(AF_INET6, SOCK_RAW, IPPROTO_RAW);
297 	struct sockaddr_in6 dst_raw = *dst;
298 	struct {
299 		struct ipv6hdr iph;
300 		struct icmp6hdr icmph;
301 		struct ipv6hdr iphe;
302 		struct {
303 			uint16_t sport;
304 			uint16_t dport;
305 			uint32_t seq;
306 		} tcph;
307 	} packet = {};
308 	size_t packet_len;
309 	ssize_t bytes;
310 
311 
312 	if (sk < 0)
313 		test_error("socket(AF_INET6, SOCK_RAW, IPPROTO_RAW)");
314 
315 	packet_len = sizeof(packet) - sizeof(packet.iph);
316 	set_ip6hdr(&packet.iph, packet_len, IPPROTO_ICMPV6, src, dst);
317 
318 	packet.icmph.icmp6_type = type;
319 	packet.icmph.icmp6_code = code;
320 
321 	packet_len = sizeof(packet.iphe) + sizeof(packet.tcph);
322 	set_ip6hdr(&packet.iphe, packet_len, IPPROTO_TCP, dst, src);
323 
324 	packet.tcph.sport = dst->sin6_port;
325 	packet.tcph.dport = src->sin6_port;
326 	packet.tcph.seq = htonl(rcv_nxt);
327 
328 	packet_len = sizeof(packet) - sizeof(packet.iph);
329 
330 	packet.icmph.icmp6_cksum = icmp6_checksum(src, dst,
331 			(void *)&packet.icmph, packet_len, IPPROTO_ICMPV6);
332 
333 	dst_raw.sin6_port = htons(IPPROTO_RAW);
334 	bytes = sendto(sk, &packet, sizeof(packet), 0,
335 		       (struct sockaddr *)&dst_raw, sizeof(dst_raw));
336 	if (bytes != sizeof(packet))
337 		test_error("send(): %zd", bytes);
338 	icmps_sent++;
339 
340 	close(sk);
341 }
342 
343 static uint32_t get_rcv_nxt(int sk)
344 {
345 	int val = TCP_REPAIR_ON;
346 	uint32_t ret;
347 	socklen_t sz = sizeof(ret);
348 
349 	if (setsockopt(sk, SOL_TCP, TCP_REPAIR, &val, sizeof(val)))
350 		test_error("setsockopt(TCP_REPAIR)");
351 	val = TCP_RECV_QUEUE;
352 	if (setsockopt(sk, SOL_TCP, TCP_REPAIR_QUEUE, &val, sizeof(val)))
353 		test_error("setsockopt(TCP_REPAIR_QUEUE)");
354 	if (getsockopt(sk, SOL_TCP, TCP_QUEUE_SEQ, &ret, &sz))
355 		test_error("getsockopt(TCP_QUEUE_SEQ)");
356 	val = TCP_REPAIR_OFF_NO_WP;
357 	if (setsockopt(sk, SOL_TCP, TCP_REPAIR, &val, sizeof(val)))
358 		test_error("setsockopt(TCP_REPAIR)");
359 	return ret;
360 }
361 
362 static void icmp_interfere(const size_t nr, uint32_t rcv_nxt, void *src, void *dst)
363 {
364 	struct sockaddr_in *saddr4 = src;
365 	struct sockaddr_in *daddr4 = dst;
366 	struct sockaddr_in6 *saddr6 = src;
367 	struct sockaddr_in6 *daddr6 = dst;
368 	size_t i;
369 
370 	if (saddr4->sin_family != daddr4->sin_family)
371 		test_error("Different address families");
372 
373 	for (i = 0; i < nr; i++) {
374 		if (saddr4->sin_family == AF_INET) {
375 			icmp_interfere4(ICMP_DEST_UNREACH, ICMP_PROT_UNREACH,
376 					rcv_nxt, saddr4, daddr4);
377 			icmp_interfere4(ICMP_DEST_UNREACH, ICMP_PORT_UNREACH,
378 					rcv_nxt, saddr4, daddr4);
379 			icmp_interfere4(ICMP_DEST_UNREACH, ICMP_FRAG_NEEDED,
380 					rcv_nxt, saddr4, daddr4);
381 			icmps_sent += 3;
382 		} else if (saddr4->sin_family == AF_INET6) {
383 			icmp6_interfere(ICMPV6_DEST_UNREACH,
384 					ICMPV6_ADM_PROHIBITED,
385 					rcv_nxt, saddr6, daddr6);
386 			icmp6_interfere(ICMPV6_DEST_UNREACH,
387 					ICMPV6_PORT_UNREACH,
388 					rcv_nxt, saddr6, daddr6);
389 			icmps_sent += 2;
390 		} else {
391 			test_error("Not ip address family");
392 		}
393 	}
394 }
395 
396 static void send_interfered(int sk)
397 {
398 	struct sockaddr_in6 src, dst;
399 	socklen_t addr_sz;
400 
401 	addr_sz = sizeof(src);
402 	if (getsockname(sk, &src, &addr_sz))
403 		test_error("getsockname()");
404 	addr_sz = sizeof(dst);
405 	if (getpeername(sk, &dst, &addr_sz))
406 		test_error("getpeername()");
407 
408 	while (1) {
409 		uint32_t rcv_nxt;
410 
411 		if (test_client_verify(sk, packet_size, packets_nr)) {
412 			test_fail("client: connection is broken");
413 			return;
414 		}
415 		packets_sent += packets_nr;
416 		rcv_nxt = get_rcv_nxt(sk);
417 		icmp_interfere(packets_nr, rcv_nxt, (void *)&src, (void *)&dst);
418 	}
419 }
420 
421 static void *client_fn(void *arg)
422 {
423 	int sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
424 
425 	if (sk < 0)
426 		test_error("socket()");
427 
428 	if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100))
429 		test_error("setsockopt(TCP_AO_ADD_KEY)");
430 
431 	synchronize_threads();
432 	if (test_connect_socket(sk, this_ip_dest, test_server_port) <= 0)
433 		test_error("failed to connect()");
434 	synchronize_threads();
435 
436 	send_interfered(sk);
437 
438 	/* Not expecting client to quit */
439 	test_fail("client disconnected");
440 
441 	return NULL;
442 }
443 
444 int main(int argc, char *argv[])
445 {
446 	test_init(4, server_fn, client_fn);
447 	return 0;
448 }
449