xref: /linux/tools/testing/selftests/bpf/progs/test_tcp_custom_syncookie.c (revision 76d9b92e68f2bb55890f935c5143f4fef97a935d)
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright Amazon.com Inc. or its affiliates. */
3 
4 #include "vmlinux.h"
5 
6 #include <bpf/bpf_helpers.h>
7 #include <bpf/bpf_endian.h>
8 #include "bpf_tracing_net.h"
9 #include "bpf_kfuncs.h"
10 #include "test_siphash.h"
11 #include "test_tcp_custom_syncookie.h"
12 #include "bpf_misc.h"
13 
14 #define MAX_PACKET_OFF 0xffff
15 
16 /* Hash is calculated for each client and split into ISN and TS.
17  *
18  *       MSB                                   LSB
19  * ISN:  | 31 ... 8 | 7 6 |   5 |    4 | 3 2 1 0 |
20  *       |   Hash_1 | MSS | ECN | SACK |  WScale |
21  *
22  * TS:   | 31 ... 8 |          7 ... 0           |
23  *       |   Random |           Hash_2           |
24  */
25 #define COOKIE_BITS	8
26 #define COOKIE_MASK	(((__u32)1 << COOKIE_BITS) - 1)
27 
28 enum {
29 	/* 0xf is invalid thus means that SYN did not have WScale. */
30 	BPF_SYNCOOKIE_WSCALE_MASK	= (1 << 4) - 1,
31 	BPF_SYNCOOKIE_SACK		= (1 << 4),
32 	BPF_SYNCOOKIE_ECN		= (1 << 5),
33 };
34 
35 #define MSS_LOCAL_IPV4	65495
36 #define MSS_LOCAL_IPV6	65476
37 
38 const __u16 msstab4[] = {
39 	536,
40 	1300,
41 	1460,
42 	MSS_LOCAL_IPV4,
43 };
44 
45 const __u16 msstab6[] = {
46 	1280 - 60, /* IPV6_MIN_MTU - 60 */
47 	1480 - 60,
48 	9000 - 60,
49 	MSS_LOCAL_IPV6,
50 };
51 
52 static siphash_key_t test_key_siphash = {
53 	{ 0x0706050403020100ULL, 0x0f0e0d0c0b0a0908ULL }
54 };
55 
56 struct tcp_syncookie {
57 	struct __sk_buff *skb;
58 	void *data;
59 	void *data_end;
60 	struct ethhdr *eth;
61 	struct iphdr *ipv4;
62 	struct ipv6hdr *ipv6;
63 	struct tcphdr *tcp;
64 	__be32 *ptr32;
65 	struct bpf_tcp_req_attrs attrs;
66 	u32 off;
67 	u32 cookie;
68 	u64 first;
69 };
70 
71 bool handled_syn, handled_ack;
72 
73 static int tcp_load_headers(struct tcp_syncookie *ctx)
74 {
75 	ctx->data = (void *)(long)ctx->skb->data;
76 	ctx->data_end = (void *)(long)ctx->skb->data_end;
77 	ctx->eth = (struct ethhdr *)(long)ctx->skb->data;
78 
79 	if (ctx->eth + 1 > ctx->data_end)
80 		goto err;
81 
82 	switch (bpf_ntohs(ctx->eth->h_proto)) {
83 	case ETH_P_IP:
84 		ctx->ipv4 = (struct iphdr *)(ctx->eth + 1);
85 
86 		if (ctx->ipv4 + 1 > ctx->data_end)
87 			goto err;
88 
89 		if (ctx->ipv4->ihl != sizeof(*ctx->ipv4) / 4)
90 			goto err;
91 
92 		if (ctx->ipv4->version != 4)
93 			goto err;
94 
95 		if (ctx->ipv4->protocol != IPPROTO_TCP)
96 			goto err;
97 
98 		ctx->tcp = (struct tcphdr *)(ctx->ipv4 + 1);
99 		break;
100 	case ETH_P_IPV6:
101 		ctx->ipv6 = (struct ipv6hdr *)(ctx->eth + 1);
102 
103 		if (ctx->ipv6 + 1 > ctx->data_end)
104 			goto err;
105 
106 		if (ctx->ipv6->version != 6)
107 			goto err;
108 
109 		if (ctx->ipv6->nexthdr != NEXTHDR_TCP)
110 			goto err;
111 
112 		ctx->tcp = (struct tcphdr *)(ctx->ipv6 + 1);
113 		break;
114 	default:
115 		goto err;
116 	}
117 
118 	if (ctx->tcp + 1 > ctx->data_end)
119 		goto err;
120 
121 	return 0;
122 err:
123 	return -1;
124 }
125 
126 static int tcp_reload_headers(struct tcp_syncookie *ctx)
127 {
128 	/* Without volatile,
129 	 * R3 32-bit pointer arithmetic prohibited
130 	 */
131 	volatile u64 data_len = ctx->skb->data_end - ctx->skb->data;
132 
133 	if (ctx->tcp->doff < sizeof(*ctx->tcp) / 4)
134 		goto err;
135 
136 	/* Needed to calculate csum and parse TCP options. */
137 	if (bpf_skb_change_tail(ctx->skb, data_len + 60 - ctx->tcp->doff * 4, 0))
138 		goto err;
139 
140 	ctx->data = (void *)(long)ctx->skb->data;
141 	ctx->data_end = (void *)(long)ctx->skb->data_end;
142 	ctx->eth = (struct ethhdr *)(long)ctx->skb->data;
143 	if (ctx->ipv4) {
144 		ctx->ipv4 = (struct iphdr *)(ctx->eth + 1);
145 		ctx->ipv6 = NULL;
146 		ctx->tcp = (struct tcphdr *)(ctx->ipv4 + 1);
147 	} else {
148 		ctx->ipv4 = NULL;
149 		ctx->ipv6 = (struct ipv6hdr *)(ctx->eth + 1);
150 		ctx->tcp = (struct tcphdr *)(ctx->ipv6 + 1);
151 	}
152 
153 	if ((void *)ctx->tcp + 60 > ctx->data_end)
154 		goto err;
155 
156 	return 0;
157 err:
158 	return -1;
159 }
160 
161 static __sum16 tcp_v4_csum(struct tcp_syncookie *ctx, __wsum csum)
162 {
163 	return csum_tcpudp_magic(ctx->ipv4->saddr, ctx->ipv4->daddr,
164 				 ctx->tcp->doff * 4, IPPROTO_TCP, csum);
165 }
166 
167 static __sum16 tcp_v6_csum(struct tcp_syncookie *ctx, __wsum csum)
168 {
169 	return csum_ipv6_magic(&ctx->ipv6->saddr, &ctx->ipv6->daddr,
170 			       ctx->tcp->doff * 4, IPPROTO_TCP, csum);
171 }
172 
173 static int tcp_validate_header(struct tcp_syncookie *ctx)
174 {
175 	s64 csum;
176 
177 	if (tcp_reload_headers(ctx))
178 		goto err;
179 
180 	csum = bpf_csum_diff(0, 0, (void *)ctx->tcp, ctx->tcp->doff * 4, 0);
181 	if (csum < 0)
182 		goto err;
183 
184 	if (ctx->ipv4) {
185 		/* check tcp_v4_csum(csum) is 0 if not on lo. */
186 
187 		csum = bpf_csum_diff(0, 0, (void *)ctx->ipv4, ctx->ipv4->ihl * 4, 0);
188 		if (csum < 0)
189 			goto err;
190 
191 		if (csum_fold(csum) != 0)
192 			goto err;
193 	} else if (ctx->ipv6) {
194 		/* check tcp_v6_csum(csum) is 0 if not on lo. */
195 	}
196 
197 	return 0;
198 err:
199 	return -1;
200 }
201 
202 static __always_inline void *next(struct tcp_syncookie *ctx, __u32 sz)
203 {
204 	__u64 off = ctx->off;
205 	__u8 *data;
206 
207 	/* Verifier forbids access to packet when offset exceeds MAX_PACKET_OFF */
208 	if (off > MAX_PACKET_OFF - sz)
209 		return NULL;
210 
211 	data = ctx->data + off;
212 	barrier_var(data);
213 	if (data + sz >= ctx->data_end)
214 		return NULL;
215 
216 	ctx->off += sz;
217 	return data;
218 }
219 
220 static int tcp_parse_option(__u32 index, struct tcp_syncookie *ctx)
221 {
222 	__u8 *opcode, *opsize, *wscale;
223 	__u32 *tsval, *tsecr;
224 	__u16 *mss;
225 	__u32 off;
226 
227 	off = ctx->off;
228 	opcode = next(ctx, 1);
229 	if (!opcode)
230 		goto stop;
231 
232 	if (*opcode == TCPOPT_EOL)
233 		goto stop;
234 
235 	if (*opcode == TCPOPT_NOP)
236 		goto next;
237 
238 	opsize = next(ctx, 1);
239 	if (!opsize)
240 		goto stop;
241 
242 	if (*opsize < 2)
243 		goto stop;
244 
245 	switch (*opcode) {
246 	case TCPOPT_MSS:
247 		mss = next(ctx, 2);
248 		if (*opsize == TCPOLEN_MSS && ctx->tcp->syn && mss)
249 			ctx->attrs.mss = get_unaligned_be16(mss);
250 		break;
251 	case TCPOPT_WINDOW:
252 		wscale = next(ctx, 1);
253 		if (*opsize == TCPOLEN_WINDOW && ctx->tcp->syn && wscale) {
254 			ctx->attrs.wscale_ok = 1;
255 			ctx->attrs.snd_wscale = *wscale;
256 		}
257 		break;
258 	case TCPOPT_TIMESTAMP:
259 		tsval = next(ctx, 4);
260 		tsecr = next(ctx, 4);
261 		if (*opsize == TCPOLEN_TIMESTAMP && tsval && tsecr) {
262 			ctx->attrs.rcv_tsval = get_unaligned_be32(tsval);
263 			ctx->attrs.rcv_tsecr = get_unaligned_be32(tsecr);
264 
265 			if (ctx->tcp->syn && ctx->attrs.rcv_tsecr)
266 				ctx->attrs.tstamp_ok = 0;
267 			else
268 				ctx->attrs.tstamp_ok = 1;
269 		}
270 		break;
271 	case TCPOPT_SACK_PERM:
272 		if (*opsize == TCPOLEN_SACK_PERM && ctx->tcp->syn)
273 			ctx->attrs.sack_ok = 1;
274 		break;
275 	}
276 
277 	ctx->off = off + *opsize;
278 next:
279 	return 0;
280 stop:
281 	return 1;
282 }
283 
284 static void tcp_parse_options(struct tcp_syncookie *ctx)
285 {
286 	ctx->off = (__u8 *)(ctx->tcp + 1) - (__u8 *)ctx->data,
287 
288 	bpf_loop(40, tcp_parse_option, ctx, 0);
289 }
290 
291 static int tcp_validate_sysctl(struct tcp_syncookie *ctx)
292 {
293 	if ((ctx->ipv4 && ctx->attrs.mss != MSS_LOCAL_IPV4) ||
294 	    (ctx->ipv6 && ctx->attrs.mss != MSS_LOCAL_IPV6))
295 		goto err;
296 
297 	if (!ctx->attrs.wscale_ok || ctx->attrs.snd_wscale != 7)
298 		goto err;
299 
300 	if (!ctx->attrs.tstamp_ok)
301 		goto err;
302 
303 	if (!ctx->attrs.sack_ok)
304 		goto err;
305 
306 	if (!ctx->tcp->ece || !ctx->tcp->cwr)
307 		goto err;
308 
309 	return 0;
310 err:
311 	return -1;
312 }
313 
314 static void tcp_prepare_cookie(struct tcp_syncookie *ctx)
315 {
316 	u32 seq = bpf_ntohl(ctx->tcp->seq);
317 	u64 first = 0, second;
318 	int mssind = 0;
319 	u32 hash;
320 
321 	if (ctx->ipv4) {
322 		for (mssind = ARRAY_SIZE(msstab4) - 1; mssind; mssind--)
323 			if (ctx->attrs.mss >= msstab4[mssind])
324 				break;
325 
326 		ctx->attrs.mss = msstab4[mssind];
327 
328 		first = (u64)ctx->ipv4->saddr << 32 | ctx->ipv4->daddr;
329 	} else if (ctx->ipv6) {
330 		for (mssind = ARRAY_SIZE(msstab6) - 1; mssind; mssind--)
331 			if (ctx->attrs.mss >= msstab6[mssind])
332 				break;
333 
334 		ctx->attrs.mss = msstab6[mssind];
335 
336 		first = (u64)ctx->ipv6->saddr.in6_u.u6_addr8[0] << 32 |
337 			ctx->ipv6->daddr.in6_u.u6_addr32[0];
338 	}
339 
340 	second = (u64)seq << 32 | ctx->tcp->source << 16 | ctx->tcp->dest;
341 	hash = siphash_2u64(first, second, &test_key_siphash);
342 
343 	if (ctx->attrs.tstamp_ok) {
344 		ctx->attrs.rcv_tsecr = bpf_get_prandom_u32();
345 		ctx->attrs.rcv_tsecr &= ~COOKIE_MASK;
346 		ctx->attrs.rcv_tsecr |= hash & COOKIE_MASK;
347 	}
348 
349 	hash &= ~COOKIE_MASK;
350 	hash |= mssind << 6;
351 
352 	if (ctx->attrs.wscale_ok)
353 		hash |= ctx->attrs.snd_wscale & BPF_SYNCOOKIE_WSCALE_MASK;
354 
355 	if (ctx->attrs.sack_ok)
356 		hash |= BPF_SYNCOOKIE_SACK;
357 
358 	if (ctx->attrs.tstamp_ok && ctx->tcp->ece && ctx->tcp->cwr)
359 		hash |= BPF_SYNCOOKIE_ECN;
360 
361 	ctx->cookie = hash;
362 }
363 
364 static void tcp_write_options(struct tcp_syncookie *ctx)
365 {
366 	ctx->ptr32 = (__be32 *)(ctx->tcp + 1);
367 
368 	*ctx->ptr32++ = bpf_htonl(TCPOPT_MSS << 24 | TCPOLEN_MSS << 16 |
369 				  ctx->attrs.mss);
370 
371 	if (ctx->attrs.wscale_ok)
372 		*ctx->ptr32++ = bpf_htonl(TCPOPT_NOP << 24 |
373 					  TCPOPT_WINDOW << 16 |
374 					  TCPOLEN_WINDOW << 8 |
375 					  ctx->attrs.snd_wscale);
376 
377 	if (ctx->attrs.tstamp_ok) {
378 		if (ctx->attrs.sack_ok)
379 			*ctx->ptr32++ = bpf_htonl(TCPOPT_SACK_PERM << 24 |
380 						  TCPOLEN_SACK_PERM << 16 |
381 						  TCPOPT_TIMESTAMP << 8 |
382 						  TCPOLEN_TIMESTAMP);
383 		else
384 			*ctx->ptr32++ = bpf_htonl(TCPOPT_NOP << 24 |
385 						  TCPOPT_NOP << 16 |
386 						  TCPOPT_TIMESTAMP << 8 |
387 						  TCPOLEN_TIMESTAMP);
388 
389 		*ctx->ptr32++ = bpf_htonl(ctx->attrs.rcv_tsecr);
390 		*ctx->ptr32++ = bpf_htonl(ctx->attrs.rcv_tsval);
391 	} else if (ctx->attrs.sack_ok) {
392 		*ctx->ptr32++ = bpf_htonl(TCPOPT_NOP << 24 |
393 					  TCPOPT_NOP << 16 |
394 					  TCPOPT_SACK_PERM << 8 |
395 					  TCPOLEN_SACK_PERM);
396 	}
397 }
398 
399 static int tcp_handle_syn(struct tcp_syncookie *ctx)
400 {
401 	s64 csum;
402 
403 	if (tcp_validate_header(ctx))
404 		goto err;
405 
406 	tcp_parse_options(ctx);
407 
408 	if (tcp_validate_sysctl(ctx))
409 		goto err;
410 
411 	tcp_prepare_cookie(ctx);
412 	tcp_write_options(ctx);
413 
414 	swap(ctx->tcp->source, ctx->tcp->dest);
415 	ctx->tcp->check = 0;
416 	ctx->tcp->ack_seq = bpf_htonl(bpf_ntohl(ctx->tcp->seq) + 1);
417 	ctx->tcp->seq = bpf_htonl(ctx->cookie);
418 	ctx->tcp->doff = ((long)ctx->ptr32 - (long)ctx->tcp) >> 2;
419 	ctx->tcp->ack = 1;
420 	if (!ctx->attrs.tstamp_ok || !ctx->tcp->ece || !ctx->tcp->cwr)
421 		ctx->tcp->ece = 0;
422 	ctx->tcp->cwr = 0;
423 
424 	csum = bpf_csum_diff(0, 0, (void *)ctx->tcp, ctx->tcp->doff * 4, 0);
425 	if (csum < 0)
426 		goto err;
427 
428 	if (ctx->ipv4) {
429 		swap(ctx->ipv4->saddr, ctx->ipv4->daddr);
430 		ctx->tcp->check = tcp_v4_csum(ctx, csum);
431 
432 		ctx->ipv4->check = 0;
433 		ctx->ipv4->tos = 0;
434 		ctx->ipv4->tot_len = bpf_htons((long)ctx->ptr32 - (long)ctx->ipv4);
435 		ctx->ipv4->id = 0;
436 		ctx->ipv4->ttl = 64;
437 
438 		csum = bpf_csum_diff(0, 0, (void *)ctx->ipv4, sizeof(*ctx->ipv4), 0);
439 		if (csum < 0)
440 			goto err;
441 
442 		ctx->ipv4->check = csum_fold(csum);
443 	} else if (ctx->ipv6) {
444 		swap(ctx->ipv6->saddr, ctx->ipv6->daddr);
445 		ctx->tcp->check = tcp_v6_csum(ctx, csum);
446 
447 		*(__be32 *)ctx->ipv6 = bpf_htonl(0x60000000);
448 		ctx->ipv6->payload_len = bpf_htons((long)ctx->ptr32 - (long)ctx->tcp);
449 		ctx->ipv6->hop_limit = 64;
450 	}
451 
452 	swap_array(ctx->eth->h_source, ctx->eth->h_dest);
453 
454 	if (bpf_skb_change_tail(ctx->skb, (long)ctx->ptr32 - (long)ctx->eth, 0))
455 		goto err;
456 
457 	return bpf_redirect(ctx->skb->ifindex, 0);
458 err:
459 	return TC_ACT_SHOT;
460 }
461 
462 static int tcp_validate_cookie(struct tcp_syncookie *ctx)
463 {
464 	u32 cookie = bpf_ntohl(ctx->tcp->ack_seq) - 1;
465 	u32 seq = bpf_ntohl(ctx->tcp->seq) - 1;
466 	u64 first = 0, second;
467 	int mssind;
468 	u32 hash;
469 
470 	if (ctx->ipv4)
471 		first = (u64)ctx->ipv4->saddr << 32 | ctx->ipv4->daddr;
472 	else if (ctx->ipv6)
473 		first = (u64)ctx->ipv6->saddr.in6_u.u6_addr8[0] << 32 |
474 			ctx->ipv6->daddr.in6_u.u6_addr32[0];
475 
476 	second = (u64)seq << 32 | ctx->tcp->source << 16 | ctx->tcp->dest;
477 	hash = siphash_2u64(first, second, &test_key_siphash);
478 
479 	if (ctx->attrs.tstamp_ok)
480 		hash -= ctx->attrs.rcv_tsecr & COOKIE_MASK;
481 	else
482 		hash &= ~COOKIE_MASK;
483 
484 	hash -= cookie & ~COOKIE_MASK;
485 	if (hash)
486 		goto err;
487 
488 	mssind = (cookie & (3 << 6)) >> 6;
489 	if (ctx->ipv4) {
490 		if (mssind > ARRAY_SIZE(msstab4))
491 			goto err;
492 
493 		ctx->attrs.mss = msstab4[mssind];
494 	} else {
495 		if (mssind > ARRAY_SIZE(msstab6))
496 			goto err;
497 
498 		ctx->attrs.mss = msstab6[mssind];
499 	}
500 
501 	ctx->attrs.snd_wscale = cookie & BPF_SYNCOOKIE_WSCALE_MASK;
502 	ctx->attrs.rcv_wscale = ctx->attrs.snd_wscale;
503 	ctx->attrs.wscale_ok = ctx->attrs.snd_wscale == BPF_SYNCOOKIE_WSCALE_MASK;
504 	ctx->attrs.sack_ok = cookie & BPF_SYNCOOKIE_SACK;
505 	ctx->attrs.ecn_ok = cookie & BPF_SYNCOOKIE_ECN;
506 
507 	return 0;
508 err:
509 	return -1;
510 }
511 
512 static int tcp_handle_ack(struct tcp_syncookie *ctx)
513 {
514 	struct bpf_sock_tuple tuple;
515 	struct bpf_sock *skc;
516 	int ret = TC_ACT_OK;
517 	struct sock *sk;
518 	u32 tuple_size;
519 
520 	if (ctx->ipv4) {
521 		tuple.ipv4.saddr = ctx->ipv4->saddr;
522 		tuple.ipv4.daddr = ctx->ipv4->daddr;
523 		tuple.ipv4.sport = ctx->tcp->source;
524 		tuple.ipv4.dport = ctx->tcp->dest;
525 		tuple_size = sizeof(tuple.ipv4);
526 	} else if (ctx->ipv6) {
527 		__builtin_memcpy(tuple.ipv6.saddr, &ctx->ipv6->saddr, sizeof(tuple.ipv6.saddr));
528 		__builtin_memcpy(tuple.ipv6.daddr, &ctx->ipv6->daddr, sizeof(tuple.ipv6.daddr));
529 		tuple.ipv6.sport = ctx->tcp->source;
530 		tuple.ipv6.dport = ctx->tcp->dest;
531 		tuple_size = sizeof(tuple.ipv6);
532 	} else {
533 		goto out;
534 	}
535 
536 	skc = bpf_skc_lookup_tcp(ctx->skb, &tuple, tuple_size, -1, 0);
537 	if (!skc)
538 		goto out;
539 
540 	if (skc->state != TCP_LISTEN)
541 		goto release;
542 
543 	sk = (struct sock *)bpf_skc_to_tcp_sock(skc);
544 	if (!sk)
545 		goto err;
546 
547 	if (tcp_validate_header(ctx))
548 		goto err;
549 
550 	tcp_parse_options(ctx);
551 
552 	if (tcp_validate_cookie(ctx))
553 		goto err;
554 
555 	ret = bpf_sk_assign_tcp_reqsk(ctx->skb, sk, &ctx->attrs, sizeof(ctx->attrs));
556 	if (ret < 0)
557 		goto err;
558 
559 release:
560 	bpf_sk_release(skc);
561 out:
562 	return ret;
563 
564 err:
565 	ret = TC_ACT_SHOT;
566 	goto release;
567 }
568 
569 SEC("tc")
570 int tcp_custom_syncookie(struct __sk_buff *skb)
571 {
572 	struct tcp_syncookie ctx = {
573 		.skb = skb,
574 	};
575 
576 	if (tcp_load_headers(&ctx))
577 		return TC_ACT_OK;
578 
579 	if (ctx.tcp->rst)
580 		return TC_ACT_OK;
581 
582 	if (ctx.tcp->syn) {
583 		if (ctx.tcp->ack)
584 			return TC_ACT_OK;
585 
586 		handled_syn = true;
587 
588 		return tcp_handle_syn(&ctx);
589 	}
590 
591 	handled_ack = true;
592 
593 	return tcp_handle_ack(&ctx);
594 }
595 
596 char _license[] SEC("license") = "GPL";
597