1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright Amazon.com Inc. or its affiliates. */ 3 4 #include "vmlinux.h" 5 6 #include <bpf/bpf_helpers.h> 7 #include <bpf/bpf_endian.h> 8 #include "bpf_tracing_net.h" 9 #include "bpf_kfuncs.h" 10 #include "test_siphash.h" 11 #include "test_tcp_custom_syncookie.h" 12 13 #define MAX_PACKET_OFF 0xffff 14 15 /* Hash is calculated for each client and split into ISN and TS. 16 * 17 * MSB LSB 18 * ISN: | 31 ... 8 | 7 6 | 5 | 4 | 3 2 1 0 | 19 * | Hash_1 | MSS | ECN | SACK | WScale | 20 * 21 * TS: | 31 ... 8 | 7 ... 0 | 22 * | Random | Hash_2 | 23 */ 24 #define COOKIE_BITS 8 25 #define COOKIE_MASK (((__u32)1 << COOKIE_BITS) - 1) 26 27 enum { 28 /* 0xf is invalid thus means that SYN did not have WScale. */ 29 BPF_SYNCOOKIE_WSCALE_MASK = (1 << 4) - 1, 30 BPF_SYNCOOKIE_SACK = (1 << 4), 31 BPF_SYNCOOKIE_ECN = (1 << 5), 32 }; 33 34 #define MSS_LOCAL_IPV4 65495 35 #define MSS_LOCAL_IPV6 65476 36 37 const __u16 msstab4[] = { 38 536, 39 1300, 40 1460, 41 MSS_LOCAL_IPV4, 42 }; 43 44 const __u16 msstab6[] = { 45 1280 - 60, /* IPV6_MIN_MTU - 60 */ 46 1480 - 60, 47 9000 - 60, 48 MSS_LOCAL_IPV6, 49 }; 50 51 static siphash_key_t test_key_siphash = { 52 { 0x0706050403020100ULL, 0x0f0e0d0c0b0a0908ULL } 53 }; 54 55 struct tcp_syncookie { 56 struct __sk_buff *skb; 57 void *data; 58 void *data_end; 59 struct ethhdr *eth; 60 struct iphdr *ipv4; 61 struct ipv6hdr *ipv6; 62 struct tcphdr *tcp; 63 __be32 *ptr32; 64 struct bpf_tcp_req_attrs attrs; 65 u32 off; 66 u32 cookie; 67 u64 first; 68 }; 69 70 bool handled_syn, handled_ack; 71 72 static int tcp_load_headers(struct tcp_syncookie *ctx) 73 { 74 ctx->data = (void *)(long)ctx->skb->data; 75 ctx->data_end = (void *)(long)ctx->skb->data_end; 76 ctx->eth = (struct ethhdr *)(long)ctx->skb->data; 77 78 if (ctx->eth + 1 > ctx->data_end) 79 goto err; 80 81 switch (bpf_ntohs(ctx->eth->h_proto)) { 82 case ETH_P_IP: 83 ctx->ipv4 = (struct iphdr *)(ctx->eth + 1); 84 85 if (ctx->ipv4 + 1 > ctx->data_end) 86 goto err; 87 88 if (ctx->ipv4->ihl != sizeof(*ctx->ipv4) / 4) 89 goto err; 90 91 if (ctx->ipv4->version != 4) 92 goto err; 93 94 if (ctx->ipv4->protocol != IPPROTO_TCP) 95 goto err; 96 97 ctx->tcp = (struct tcphdr *)(ctx->ipv4 + 1); 98 break; 99 case ETH_P_IPV6: 100 ctx->ipv6 = (struct ipv6hdr *)(ctx->eth + 1); 101 102 if (ctx->ipv6 + 1 > ctx->data_end) 103 goto err; 104 105 if (ctx->ipv6->version != 6) 106 goto err; 107 108 if (ctx->ipv6->nexthdr != NEXTHDR_TCP) 109 goto err; 110 111 ctx->tcp = (struct tcphdr *)(ctx->ipv6 + 1); 112 break; 113 default: 114 goto err; 115 } 116 117 if (ctx->tcp + 1 > ctx->data_end) 118 goto err; 119 120 return 0; 121 err: 122 return -1; 123 } 124 125 static int tcp_reload_headers(struct tcp_syncookie *ctx) 126 { 127 /* Without volatile, 128 * R3 32-bit pointer arithmetic prohibited 129 */ 130 volatile u64 data_len = ctx->skb->data_end - ctx->skb->data; 131 132 if (ctx->tcp->doff < sizeof(*ctx->tcp) / 4) 133 goto err; 134 135 /* Needed to calculate csum and parse TCP options. */ 136 if (bpf_skb_change_tail(ctx->skb, data_len + 60 - ctx->tcp->doff * 4, 0)) 137 goto err; 138 139 ctx->data = (void *)(long)ctx->skb->data; 140 ctx->data_end = (void *)(long)ctx->skb->data_end; 141 ctx->eth = (struct ethhdr *)(long)ctx->skb->data; 142 if (ctx->ipv4) { 143 ctx->ipv4 = (struct iphdr *)(ctx->eth + 1); 144 ctx->ipv6 = NULL; 145 ctx->tcp = (struct tcphdr *)(ctx->ipv4 + 1); 146 } else { 147 ctx->ipv4 = NULL; 148 ctx->ipv6 = (struct ipv6hdr *)(ctx->eth + 1); 149 ctx->tcp = (struct tcphdr *)(ctx->ipv6 + 1); 150 } 151 152 if ((void *)ctx->tcp + 60 > ctx->data_end) 153 goto err; 154 155 return 0; 156 err: 157 return -1; 158 } 159 160 static __sum16 tcp_v4_csum(struct tcp_syncookie *ctx, __wsum csum) 161 { 162 return csum_tcpudp_magic(ctx->ipv4->saddr, ctx->ipv4->daddr, 163 ctx->tcp->doff * 4, IPPROTO_TCP, csum); 164 } 165 166 static __sum16 tcp_v6_csum(struct tcp_syncookie *ctx, __wsum csum) 167 { 168 return csum_ipv6_magic(&ctx->ipv6->saddr, &ctx->ipv6->daddr, 169 ctx->tcp->doff * 4, IPPROTO_TCP, csum); 170 } 171 172 static int tcp_validate_header(struct tcp_syncookie *ctx) 173 { 174 s64 csum; 175 176 if (tcp_reload_headers(ctx)) 177 goto err; 178 179 csum = bpf_csum_diff(0, 0, (void *)ctx->tcp, ctx->tcp->doff * 4, 0); 180 if (csum < 0) 181 goto err; 182 183 if (ctx->ipv4) { 184 /* check tcp_v4_csum(csum) is 0 if not on lo. */ 185 186 csum = bpf_csum_diff(0, 0, (void *)ctx->ipv4, ctx->ipv4->ihl * 4, 0); 187 if (csum < 0) 188 goto err; 189 190 if (csum_fold(csum) != 0) 191 goto err; 192 } else if (ctx->ipv6) { 193 /* check tcp_v6_csum(csum) is 0 if not on lo. */ 194 } 195 196 return 0; 197 err: 198 return -1; 199 } 200 201 static __always_inline void *next(struct tcp_syncookie *ctx, __u32 sz) 202 { 203 __u64 off = ctx->off; 204 __u8 *data; 205 206 /* Verifier forbids access to packet when offset exceeds MAX_PACKET_OFF */ 207 if (off > MAX_PACKET_OFF - sz) 208 return NULL; 209 210 data = ctx->data + off; 211 barrier_var(data); 212 if (data + sz >= ctx->data_end) 213 return NULL; 214 215 ctx->off += sz; 216 return data; 217 } 218 219 static int tcp_parse_option(__u32 index, struct tcp_syncookie *ctx) 220 { 221 __u8 *opcode, *opsize, *wscale; 222 __u32 *tsval, *tsecr; 223 __u16 *mss; 224 __u32 off; 225 226 off = ctx->off; 227 opcode = next(ctx, 1); 228 if (!opcode) 229 goto stop; 230 231 if (*opcode == TCPOPT_EOL) 232 goto stop; 233 234 if (*opcode == TCPOPT_NOP) 235 goto next; 236 237 opsize = next(ctx, 1); 238 if (!opsize) 239 goto stop; 240 241 if (*opsize < 2) 242 goto stop; 243 244 switch (*opcode) { 245 case TCPOPT_MSS: 246 mss = next(ctx, 2); 247 if (*opsize == TCPOLEN_MSS && ctx->tcp->syn && mss) 248 ctx->attrs.mss = get_unaligned_be16(mss); 249 break; 250 case TCPOPT_WINDOW: 251 wscale = next(ctx, 1); 252 if (*opsize == TCPOLEN_WINDOW && ctx->tcp->syn && wscale) { 253 ctx->attrs.wscale_ok = 1; 254 ctx->attrs.snd_wscale = *wscale; 255 } 256 break; 257 case TCPOPT_TIMESTAMP: 258 tsval = next(ctx, 4); 259 tsecr = next(ctx, 4); 260 if (*opsize == TCPOLEN_TIMESTAMP && tsval && tsecr) { 261 ctx->attrs.rcv_tsval = get_unaligned_be32(tsval); 262 ctx->attrs.rcv_tsecr = get_unaligned_be32(tsecr); 263 264 if (ctx->tcp->syn && ctx->attrs.rcv_tsecr) 265 ctx->attrs.tstamp_ok = 0; 266 else 267 ctx->attrs.tstamp_ok = 1; 268 } 269 break; 270 case TCPOPT_SACK_PERM: 271 if (*opsize == TCPOLEN_SACK_PERM && ctx->tcp->syn) 272 ctx->attrs.sack_ok = 1; 273 break; 274 } 275 276 ctx->off = off + *opsize; 277 next: 278 return 0; 279 stop: 280 return 1; 281 } 282 283 static void tcp_parse_options(struct tcp_syncookie *ctx) 284 { 285 ctx->off = (__u8 *)(ctx->tcp + 1) - (__u8 *)ctx->data, 286 287 bpf_loop(40, tcp_parse_option, ctx, 0); 288 } 289 290 static int tcp_validate_sysctl(struct tcp_syncookie *ctx) 291 { 292 if ((ctx->ipv4 && ctx->attrs.mss != MSS_LOCAL_IPV4) || 293 (ctx->ipv6 && ctx->attrs.mss != MSS_LOCAL_IPV6)) 294 goto err; 295 296 if (!ctx->attrs.wscale_ok || ctx->attrs.snd_wscale != 7) 297 goto err; 298 299 if (!ctx->attrs.tstamp_ok) 300 goto err; 301 302 if (!ctx->attrs.sack_ok) 303 goto err; 304 305 if (!ctx->tcp->ece || !ctx->tcp->cwr) 306 goto err; 307 308 return 0; 309 err: 310 return -1; 311 } 312 313 static void tcp_prepare_cookie(struct tcp_syncookie *ctx) 314 { 315 u32 seq = bpf_ntohl(ctx->tcp->seq); 316 u64 first = 0, second; 317 int mssind = 0; 318 u32 hash; 319 320 if (ctx->ipv4) { 321 for (mssind = ARRAY_SIZE(msstab4) - 1; mssind; mssind--) 322 if (ctx->attrs.mss >= msstab4[mssind]) 323 break; 324 325 ctx->attrs.mss = msstab4[mssind]; 326 327 first = (u64)ctx->ipv4->saddr << 32 | ctx->ipv4->daddr; 328 } else if (ctx->ipv6) { 329 for (mssind = ARRAY_SIZE(msstab6) - 1; mssind; mssind--) 330 if (ctx->attrs.mss >= msstab6[mssind]) 331 break; 332 333 ctx->attrs.mss = msstab6[mssind]; 334 335 first = (u64)ctx->ipv6->saddr.in6_u.u6_addr8[0] << 32 | 336 ctx->ipv6->daddr.in6_u.u6_addr32[0]; 337 } 338 339 second = (u64)seq << 32 | ctx->tcp->source << 16 | ctx->tcp->dest; 340 hash = siphash_2u64(first, second, &test_key_siphash); 341 342 if (ctx->attrs.tstamp_ok) { 343 ctx->attrs.rcv_tsecr = bpf_get_prandom_u32(); 344 ctx->attrs.rcv_tsecr &= ~COOKIE_MASK; 345 ctx->attrs.rcv_tsecr |= hash & COOKIE_MASK; 346 } 347 348 hash &= ~COOKIE_MASK; 349 hash |= mssind << 6; 350 351 if (ctx->attrs.wscale_ok) 352 hash |= ctx->attrs.snd_wscale & BPF_SYNCOOKIE_WSCALE_MASK; 353 354 if (ctx->attrs.sack_ok) 355 hash |= BPF_SYNCOOKIE_SACK; 356 357 if (ctx->attrs.tstamp_ok && ctx->tcp->ece && ctx->tcp->cwr) 358 hash |= BPF_SYNCOOKIE_ECN; 359 360 ctx->cookie = hash; 361 } 362 363 static void tcp_write_options(struct tcp_syncookie *ctx) 364 { 365 ctx->ptr32 = (__be32 *)(ctx->tcp + 1); 366 367 *ctx->ptr32++ = bpf_htonl(TCPOPT_MSS << 24 | TCPOLEN_MSS << 16 | 368 ctx->attrs.mss); 369 370 if (ctx->attrs.wscale_ok) 371 *ctx->ptr32++ = bpf_htonl(TCPOPT_NOP << 24 | 372 TCPOPT_WINDOW << 16 | 373 TCPOLEN_WINDOW << 8 | 374 ctx->attrs.snd_wscale); 375 376 if (ctx->attrs.tstamp_ok) { 377 if (ctx->attrs.sack_ok) 378 *ctx->ptr32++ = bpf_htonl(TCPOPT_SACK_PERM << 24 | 379 TCPOLEN_SACK_PERM << 16 | 380 TCPOPT_TIMESTAMP << 8 | 381 TCPOLEN_TIMESTAMP); 382 else 383 *ctx->ptr32++ = bpf_htonl(TCPOPT_NOP << 24 | 384 TCPOPT_NOP << 16 | 385 TCPOPT_TIMESTAMP << 8 | 386 TCPOLEN_TIMESTAMP); 387 388 *ctx->ptr32++ = bpf_htonl(ctx->attrs.rcv_tsecr); 389 *ctx->ptr32++ = bpf_htonl(ctx->attrs.rcv_tsval); 390 } else if (ctx->attrs.sack_ok) { 391 *ctx->ptr32++ = bpf_htonl(TCPOPT_NOP << 24 | 392 TCPOPT_NOP << 16 | 393 TCPOPT_SACK_PERM << 8 | 394 TCPOLEN_SACK_PERM); 395 } 396 } 397 398 static int tcp_handle_syn(struct tcp_syncookie *ctx) 399 { 400 s64 csum; 401 402 if (tcp_validate_header(ctx)) 403 goto err; 404 405 tcp_parse_options(ctx); 406 407 if (tcp_validate_sysctl(ctx)) 408 goto err; 409 410 tcp_prepare_cookie(ctx); 411 tcp_write_options(ctx); 412 413 swap(ctx->tcp->source, ctx->tcp->dest); 414 ctx->tcp->check = 0; 415 ctx->tcp->ack_seq = bpf_htonl(bpf_ntohl(ctx->tcp->seq) + 1); 416 ctx->tcp->seq = bpf_htonl(ctx->cookie); 417 ctx->tcp->doff = ((long)ctx->ptr32 - (long)ctx->tcp) >> 2; 418 ctx->tcp->ack = 1; 419 if (!ctx->attrs.tstamp_ok || !ctx->tcp->ece || !ctx->tcp->cwr) 420 ctx->tcp->ece = 0; 421 ctx->tcp->cwr = 0; 422 423 csum = bpf_csum_diff(0, 0, (void *)ctx->tcp, ctx->tcp->doff * 4, 0); 424 if (csum < 0) 425 goto err; 426 427 if (ctx->ipv4) { 428 swap(ctx->ipv4->saddr, ctx->ipv4->daddr); 429 ctx->tcp->check = tcp_v4_csum(ctx, csum); 430 431 ctx->ipv4->check = 0; 432 ctx->ipv4->tos = 0; 433 ctx->ipv4->tot_len = bpf_htons((long)ctx->ptr32 - (long)ctx->ipv4); 434 ctx->ipv4->id = 0; 435 ctx->ipv4->ttl = 64; 436 437 csum = bpf_csum_diff(0, 0, (void *)ctx->ipv4, sizeof(*ctx->ipv4), 0); 438 if (csum < 0) 439 goto err; 440 441 ctx->ipv4->check = csum_fold(csum); 442 } else if (ctx->ipv6) { 443 swap(ctx->ipv6->saddr, ctx->ipv6->daddr); 444 ctx->tcp->check = tcp_v6_csum(ctx, csum); 445 446 *(__be32 *)ctx->ipv6 = bpf_htonl(0x60000000); 447 ctx->ipv6->payload_len = bpf_htons((long)ctx->ptr32 - (long)ctx->tcp); 448 ctx->ipv6->hop_limit = 64; 449 } 450 451 swap_array(ctx->eth->h_source, ctx->eth->h_dest); 452 453 if (bpf_skb_change_tail(ctx->skb, (long)ctx->ptr32 - (long)ctx->eth, 0)) 454 goto err; 455 456 return bpf_redirect(ctx->skb->ifindex, 0); 457 err: 458 return TC_ACT_SHOT; 459 } 460 461 static int tcp_validate_cookie(struct tcp_syncookie *ctx) 462 { 463 u32 cookie = bpf_ntohl(ctx->tcp->ack_seq) - 1; 464 u32 seq = bpf_ntohl(ctx->tcp->seq) - 1; 465 u64 first = 0, second; 466 int mssind; 467 u32 hash; 468 469 if (ctx->ipv4) 470 first = (u64)ctx->ipv4->saddr << 32 | ctx->ipv4->daddr; 471 else if (ctx->ipv6) 472 first = (u64)ctx->ipv6->saddr.in6_u.u6_addr8[0] << 32 | 473 ctx->ipv6->daddr.in6_u.u6_addr32[0]; 474 475 second = (u64)seq << 32 | ctx->tcp->source << 16 | ctx->tcp->dest; 476 hash = siphash_2u64(first, second, &test_key_siphash); 477 478 if (ctx->attrs.tstamp_ok) 479 hash -= ctx->attrs.rcv_tsecr & COOKIE_MASK; 480 else 481 hash &= ~COOKIE_MASK; 482 483 hash -= cookie & ~COOKIE_MASK; 484 if (hash) 485 goto err; 486 487 mssind = (cookie & (3 << 6)) >> 6; 488 if (ctx->ipv4) { 489 if (mssind > ARRAY_SIZE(msstab4)) 490 goto err; 491 492 ctx->attrs.mss = msstab4[mssind]; 493 } else { 494 if (mssind > ARRAY_SIZE(msstab6)) 495 goto err; 496 497 ctx->attrs.mss = msstab6[mssind]; 498 } 499 500 ctx->attrs.snd_wscale = cookie & BPF_SYNCOOKIE_WSCALE_MASK; 501 ctx->attrs.rcv_wscale = ctx->attrs.snd_wscale; 502 ctx->attrs.wscale_ok = ctx->attrs.snd_wscale == BPF_SYNCOOKIE_WSCALE_MASK; 503 ctx->attrs.sack_ok = cookie & BPF_SYNCOOKIE_SACK; 504 ctx->attrs.ecn_ok = cookie & BPF_SYNCOOKIE_ECN; 505 506 return 0; 507 err: 508 return -1; 509 } 510 511 static int tcp_handle_ack(struct tcp_syncookie *ctx) 512 { 513 struct bpf_sock_tuple tuple; 514 struct bpf_sock *skc; 515 int ret = TC_ACT_OK; 516 struct sock *sk; 517 u32 tuple_size; 518 519 if (ctx->ipv4) { 520 tuple.ipv4.saddr = ctx->ipv4->saddr; 521 tuple.ipv4.daddr = ctx->ipv4->daddr; 522 tuple.ipv4.sport = ctx->tcp->source; 523 tuple.ipv4.dport = ctx->tcp->dest; 524 tuple_size = sizeof(tuple.ipv4); 525 } else if (ctx->ipv6) { 526 __builtin_memcpy(tuple.ipv6.saddr, &ctx->ipv6->saddr, sizeof(tuple.ipv6.saddr)); 527 __builtin_memcpy(tuple.ipv6.daddr, &ctx->ipv6->daddr, sizeof(tuple.ipv6.daddr)); 528 tuple.ipv6.sport = ctx->tcp->source; 529 tuple.ipv6.dport = ctx->tcp->dest; 530 tuple_size = sizeof(tuple.ipv6); 531 } else { 532 goto out; 533 } 534 535 skc = bpf_skc_lookup_tcp(ctx->skb, &tuple, tuple_size, -1, 0); 536 if (!skc) 537 goto out; 538 539 if (skc->state != TCP_LISTEN) 540 goto release; 541 542 sk = (struct sock *)bpf_skc_to_tcp_sock(skc); 543 if (!sk) 544 goto err; 545 546 if (tcp_validate_header(ctx)) 547 goto err; 548 549 tcp_parse_options(ctx); 550 551 if (tcp_validate_cookie(ctx)) 552 goto err; 553 554 ret = bpf_sk_assign_tcp_reqsk(ctx->skb, sk, &ctx->attrs, sizeof(ctx->attrs)); 555 if (ret < 0) 556 goto err; 557 558 release: 559 bpf_sk_release(skc); 560 out: 561 return ret; 562 563 err: 564 ret = TC_ACT_SHOT; 565 goto release; 566 } 567 568 SEC("tc") 569 int tcp_custom_syncookie(struct __sk_buff *skb) 570 { 571 struct tcp_syncookie ctx = { 572 .skb = skb, 573 }; 574 575 if (tcp_load_headers(&ctx)) 576 return TC_ACT_OK; 577 578 if (ctx.tcp->rst) 579 return TC_ACT_OK; 580 581 if (ctx.tcp->syn) { 582 if (ctx.tcp->ack) 583 return TC_ACT_OK; 584 585 handled_syn = true; 586 587 return tcp_handle_syn(&ctx); 588 } 589 590 handled_ack = true; 591 592 return tcp_handle_ack(&ctx); 593 } 594 595 char _license[] SEC("license") = "GPL"; 596