1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright Amazon.com Inc. or its affiliates. */ 3 4 #include "vmlinux.h" 5 6 #include <bpf/bpf_helpers.h> 7 #include <bpf/bpf_endian.h> 8 #include "bpf_tracing_net.h" 9 #include "bpf_kfuncs.h" 10 #include "test_siphash.h" 11 #include "test_tcp_custom_syncookie.h" 12 #include "bpf_misc.h" 13 14 #define MAX_PACKET_OFF 0xffff 15 16 /* Hash is calculated for each client and split into ISN and TS. 17 * 18 * MSB LSB 19 * ISN: | 31 ... 8 | 7 6 | 5 | 4 | 3 2 1 0 | 20 * | Hash_1 | MSS | ECN | SACK | WScale | 21 * 22 * TS: | 31 ... 8 | 7 ... 0 | 23 * | Random | Hash_2 | 24 */ 25 #define COOKIE_BITS 8 26 #define COOKIE_MASK (((__u32)1 << COOKIE_BITS) - 1) 27 28 enum { 29 /* 0xf is invalid thus means that SYN did not have WScale. */ 30 BPF_SYNCOOKIE_WSCALE_MASK = (1 << 4) - 1, 31 BPF_SYNCOOKIE_SACK = (1 << 4), 32 BPF_SYNCOOKIE_ECN = (1 << 5), 33 }; 34 35 #define MSS_LOCAL_IPV4 65495 36 #define MSS_LOCAL_IPV6 65476 37 38 const __u16 msstab4[] = { 39 536, 40 1300, 41 1460, 42 MSS_LOCAL_IPV4, 43 }; 44 45 const __u16 msstab6[] = { 46 1280 - 60, /* IPV6_MIN_MTU - 60 */ 47 1480 - 60, 48 9000 - 60, 49 MSS_LOCAL_IPV6, 50 }; 51 52 static siphash_key_t test_key_siphash = { 53 { 0x0706050403020100ULL, 0x0f0e0d0c0b0a0908ULL } 54 }; 55 56 struct tcp_syncookie { 57 struct __sk_buff *skb; 58 void *data; 59 void *data_end; 60 struct ethhdr *eth; 61 struct iphdr *ipv4; 62 struct ipv6hdr *ipv6; 63 struct tcphdr *tcp; 64 __be32 *ptr32; 65 struct bpf_tcp_req_attrs attrs; 66 u32 off; 67 u32 cookie; 68 u64 first; 69 }; 70 71 bool handled_syn, handled_ack; 72 73 static int tcp_load_headers(struct tcp_syncookie *ctx) 74 { 75 ctx->data = (void *)(long)ctx->skb->data; 76 ctx->data_end = (void *)(long)ctx->skb->data_end; 77 ctx->eth = (struct ethhdr *)(long)ctx->skb->data; 78 79 if (ctx->eth + 1 > ctx->data_end) 80 goto err; 81 82 switch (bpf_ntohs(ctx->eth->h_proto)) { 83 case ETH_P_IP: 84 ctx->ipv4 = (struct iphdr *)(ctx->eth + 1); 85 86 if (ctx->ipv4 + 1 > ctx->data_end) 87 goto err; 88 89 if (ctx->ipv4->ihl != sizeof(*ctx->ipv4) / 4) 90 goto err; 91 92 if (ctx->ipv4->version != 4) 93 goto err; 94 95 if (ctx->ipv4->protocol != IPPROTO_TCP) 96 goto err; 97 98 ctx->tcp = (struct tcphdr *)(ctx->ipv4 + 1); 99 break; 100 case ETH_P_IPV6: 101 ctx->ipv6 = (struct ipv6hdr *)(ctx->eth + 1); 102 103 if (ctx->ipv6 + 1 > ctx->data_end) 104 goto err; 105 106 if (ctx->ipv6->version != 6) 107 goto err; 108 109 if (ctx->ipv6->nexthdr != NEXTHDR_TCP) 110 goto err; 111 112 ctx->tcp = (struct tcphdr *)(ctx->ipv6 + 1); 113 break; 114 default: 115 goto err; 116 } 117 118 if (ctx->tcp + 1 > ctx->data_end) 119 goto err; 120 121 return 0; 122 err: 123 return -1; 124 } 125 126 static int tcp_reload_headers(struct tcp_syncookie *ctx) 127 { 128 /* Without volatile, 129 * R3 32-bit pointer arithmetic prohibited 130 */ 131 volatile u64 data_len = ctx->skb->data_end - ctx->skb->data; 132 133 if (ctx->tcp->doff < sizeof(*ctx->tcp) / 4) 134 goto err; 135 136 /* Needed to calculate csum and parse TCP options. */ 137 if (bpf_skb_change_tail(ctx->skb, data_len + 60 - ctx->tcp->doff * 4, 0)) 138 goto err; 139 140 ctx->data = (void *)(long)ctx->skb->data; 141 ctx->data_end = (void *)(long)ctx->skb->data_end; 142 ctx->eth = (struct ethhdr *)(long)ctx->skb->data; 143 if (ctx->ipv4) { 144 ctx->ipv4 = (struct iphdr *)(ctx->eth + 1); 145 ctx->ipv6 = NULL; 146 ctx->tcp = (struct tcphdr *)(ctx->ipv4 + 1); 147 } else { 148 ctx->ipv4 = NULL; 149 ctx->ipv6 = (struct ipv6hdr *)(ctx->eth + 1); 150 ctx->tcp = (struct tcphdr *)(ctx->ipv6 + 1); 151 } 152 153 if ((void *)ctx->tcp + 60 > ctx->data_end) 154 goto err; 155 156 return 0; 157 err: 158 return -1; 159 } 160 161 static __sum16 tcp_v4_csum(struct tcp_syncookie *ctx, __wsum csum) 162 { 163 return csum_tcpudp_magic(ctx->ipv4->saddr, ctx->ipv4->daddr, 164 ctx->tcp->doff * 4, IPPROTO_TCP, csum); 165 } 166 167 static __sum16 tcp_v6_csum(struct tcp_syncookie *ctx, __wsum csum) 168 { 169 return csum_ipv6_magic(&ctx->ipv6->saddr, &ctx->ipv6->daddr, 170 ctx->tcp->doff * 4, IPPROTO_TCP, csum); 171 } 172 173 static int tcp_validate_header(struct tcp_syncookie *ctx) 174 { 175 s64 csum; 176 177 if (tcp_reload_headers(ctx)) 178 goto err; 179 180 csum = bpf_csum_diff(0, 0, (void *)ctx->tcp, ctx->tcp->doff * 4, 0); 181 if (csum < 0) 182 goto err; 183 184 if (ctx->ipv4) { 185 /* check tcp_v4_csum(csum) is 0 if not on lo. */ 186 187 csum = bpf_csum_diff(0, 0, (void *)ctx->ipv4, ctx->ipv4->ihl * 4, 0); 188 if (csum < 0) 189 goto err; 190 191 if (csum_fold(csum) != 0) 192 goto err; 193 } else if (ctx->ipv6) { 194 /* check tcp_v6_csum(csum) is 0 if not on lo. */ 195 } 196 197 return 0; 198 err: 199 return -1; 200 } 201 202 static __always_inline void *next(struct tcp_syncookie *ctx, __u32 sz) 203 { 204 __u64 off = ctx->off; 205 __u8 *data; 206 207 /* Verifier forbids access to packet when offset exceeds MAX_PACKET_OFF */ 208 if (off > MAX_PACKET_OFF - sz) 209 return NULL; 210 211 data = ctx->data + off; 212 barrier_var(data); 213 if (data + sz >= ctx->data_end) 214 return NULL; 215 216 ctx->off += sz; 217 return data; 218 } 219 220 static int tcp_parse_option(__u32 index, struct tcp_syncookie *ctx) 221 { 222 __u8 *opcode, *opsize, *wscale; 223 __u32 *tsval, *tsecr; 224 __u16 *mss; 225 __u32 off; 226 227 off = ctx->off; 228 opcode = next(ctx, 1); 229 if (!opcode) 230 goto stop; 231 232 if (*opcode == TCPOPT_EOL) 233 goto stop; 234 235 if (*opcode == TCPOPT_NOP) 236 goto next; 237 238 opsize = next(ctx, 1); 239 if (!opsize) 240 goto stop; 241 242 if (*opsize < 2) 243 goto stop; 244 245 switch (*opcode) { 246 case TCPOPT_MSS: 247 mss = next(ctx, 2); 248 if (*opsize == TCPOLEN_MSS && ctx->tcp->syn && mss) 249 ctx->attrs.mss = get_unaligned_be16(mss); 250 break; 251 case TCPOPT_WINDOW: 252 wscale = next(ctx, 1); 253 if (*opsize == TCPOLEN_WINDOW && ctx->tcp->syn && wscale) { 254 ctx->attrs.wscale_ok = 1; 255 ctx->attrs.snd_wscale = *wscale; 256 } 257 break; 258 case TCPOPT_TIMESTAMP: 259 tsval = next(ctx, 4); 260 tsecr = next(ctx, 4); 261 if (*opsize == TCPOLEN_TIMESTAMP && tsval && tsecr) { 262 ctx->attrs.rcv_tsval = get_unaligned_be32(tsval); 263 ctx->attrs.rcv_tsecr = get_unaligned_be32(tsecr); 264 265 if (ctx->tcp->syn && ctx->attrs.rcv_tsecr) 266 ctx->attrs.tstamp_ok = 0; 267 else 268 ctx->attrs.tstamp_ok = 1; 269 } 270 break; 271 case TCPOPT_SACK_PERM: 272 if (*opsize == TCPOLEN_SACK_PERM && ctx->tcp->syn) 273 ctx->attrs.sack_ok = 1; 274 break; 275 } 276 277 ctx->off = off + *opsize; 278 next: 279 return 0; 280 stop: 281 return 1; 282 } 283 284 static void tcp_parse_options(struct tcp_syncookie *ctx) 285 { 286 ctx->off = (__u8 *)(ctx->tcp + 1) - (__u8 *)ctx->data, 287 288 bpf_loop(40, tcp_parse_option, ctx, 0); 289 } 290 291 static int tcp_validate_sysctl(struct tcp_syncookie *ctx) 292 { 293 if ((ctx->ipv4 && ctx->attrs.mss != MSS_LOCAL_IPV4) || 294 (ctx->ipv6 && ctx->attrs.mss != MSS_LOCAL_IPV6)) 295 goto err; 296 297 if (!ctx->attrs.wscale_ok || ctx->attrs.snd_wscale != 7) 298 goto err; 299 300 if (!ctx->attrs.tstamp_ok) 301 goto err; 302 303 if (!ctx->attrs.sack_ok) 304 goto err; 305 306 if (!ctx->tcp->ece || !ctx->tcp->cwr) 307 goto err; 308 309 return 0; 310 err: 311 return -1; 312 } 313 314 static void tcp_prepare_cookie(struct tcp_syncookie *ctx) 315 { 316 u32 seq = bpf_ntohl(ctx->tcp->seq); 317 u64 first = 0, second; 318 int mssind = 0; 319 u32 hash; 320 321 if (ctx->ipv4) { 322 for (mssind = ARRAY_SIZE(msstab4) - 1; mssind; mssind--) 323 if (ctx->attrs.mss >= msstab4[mssind]) 324 break; 325 326 ctx->attrs.mss = msstab4[mssind]; 327 328 first = (u64)ctx->ipv4->saddr << 32 | ctx->ipv4->daddr; 329 } else if (ctx->ipv6) { 330 for (mssind = ARRAY_SIZE(msstab6) - 1; mssind; mssind--) 331 if (ctx->attrs.mss >= msstab6[mssind]) 332 break; 333 334 ctx->attrs.mss = msstab6[mssind]; 335 336 first = (u64)ctx->ipv6->saddr.in6_u.u6_addr8[0] << 32 | 337 ctx->ipv6->daddr.in6_u.u6_addr32[0]; 338 } 339 340 second = (u64)seq << 32 | ctx->tcp->source << 16 | ctx->tcp->dest; 341 hash = siphash_2u64(first, second, &test_key_siphash); 342 343 if (ctx->attrs.tstamp_ok) { 344 ctx->attrs.rcv_tsecr = bpf_get_prandom_u32(); 345 ctx->attrs.rcv_tsecr &= ~COOKIE_MASK; 346 ctx->attrs.rcv_tsecr |= hash & COOKIE_MASK; 347 } 348 349 hash &= ~COOKIE_MASK; 350 hash |= mssind << 6; 351 352 if (ctx->attrs.wscale_ok) 353 hash |= ctx->attrs.snd_wscale & BPF_SYNCOOKIE_WSCALE_MASK; 354 355 if (ctx->attrs.sack_ok) 356 hash |= BPF_SYNCOOKIE_SACK; 357 358 if (ctx->attrs.tstamp_ok && ctx->tcp->ece && ctx->tcp->cwr) 359 hash |= BPF_SYNCOOKIE_ECN; 360 361 ctx->cookie = hash; 362 } 363 364 static void tcp_write_options(struct tcp_syncookie *ctx) 365 { 366 ctx->ptr32 = (__be32 *)(ctx->tcp + 1); 367 368 *ctx->ptr32++ = bpf_htonl(TCPOPT_MSS << 24 | TCPOLEN_MSS << 16 | 369 ctx->attrs.mss); 370 371 if (ctx->attrs.wscale_ok) 372 *ctx->ptr32++ = bpf_htonl(TCPOPT_NOP << 24 | 373 TCPOPT_WINDOW << 16 | 374 TCPOLEN_WINDOW << 8 | 375 ctx->attrs.snd_wscale); 376 377 if (ctx->attrs.tstamp_ok) { 378 if (ctx->attrs.sack_ok) 379 *ctx->ptr32++ = bpf_htonl(TCPOPT_SACK_PERM << 24 | 380 TCPOLEN_SACK_PERM << 16 | 381 TCPOPT_TIMESTAMP << 8 | 382 TCPOLEN_TIMESTAMP); 383 else 384 *ctx->ptr32++ = bpf_htonl(TCPOPT_NOP << 24 | 385 TCPOPT_NOP << 16 | 386 TCPOPT_TIMESTAMP << 8 | 387 TCPOLEN_TIMESTAMP); 388 389 *ctx->ptr32++ = bpf_htonl(ctx->attrs.rcv_tsecr); 390 *ctx->ptr32++ = bpf_htonl(ctx->attrs.rcv_tsval); 391 } else if (ctx->attrs.sack_ok) { 392 *ctx->ptr32++ = bpf_htonl(TCPOPT_NOP << 24 | 393 TCPOPT_NOP << 16 | 394 TCPOPT_SACK_PERM << 8 | 395 TCPOLEN_SACK_PERM); 396 } 397 } 398 399 static int tcp_handle_syn(struct tcp_syncookie *ctx) 400 { 401 s64 csum; 402 403 if (tcp_validate_header(ctx)) 404 goto err; 405 406 tcp_parse_options(ctx); 407 408 if (tcp_validate_sysctl(ctx)) 409 goto err; 410 411 tcp_prepare_cookie(ctx); 412 tcp_write_options(ctx); 413 414 swap(ctx->tcp->source, ctx->tcp->dest); 415 ctx->tcp->check = 0; 416 ctx->tcp->ack_seq = bpf_htonl(bpf_ntohl(ctx->tcp->seq) + 1); 417 ctx->tcp->seq = bpf_htonl(ctx->cookie); 418 ctx->tcp->doff = ((long)ctx->ptr32 - (long)ctx->tcp) >> 2; 419 ctx->tcp->ack = 1; 420 if (!ctx->attrs.tstamp_ok || !ctx->tcp->ece || !ctx->tcp->cwr) 421 ctx->tcp->ece = 0; 422 ctx->tcp->cwr = 0; 423 424 csum = bpf_csum_diff(0, 0, (void *)ctx->tcp, ctx->tcp->doff * 4, 0); 425 if (csum < 0) 426 goto err; 427 428 if (ctx->ipv4) { 429 swap(ctx->ipv4->saddr, ctx->ipv4->daddr); 430 ctx->tcp->check = tcp_v4_csum(ctx, csum); 431 432 ctx->ipv4->check = 0; 433 ctx->ipv4->tos = 0; 434 ctx->ipv4->tot_len = bpf_htons((long)ctx->ptr32 - (long)ctx->ipv4); 435 ctx->ipv4->id = 0; 436 ctx->ipv4->ttl = 64; 437 438 csum = bpf_csum_diff(0, 0, (void *)ctx->ipv4, sizeof(*ctx->ipv4), 0); 439 if (csum < 0) 440 goto err; 441 442 ctx->ipv4->check = csum_fold(csum); 443 } else if (ctx->ipv6) { 444 swap(ctx->ipv6->saddr, ctx->ipv6->daddr); 445 ctx->tcp->check = tcp_v6_csum(ctx, csum); 446 447 *(__be32 *)ctx->ipv6 = bpf_htonl(0x60000000); 448 ctx->ipv6->payload_len = bpf_htons((long)ctx->ptr32 - (long)ctx->tcp); 449 ctx->ipv6->hop_limit = 64; 450 } 451 452 swap_array(ctx->eth->h_source, ctx->eth->h_dest); 453 454 if (bpf_skb_change_tail(ctx->skb, (long)ctx->ptr32 - (long)ctx->eth, 0)) 455 goto err; 456 457 return bpf_redirect(ctx->skb->ifindex, 0); 458 err: 459 return TC_ACT_SHOT; 460 } 461 462 static int tcp_validate_cookie(struct tcp_syncookie *ctx) 463 { 464 u32 cookie = bpf_ntohl(ctx->tcp->ack_seq) - 1; 465 u32 seq = bpf_ntohl(ctx->tcp->seq) - 1; 466 u64 first = 0, second; 467 int mssind; 468 u32 hash; 469 470 if (ctx->ipv4) 471 first = (u64)ctx->ipv4->saddr << 32 | ctx->ipv4->daddr; 472 else if (ctx->ipv6) 473 first = (u64)ctx->ipv6->saddr.in6_u.u6_addr8[0] << 32 | 474 ctx->ipv6->daddr.in6_u.u6_addr32[0]; 475 476 second = (u64)seq << 32 | ctx->tcp->source << 16 | ctx->tcp->dest; 477 hash = siphash_2u64(first, second, &test_key_siphash); 478 479 if (ctx->attrs.tstamp_ok) 480 hash -= ctx->attrs.rcv_tsecr & COOKIE_MASK; 481 else 482 hash &= ~COOKIE_MASK; 483 484 hash -= cookie & ~COOKIE_MASK; 485 if (hash) 486 goto err; 487 488 mssind = (cookie & (3 << 6)) >> 6; 489 if (ctx->ipv4) { 490 if (mssind > ARRAY_SIZE(msstab4)) 491 goto err; 492 493 ctx->attrs.mss = msstab4[mssind]; 494 } else { 495 if (mssind > ARRAY_SIZE(msstab6)) 496 goto err; 497 498 ctx->attrs.mss = msstab6[mssind]; 499 } 500 501 ctx->attrs.snd_wscale = cookie & BPF_SYNCOOKIE_WSCALE_MASK; 502 ctx->attrs.rcv_wscale = ctx->attrs.snd_wscale; 503 ctx->attrs.wscale_ok = ctx->attrs.snd_wscale == BPF_SYNCOOKIE_WSCALE_MASK; 504 ctx->attrs.sack_ok = cookie & BPF_SYNCOOKIE_SACK; 505 ctx->attrs.ecn_ok = cookie & BPF_SYNCOOKIE_ECN; 506 507 return 0; 508 err: 509 return -1; 510 } 511 512 static int tcp_handle_ack(struct tcp_syncookie *ctx) 513 { 514 struct bpf_sock_tuple tuple; 515 struct bpf_sock *skc; 516 int ret = TC_ACT_OK; 517 struct sock *sk; 518 u32 tuple_size; 519 520 if (ctx->ipv4) { 521 tuple.ipv4.saddr = ctx->ipv4->saddr; 522 tuple.ipv4.daddr = ctx->ipv4->daddr; 523 tuple.ipv4.sport = ctx->tcp->source; 524 tuple.ipv4.dport = ctx->tcp->dest; 525 tuple_size = sizeof(tuple.ipv4); 526 } else if (ctx->ipv6) { 527 __builtin_memcpy(tuple.ipv6.saddr, &ctx->ipv6->saddr, sizeof(tuple.ipv6.saddr)); 528 __builtin_memcpy(tuple.ipv6.daddr, &ctx->ipv6->daddr, sizeof(tuple.ipv6.daddr)); 529 tuple.ipv6.sport = ctx->tcp->source; 530 tuple.ipv6.dport = ctx->tcp->dest; 531 tuple_size = sizeof(tuple.ipv6); 532 } else { 533 goto out; 534 } 535 536 skc = bpf_skc_lookup_tcp(ctx->skb, &tuple, tuple_size, -1, 0); 537 if (!skc) 538 goto out; 539 540 if (skc->state != TCP_LISTEN) 541 goto release; 542 543 sk = (struct sock *)bpf_skc_to_tcp_sock(skc); 544 if (!sk) 545 goto err; 546 547 if (tcp_validate_header(ctx)) 548 goto err; 549 550 tcp_parse_options(ctx); 551 552 if (tcp_validate_cookie(ctx)) 553 goto err; 554 555 ret = bpf_sk_assign_tcp_reqsk(ctx->skb, sk, &ctx->attrs, sizeof(ctx->attrs)); 556 if (ret < 0) 557 goto err; 558 559 release: 560 bpf_sk_release(skc); 561 out: 562 return ret; 563 564 err: 565 ret = TC_ACT_SHOT; 566 goto release; 567 } 568 569 SEC("tc") 570 int tcp_custom_syncookie(struct __sk_buff *skb) 571 { 572 struct tcp_syncookie ctx = { 573 .skb = skb, 574 }; 575 576 if (tcp_load_headers(&ctx)) 577 return TC_ACT_OK; 578 579 if (ctx.tcp->rst) 580 return TC_ACT_OK; 581 582 if (ctx.tcp->syn) { 583 if (ctx.tcp->ack) 584 return TC_ACT_OK; 585 586 handled_syn = true; 587 588 return tcp_handle_syn(&ctx); 589 } 590 591 handled_ack = true; 592 593 return tcp_handle_ack(&ctx); 594 } 595 596 char _license[] SEC("license") = "GPL"; 597