xref: /linux/tools/testing/selftests/bpf/progs/test_tcp_custom_syncookie.c (revision 6500780cffa7f221431fa4a2ec1c2f6bc51dcb6b)
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright Amazon.com Inc. or its affiliates. */
3 
4 #include "vmlinux.h"
5 
6 #include <bpf/bpf_helpers.h>
7 #include <bpf/bpf_endian.h>
8 #include "bpf_tracing_net.h"
9 #include "bpf_kfuncs.h"
10 #include "test_siphash.h"
11 #include "test_tcp_custom_syncookie.h"
12 
13 /* Hash is calculated for each client and split into ISN and TS.
14  *
15  *       MSB                                   LSB
16  * ISN:  | 31 ... 8 | 7 6 |   5 |    4 | 3 2 1 0 |
17  *       |   Hash_1 | MSS | ECN | SACK |  WScale |
18  *
19  * TS:   | 31 ... 8 |          7 ... 0           |
20  *       |   Random |           Hash_2           |
21  */
22 #define COOKIE_BITS	8
23 #define COOKIE_MASK	(((__u32)1 << COOKIE_BITS) - 1)
24 
25 enum {
26 	/* 0xf is invalid thus means that SYN did not have WScale. */
27 	BPF_SYNCOOKIE_WSCALE_MASK	= (1 << 4) - 1,
28 	BPF_SYNCOOKIE_SACK		= (1 << 4),
29 	BPF_SYNCOOKIE_ECN		= (1 << 5),
30 };
31 
32 #define MSS_LOCAL_IPV4	65495
33 #define MSS_LOCAL_IPV6	65476
34 
35 const __u16 msstab4[] = {
36 	536,
37 	1300,
38 	1460,
39 	MSS_LOCAL_IPV4,
40 };
41 
42 const __u16 msstab6[] = {
43 	1280 - 60, /* IPV6_MIN_MTU - 60 */
44 	1480 - 60,
45 	9000 - 60,
46 	MSS_LOCAL_IPV6,
47 };
48 
49 static siphash_key_t test_key_siphash = {
50 	{ 0x0706050403020100ULL, 0x0f0e0d0c0b0a0908ULL }
51 };
52 
53 struct tcp_syncookie {
54 	struct __sk_buff *skb;
55 	void *data_end;
56 	struct ethhdr *eth;
57 	struct iphdr *ipv4;
58 	struct ipv6hdr *ipv6;
59 	struct tcphdr *tcp;
60 	union {
61 		char *ptr;
62 		__be32 *ptr32;
63 	};
64 	struct bpf_tcp_req_attrs attrs;
65 	u32 cookie;
66 	u64 first;
67 };
68 
69 bool handled_syn, handled_ack;
70 
71 static int tcp_load_headers(struct tcp_syncookie *ctx)
72 {
73 	ctx->data_end = (void *)(long)ctx->skb->data_end;
74 	ctx->eth = (struct ethhdr *)(long)ctx->skb->data;
75 
76 	if (ctx->eth + 1 > ctx->data_end)
77 		goto err;
78 
79 	switch (bpf_ntohs(ctx->eth->h_proto)) {
80 	case ETH_P_IP:
81 		ctx->ipv4 = (struct iphdr *)(ctx->eth + 1);
82 
83 		if (ctx->ipv4 + 1 > ctx->data_end)
84 			goto err;
85 
86 		if (ctx->ipv4->ihl != sizeof(*ctx->ipv4) / 4)
87 			goto err;
88 
89 		if (ctx->ipv4->version != 4)
90 			goto err;
91 
92 		if (ctx->ipv4->protocol != IPPROTO_TCP)
93 			goto err;
94 
95 		ctx->tcp = (struct tcphdr *)(ctx->ipv4 + 1);
96 		break;
97 	case ETH_P_IPV6:
98 		ctx->ipv6 = (struct ipv6hdr *)(ctx->eth + 1);
99 
100 		if (ctx->ipv6 + 1 > ctx->data_end)
101 			goto err;
102 
103 		if (ctx->ipv6->version != 6)
104 			goto err;
105 
106 		if (ctx->ipv6->nexthdr != NEXTHDR_TCP)
107 			goto err;
108 
109 		ctx->tcp = (struct tcphdr *)(ctx->ipv6 + 1);
110 		break;
111 	default:
112 		goto err;
113 	}
114 
115 	if (ctx->tcp + 1 > ctx->data_end)
116 		goto err;
117 
118 	return 0;
119 err:
120 	return -1;
121 }
122 
123 static int tcp_reload_headers(struct tcp_syncookie *ctx)
124 {
125 	/* Without volatile,
126 	 * R3 32-bit pointer arithmetic prohibited
127 	 */
128 	volatile u64 data_len = ctx->skb->data_end - ctx->skb->data;
129 
130 	if (ctx->tcp->doff < sizeof(*ctx->tcp) / 4)
131 		goto err;
132 
133 	/* Needed to calculate csum and parse TCP options. */
134 	if (bpf_skb_change_tail(ctx->skb, data_len + 60 - ctx->tcp->doff * 4, 0))
135 		goto err;
136 
137 	ctx->data_end = (void *)(long)ctx->skb->data_end;
138 	ctx->eth = (struct ethhdr *)(long)ctx->skb->data;
139 	if (ctx->ipv4) {
140 		ctx->ipv4 = (struct iphdr *)(ctx->eth + 1);
141 		ctx->ipv6 = NULL;
142 		ctx->tcp = (struct tcphdr *)(ctx->ipv4 + 1);
143 	} else {
144 		ctx->ipv4 = NULL;
145 		ctx->ipv6 = (struct ipv6hdr *)(ctx->eth + 1);
146 		ctx->tcp = (struct tcphdr *)(ctx->ipv6 + 1);
147 	}
148 
149 	if ((void *)ctx->tcp + 60 > ctx->data_end)
150 		goto err;
151 
152 	return 0;
153 err:
154 	return -1;
155 }
156 
157 static __sum16 tcp_v4_csum(struct tcp_syncookie *ctx, __wsum csum)
158 {
159 	return csum_tcpudp_magic(ctx->ipv4->saddr, ctx->ipv4->daddr,
160 				 ctx->tcp->doff * 4, IPPROTO_TCP, csum);
161 }
162 
163 static __sum16 tcp_v6_csum(struct tcp_syncookie *ctx, __wsum csum)
164 {
165 	return csum_ipv6_magic(&ctx->ipv6->saddr, &ctx->ipv6->daddr,
166 			       ctx->tcp->doff * 4, IPPROTO_TCP, csum);
167 }
168 
169 static int tcp_validate_header(struct tcp_syncookie *ctx)
170 {
171 	s64 csum;
172 
173 	if (tcp_reload_headers(ctx))
174 		goto err;
175 
176 	csum = bpf_csum_diff(0, 0, (void *)ctx->tcp, ctx->tcp->doff * 4, 0);
177 	if (csum < 0)
178 		goto err;
179 
180 	if (ctx->ipv4) {
181 		/* check tcp_v4_csum(csum) is 0 if not on lo. */
182 
183 		csum = bpf_csum_diff(0, 0, (void *)ctx->ipv4, ctx->ipv4->ihl * 4, 0);
184 		if (csum < 0)
185 			goto err;
186 
187 		if (csum_fold(csum) != 0)
188 			goto err;
189 	} else if (ctx->ipv6) {
190 		/* check tcp_v6_csum(csum) is 0 if not on lo. */
191 	}
192 
193 	return 0;
194 err:
195 	return -1;
196 }
197 
198 static int tcp_parse_option(__u32 index, struct tcp_syncookie *ctx)
199 {
200 	char opcode, opsize;
201 
202 	if (ctx->ptr + 1 > ctx->data_end)
203 		goto stop;
204 
205 	opcode = *ctx->ptr++;
206 
207 	if (opcode == TCPOPT_EOL)
208 		goto stop;
209 
210 	if (opcode == TCPOPT_NOP)
211 		goto next;
212 
213 	if (ctx->ptr + 1 > ctx->data_end)
214 		goto stop;
215 
216 	opsize = *ctx->ptr++;
217 
218 	if (opsize < 2)
219 		goto stop;
220 
221 	switch (opcode) {
222 	case TCPOPT_MSS:
223 		if (opsize == TCPOLEN_MSS && ctx->tcp->syn &&
224 		    ctx->ptr + (TCPOLEN_MSS - 2) < ctx->data_end)
225 			ctx->attrs.mss = get_unaligned_be16(ctx->ptr);
226 		break;
227 	case TCPOPT_WINDOW:
228 		if (opsize == TCPOLEN_WINDOW && ctx->tcp->syn &&
229 		    ctx->ptr + (TCPOLEN_WINDOW - 2) < ctx->data_end) {
230 			ctx->attrs.wscale_ok = 1;
231 			ctx->attrs.snd_wscale = *ctx->ptr;
232 		}
233 		break;
234 	case TCPOPT_TIMESTAMP:
235 		if (opsize == TCPOLEN_TIMESTAMP &&
236 		    ctx->ptr + (TCPOLEN_TIMESTAMP - 2) < ctx->data_end) {
237 			ctx->attrs.rcv_tsval = get_unaligned_be32(ctx->ptr);
238 			ctx->attrs.rcv_tsecr = get_unaligned_be32(ctx->ptr + 4);
239 
240 			if (ctx->tcp->syn && ctx->attrs.rcv_tsecr)
241 				ctx->attrs.tstamp_ok = 0;
242 			else
243 				ctx->attrs.tstamp_ok = 1;
244 		}
245 		break;
246 	case TCPOPT_SACK_PERM:
247 		if (opsize == TCPOLEN_SACK_PERM && ctx->tcp->syn &&
248 		    ctx->ptr + (TCPOLEN_SACK_PERM - 2) < ctx->data_end)
249 			ctx->attrs.sack_ok = 1;
250 		break;
251 	}
252 
253 	ctx->ptr += opsize - 2;
254 next:
255 	return 0;
256 stop:
257 	return 1;
258 }
259 
260 static void tcp_parse_options(struct tcp_syncookie *ctx)
261 {
262 	ctx->ptr = (char *)(ctx->tcp + 1);
263 
264 	bpf_loop(40, tcp_parse_option, ctx, 0);
265 }
266 
267 static int tcp_validate_sysctl(struct tcp_syncookie *ctx)
268 {
269 	if ((ctx->ipv4 && ctx->attrs.mss != MSS_LOCAL_IPV4) ||
270 	    (ctx->ipv6 && ctx->attrs.mss != MSS_LOCAL_IPV6))
271 		goto err;
272 
273 	if (!ctx->attrs.wscale_ok || ctx->attrs.snd_wscale != 7)
274 		goto err;
275 
276 	if (!ctx->attrs.tstamp_ok)
277 		goto err;
278 
279 	if (!ctx->attrs.sack_ok)
280 		goto err;
281 
282 	if (!ctx->tcp->ece || !ctx->tcp->cwr)
283 		goto err;
284 
285 	return 0;
286 err:
287 	return -1;
288 }
289 
290 static void tcp_prepare_cookie(struct tcp_syncookie *ctx)
291 {
292 	u32 seq = bpf_ntohl(ctx->tcp->seq);
293 	u64 first = 0, second;
294 	int mssind = 0;
295 	u32 hash;
296 
297 	if (ctx->ipv4) {
298 		for (mssind = ARRAY_SIZE(msstab4) - 1; mssind; mssind--)
299 			if (ctx->attrs.mss >= msstab4[mssind])
300 				break;
301 
302 		ctx->attrs.mss = msstab4[mssind];
303 
304 		first = (u64)ctx->ipv4->saddr << 32 | ctx->ipv4->daddr;
305 	} else if (ctx->ipv6) {
306 		for (mssind = ARRAY_SIZE(msstab6) - 1; mssind; mssind--)
307 			if (ctx->attrs.mss >= msstab6[mssind])
308 				break;
309 
310 		ctx->attrs.mss = msstab6[mssind];
311 
312 		first = (u64)ctx->ipv6->saddr.in6_u.u6_addr8[0] << 32 |
313 			ctx->ipv6->daddr.in6_u.u6_addr32[0];
314 	}
315 
316 	second = (u64)seq << 32 | ctx->tcp->source << 16 | ctx->tcp->dest;
317 	hash = siphash_2u64(first, second, &test_key_siphash);
318 
319 	if (ctx->attrs.tstamp_ok) {
320 		ctx->attrs.rcv_tsecr = bpf_get_prandom_u32();
321 		ctx->attrs.rcv_tsecr &= ~COOKIE_MASK;
322 		ctx->attrs.rcv_tsecr |= hash & COOKIE_MASK;
323 	}
324 
325 	hash &= ~COOKIE_MASK;
326 	hash |= mssind << 6;
327 
328 	if (ctx->attrs.wscale_ok)
329 		hash |= ctx->attrs.snd_wscale & BPF_SYNCOOKIE_WSCALE_MASK;
330 
331 	if (ctx->attrs.sack_ok)
332 		hash |= BPF_SYNCOOKIE_SACK;
333 
334 	if (ctx->attrs.tstamp_ok && ctx->tcp->ece && ctx->tcp->cwr)
335 		hash |= BPF_SYNCOOKIE_ECN;
336 
337 	ctx->cookie = hash;
338 }
339 
340 static void tcp_write_options(struct tcp_syncookie *ctx)
341 {
342 	ctx->ptr32 = (__be32 *)(ctx->tcp + 1);
343 
344 	*ctx->ptr32++ = bpf_htonl(TCPOPT_MSS << 24 | TCPOLEN_MSS << 16 |
345 				  ctx->attrs.mss);
346 
347 	if (ctx->attrs.wscale_ok)
348 		*ctx->ptr32++ = bpf_htonl(TCPOPT_NOP << 24 |
349 					  TCPOPT_WINDOW << 16 |
350 					  TCPOLEN_WINDOW << 8 |
351 					  ctx->attrs.snd_wscale);
352 
353 	if (ctx->attrs.tstamp_ok) {
354 		if (ctx->attrs.sack_ok)
355 			*ctx->ptr32++ = bpf_htonl(TCPOPT_SACK_PERM << 24 |
356 						  TCPOLEN_SACK_PERM << 16 |
357 						  TCPOPT_TIMESTAMP << 8 |
358 						  TCPOLEN_TIMESTAMP);
359 		else
360 			*ctx->ptr32++ = bpf_htonl(TCPOPT_NOP << 24 |
361 						  TCPOPT_NOP << 16 |
362 						  TCPOPT_TIMESTAMP << 8 |
363 						  TCPOLEN_TIMESTAMP);
364 
365 		*ctx->ptr32++ = bpf_htonl(ctx->attrs.rcv_tsecr);
366 		*ctx->ptr32++ = bpf_htonl(ctx->attrs.rcv_tsval);
367 	} else if (ctx->attrs.sack_ok) {
368 		*ctx->ptr32++ = bpf_htonl(TCPOPT_NOP << 24 |
369 					  TCPOPT_NOP << 16 |
370 					  TCPOPT_SACK_PERM << 8 |
371 					  TCPOLEN_SACK_PERM);
372 	}
373 }
374 
375 static int tcp_handle_syn(struct tcp_syncookie *ctx)
376 {
377 	s64 csum;
378 
379 	if (tcp_validate_header(ctx))
380 		goto err;
381 
382 	tcp_parse_options(ctx);
383 
384 	if (tcp_validate_sysctl(ctx))
385 		goto err;
386 
387 	tcp_prepare_cookie(ctx);
388 	tcp_write_options(ctx);
389 
390 	swap(ctx->tcp->source, ctx->tcp->dest);
391 	ctx->tcp->check = 0;
392 	ctx->tcp->ack_seq = bpf_htonl(bpf_ntohl(ctx->tcp->seq) + 1);
393 	ctx->tcp->seq = bpf_htonl(ctx->cookie);
394 	ctx->tcp->doff = ((long)ctx->ptr32 - (long)ctx->tcp) >> 2;
395 	ctx->tcp->ack = 1;
396 	if (!ctx->attrs.tstamp_ok || !ctx->tcp->ece || !ctx->tcp->cwr)
397 		ctx->tcp->ece = 0;
398 	ctx->tcp->cwr = 0;
399 
400 	csum = bpf_csum_diff(0, 0, (void *)ctx->tcp, ctx->tcp->doff * 4, 0);
401 	if (csum < 0)
402 		goto err;
403 
404 	if (ctx->ipv4) {
405 		swap(ctx->ipv4->saddr, ctx->ipv4->daddr);
406 		ctx->tcp->check = tcp_v4_csum(ctx, csum);
407 
408 		ctx->ipv4->check = 0;
409 		ctx->ipv4->tos = 0;
410 		ctx->ipv4->tot_len = bpf_htons((long)ctx->ptr32 - (long)ctx->ipv4);
411 		ctx->ipv4->id = 0;
412 		ctx->ipv4->ttl = 64;
413 
414 		csum = bpf_csum_diff(0, 0, (void *)ctx->ipv4, sizeof(*ctx->ipv4), 0);
415 		if (csum < 0)
416 			goto err;
417 
418 		ctx->ipv4->check = csum_fold(csum);
419 	} else if (ctx->ipv6) {
420 		swap(ctx->ipv6->saddr, ctx->ipv6->daddr);
421 		ctx->tcp->check = tcp_v6_csum(ctx, csum);
422 
423 		*(__be32 *)ctx->ipv6 = bpf_htonl(0x60000000);
424 		ctx->ipv6->payload_len = bpf_htons((long)ctx->ptr32 - (long)ctx->tcp);
425 		ctx->ipv6->hop_limit = 64;
426 	}
427 
428 	swap_array(ctx->eth->h_source, ctx->eth->h_dest);
429 
430 	if (bpf_skb_change_tail(ctx->skb, (long)ctx->ptr32 - (long)ctx->eth, 0))
431 		goto err;
432 
433 	return bpf_redirect(ctx->skb->ifindex, 0);
434 err:
435 	return TC_ACT_SHOT;
436 }
437 
438 static int tcp_validate_cookie(struct tcp_syncookie *ctx)
439 {
440 	u32 cookie = bpf_ntohl(ctx->tcp->ack_seq) - 1;
441 	u32 seq = bpf_ntohl(ctx->tcp->seq) - 1;
442 	u64 first = 0, second;
443 	int mssind;
444 	u32 hash;
445 
446 	if (ctx->ipv4)
447 		first = (u64)ctx->ipv4->saddr << 32 | ctx->ipv4->daddr;
448 	else if (ctx->ipv6)
449 		first = (u64)ctx->ipv6->saddr.in6_u.u6_addr8[0] << 32 |
450 			ctx->ipv6->daddr.in6_u.u6_addr32[0];
451 
452 	second = (u64)seq << 32 | ctx->tcp->source << 16 | ctx->tcp->dest;
453 	hash = siphash_2u64(first, second, &test_key_siphash);
454 
455 	if (ctx->attrs.tstamp_ok)
456 		hash -= ctx->attrs.rcv_tsecr & COOKIE_MASK;
457 	else
458 		hash &= ~COOKIE_MASK;
459 
460 	hash -= cookie & ~COOKIE_MASK;
461 	if (hash)
462 		goto err;
463 
464 	mssind = (cookie & (3 << 6)) >> 6;
465 	if (ctx->ipv4) {
466 		if (mssind > ARRAY_SIZE(msstab4))
467 			goto err;
468 
469 		ctx->attrs.mss = msstab4[mssind];
470 	} else {
471 		if (mssind > ARRAY_SIZE(msstab6))
472 			goto err;
473 
474 		ctx->attrs.mss = msstab6[mssind];
475 	}
476 
477 	ctx->attrs.snd_wscale = cookie & BPF_SYNCOOKIE_WSCALE_MASK;
478 	ctx->attrs.rcv_wscale = ctx->attrs.snd_wscale;
479 	ctx->attrs.wscale_ok = ctx->attrs.snd_wscale == BPF_SYNCOOKIE_WSCALE_MASK;
480 	ctx->attrs.sack_ok = cookie & BPF_SYNCOOKIE_SACK;
481 	ctx->attrs.ecn_ok = cookie & BPF_SYNCOOKIE_ECN;
482 
483 	return 0;
484 err:
485 	return -1;
486 }
487 
488 static int tcp_handle_ack(struct tcp_syncookie *ctx)
489 {
490 	struct bpf_sock_tuple tuple;
491 	struct bpf_sock *skc;
492 	int ret = TC_ACT_OK;
493 	struct sock *sk;
494 	u32 tuple_size;
495 
496 	if (ctx->ipv4) {
497 		tuple.ipv4.saddr = ctx->ipv4->saddr;
498 		tuple.ipv4.daddr = ctx->ipv4->daddr;
499 		tuple.ipv4.sport = ctx->tcp->source;
500 		tuple.ipv4.dport = ctx->tcp->dest;
501 		tuple_size = sizeof(tuple.ipv4);
502 	} else if (ctx->ipv6) {
503 		__builtin_memcpy(tuple.ipv6.saddr, &ctx->ipv6->saddr, sizeof(tuple.ipv6.saddr));
504 		__builtin_memcpy(tuple.ipv6.daddr, &ctx->ipv6->daddr, sizeof(tuple.ipv6.daddr));
505 		tuple.ipv6.sport = ctx->tcp->source;
506 		tuple.ipv6.dport = ctx->tcp->dest;
507 		tuple_size = sizeof(tuple.ipv6);
508 	} else {
509 		goto out;
510 	}
511 
512 	skc = bpf_skc_lookup_tcp(ctx->skb, &tuple, tuple_size, -1, 0);
513 	if (!skc)
514 		goto out;
515 
516 	if (skc->state != TCP_LISTEN)
517 		goto release;
518 
519 	sk = (struct sock *)bpf_skc_to_tcp_sock(skc);
520 	if (!sk)
521 		goto err;
522 
523 	if (tcp_validate_header(ctx))
524 		goto err;
525 
526 	tcp_parse_options(ctx);
527 
528 	if (tcp_validate_cookie(ctx))
529 		goto err;
530 
531 	ret = bpf_sk_assign_tcp_reqsk(ctx->skb, sk, &ctx->attrs, sizeof(ctx->attrs));
532 	if (ret < 0)
533 		goto err;
534 
535 release:
536 	bpf_sk_release(skc);
537 out:
538 	return ret;
539 
540 err:
541 	ret = TC_ACT_SHOT;
542 	goto release;
543 }
544 
545 SEC("tc")
546 int tcp_custom_syncookie(struct __sk_buff *skb)
547 {
548 	struct tcp_syncookie ctx = {
549 		.skb = skb,
550 	};
551 
552 	if (tcp_load_headers(&ctx))
553 		return TC_ACT_OK;
554 
555 	if (ctx.tcp->rst)
556 		return TC_ACT_OK;
557 
558 	if (ctx.tcp->syn) {
559 		if (ctx.tcp->ack)
560 			return TC_ACT_OK;
561 
562 		handled_syn = true;
563 
564 		return tcp_handle_syn(&ctx);
565 	}
566 
567 	handled_ack = true;
568 
569 	return tcp_handle_ack(&ctx);
570 }
571 
572 char _license[] SEC("license") = "GPL";
573