1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright Amazon.com Inc. or its affiliates. */ 3 4 #include "vmlinux.h" 5 6 #include <bpf/bpf_helpers.h> 7 #include <bpf/bpf_endian.h> 8 #include "bpf_tracing_net.h" 9 #include "bpf_kfuncs.h" 10 #include "test_siphash.h" 11 #include "test_tcp_custom_syncookie.h" 12 13 /* Hash is calculated for each client and split into ISN and TS. 14 * 15 * MSB LSB 16 * ISN: | 31 ... 8 | 7 6 | 5 | 4 | 3 2 1 0 | 17 * | Hash_1 | MSS | ECN | SACK | WScale | 18 * 19 * TS: | 31 ... 8 | 7 ... 0 | 20 * | Random | Hash_2 | 21 */ 22 #define COOKIE_BITS 8 23 #define COOKIE_MASK (((__u32)1 << COOKIE_BITS) - 1) 24 25 enum { 26 /* 0xf is invalid thus means that SYN did not have WScale. */ 27 BPF_SYNCOOKIE_WSCALE_MASK = (1 << 4) - 1, 28 BPF_SYNCOOKIE_SACK = (1 << 4), 29 BPF_SYNCOOKIE_ECN = (1 << 5), 30 }; 31 32 #define MSS_LOCAL_IPV4 65495 33 #define MSS_LOCAL_IPV6 65476 34 35 const __u16 msstab4[] = { 36 536, 37 1300, 38 1460, 39 MSS_LOCAL_IPV4, 40 }; 41 42 const __u16 msstab6[] = { 43 1280 - 60, /* IPV6_MIN_MTU - 60 */ 44 1480 - 60, 45 9000 - 60, 46 MSS_LOCAL_IPV6, 47 }; 48 49 static siphash_key_t test_key_siphash = { 50 { 0x0706050403020100ULL, 0x0f0e0d0c0b0a0908ULL } 51 }; 52 53 struct tcp_syncookie { 54 struct __sk_buff *skb; 55 void *data_end; 56 struct ethhdr *eth; 57 struct iphdr *ipv4; 58 struct ipv6hdr *ipv6; 59 struct tcphdr *tcp; 60 union { 61 char *ptr; 62 __be32 *ptr32; 63 }; 64 struct bpf_tcp_req_attrs attrs; 65 u32 cookie; 66 u64 first; 67 }; 68 69 bool handled_syn, handled_ack; 70 71 static int tcp_load_headers(struct tcp_syncookie *ctx) 72 { 73 ctx->data_end = (void *)(long)ctx->skb->data_end; 74 ctx->eth = (struct ethhdr *)(long)ctx->skb->data; 75 76 if (ctx->eth + 1 > ctx->data_end) 77 goto err; 78 79 switch (bpf_ntohs(ctx->eth->h_proto)) { 80 case ETH_P_IP: 81 ctx->ipv4 = (struct iphdr *)(ctx->eth + 1); 82 83 if (ctx->ipv4 + 1 > ctx->data_end) 84 goto err; 85 86 if (ctx->ipv4->ihl != sizeof(*ctx->ipv4) / 4) 87 goto err; 88 89 if (ctx->ipv4->version != 4) 90 goto err; 91 92 if (ctx->ipv4->protocol != IPPROTO_TCP) 93 goto err; 94 95 ctx->tcp = (struct tcphdr *)(ctx->ipv4 + 1); 96 break; 97 case ETH_P_IPV6: 98 ctx->ipv6 = (struct ipv6hdr *)(ctx->eth + 1); 99 100 if (ctx->ipv6 + 1 > ctx->data_end) 101 goto err; 102 103 if (ctx->ipv6->version != 6) 104 goto err; 105 106 if (ctx->ipv6->nexthdr != NEXTHDR_TCP) 107 goto err; 108 109 ctx->tcp = (struct tcphdr *)(ctx->ipv6 + 1); 110 break; 111 default: 112 goto err; 113 } 114 115 if (ctx->tcp + 1 > ctx->data_end) 116 goto err; 117 118 return 0; 119 err: 120 return -1; 121 } 122 123 static int tcp_reload_headers(struct tcp_syncookie *ctx) 124 { 125 /* Without volatile, 126 * R3 32-bit pointer arithmetic prohibited 127 */ 128 volatile u64 data_len = ctx->skb->data_end - ctx->skb->data; 129 130 if (ctx->tcp->doff < sizeof(*ctx->tcp) / 4) 131 goto err; 132 133 /* Needed to calculate csum and parse TCP options. */ 134 if (bpf_skb_change_tail(ctx->skb, data_len + 60 - ctx->tcp->doff * 4, 0)) 135 goto err; 136 137 ctx->data_end = (void *)(long)ctx->skb->data_end; 138 ctx->eth = (struct ethhdr *)(long)ctx->skb->data; 139 if (ctx->ipv4) { 140 ctx->ipv4 = (struct iphdr *)(ctx->eth + 1); 141 ctx->ipv6 = NULL; 142 ctx->tcp = (struct tcphdr *)(ctx->ipv4 + 1); 143 } else { 144 ctx->ipv4 = NULL; 145 ctx->ipv6 = (struct ipv6hdr *)(ctx->eth + 1); 146 ctx->tcp = (struct tcphdr *)(ctx->ipv6 + 1); 147 } 148 149 if ((void *)ctx->tcp + 60 > ctx->data_end) 150 goto err; 151 152 return 0; 153 err: 154 return -1; 155 } 156 157 static __sum16 tcp_v4_csum(struct tcp_syncookie *ctx, __wsum csum) 158 { 159 return csum_tcpudp_magic(ctx->ipv4->saddr, ctx->ipv4->daddr, 160 ctx->tcp->doff * 4, IPPROTO_TCP, csum); 161 } 162 163 static __sum16 tcp_v6_csum(struct tcp_syncookie *ctx, __wsum csum) 164 { 165 return csum_ipv6_magic(&ctx->ipv6->saddr, &ctx->ipv6->daddr, 166 ctx->tcp->doff * 4, IPPROTO_TCP, csum); 167 } 168 169 static int tcp_validate_header(struct tcp_syncookie *ctx) 170 { 171 s64 csum; 172 173 if (tcp_reload_headers(ctx)) 174 goto err; 175 176 csum = bpf_csum_diff(0, 0, (void *)ctx->tcp, ctx->tcp->doff * 4, 0); 177 if (csum < 0) 178 goto err; 179 180 if (ctx->ipv4) { 181 /* check tcp_v4_csum(csum) is 0 if not on lo. */ 182 183 csum = bpf_csum_diff(0, 0, (void *)ctx->ipv4, ctx->ipv4->ihl * 4, 0); 184 if (csum < 0) 185 goto err; 186 187 if (csum_fold(csum) != 0) 188 goto err; 189 } else if (ctx->ipv6) { 190 /* check tcp_v6_csum(csum) is 0 if not on lo. */ 191 } 192 193 return 0; 194 err: 195 return -1; 196 } 197 198 static int tcp_parse_option(__u32 index, struct tcp_syncookie *ctx) 199 { 200 char opcode, opsize; 201 202 if (ctx->ptr + 1 > ctx->data_end) 203 goto stop; 204 205 opcode = *ctx->ptr++; 206 207 if (opcode == TCPOPT_EOL) 208 goto stop; 209 210 if (opcode == TCPOPT_NOP) 211 goto next; 212 213 if (ctx->ptr + 1 > ctx->data_end) 214 goto stop; 215 216 opsize = *ctx->ptr++; 217 218 if (opsize < 2) 219 goto stop; 220 221 switch (opcode) { 222 case TCPOPT_MSS: 223 if (opsize == TCPOLEN_MSS && ctx->tcp->syn && 224 ctx->ptr + (TCPOLEN_MSS - 2) < ctx->data_end) 225 ctx->attrs.mss = get_unaligned_be16(ctx->ptr); 226 break; 227 case TCPOPT_WINDOW: 228 if (opsize == TCPOLEN_WINDOW && ctx->tcp->syn && 229 ctx->ptr + (TCPOLEN_WINDOW - 2) < ctx->data_end) { 230 ctx->attrs.wscale_ok = 1; 231 ctx->attrs.snd_wscale = *ctx->ptr; 232 } 233 break; 234 case TCPOPT_TIMESTAMP: 235 if (opsize == TCPOLEN_TIMESTAMP && 236 ctx->ptr + (TCPOLEN_TIMESTAMP - 2) < ctx->data_end) { 237 ctx->attrs.rcv_tsval = get_unaligned_be32(ctx->ptr); 238 ctx->attrs.rcv_tsecr = get_unaligned_be32(ctx->ptr + 4); 239 240 if (ctx->tcp->syn && ctx->attrs.rcv_tsecr) 241 ctx->attrs.tstamp_ok = 0; 242 else 243 ctx->attrs.tstamp_ok = 1; 244 } 245 break; 246 case TCPOPT_SACK_PERM: 247 if (opsize == TCPOLEN_SACK_PERM && ctx->tcp->syn && 248 ctx->ptr + (TCPOLEN_SACK_PERM - 2) < ctx->data_end) 249 ctx->attrs.sack_ok = 1; 250 break; 251 } 252 253 ctx->ptr += opsize - 2; 254 next: 255 return 0; 256 stop: 257 return 1; 258 } 259 260 static void tcp_parse_options(struct tcp_syncookie *ctx) 261 { 262 ctx->ptr = (char *)(ctx->tcp + 1); 263 264 bpf_loop(40, tcp_parse_option, ctx, 0); 265 } 266 267 static int tcp_validate_sysctl(struct tcp_syncookie *ctx) 268 { 269 if ((ctx->ipv4 && ctx->attrs.mss != MSS_LOCAL_IPV4) || 270 (ctx->ipv6 && ctx->attrs.mss != MSS_LOCAL_IPV6)) 271 goto err; 272 273 if (!ctx->attrs.wscale_ok || ctx->attrs.snd_wscale != 7) 274 goto err; 275 276 if (!ctx->attrs.tstamp_ok) 277 goto err; 278 279 if (!ctx->attrs.sack_ok) 280 goto err; 281 282 if (!ctx->tcp->ece || !ctx->tcp->cwr) 283 goto err; 284 285 return 0; 286 err: 287 return -1; 288 } 289 290 static void tcp_prepare_cookie(struct tcp_syncookie *ctx) 291 { 292 u32 seq = bpf_ntohl(ctx->tcp->seq); 293 u64 first = 0, second; 294 int mssind = 0; 295 u32 hash; 296 297 if (ctx->ipv4) { 298 for (mssind = ARRAY_SIZE(msstab4) - 1; mssind; mssind--) 299 if (ctx->attrs.mss >= msstab4[mssind]) 300 break; 301 302 ctx->attrs.mss = msstab4[mssind]; 303 304 first = (u64)ctx->ipv4->saddr << 32 | ctx->ipv4->daddr; 305 } else if (ctx->ipv6) { 306 for (mssind = ARRAY_SIZE(msstab6) - 1; mssind; mssind--) 307 if (ctx->attrs.mss >= msstab6[mssind]) 308 break; 309 310 ctx->attrs.mss = msstab6[mssind]; 311 312 first = (u64)ctx->ipv6->saddr.in6_u.u6_addr8[0] << 32 | 313 ctx->ipv6->daddr.in6_u.u6_addr32[0]; 314 } 315 316 second = (u64)seq << 32 | ctx->tcp->source << 16 | ctx->tcp->dest; 317 hash = siphash_2u64(first, second, &test_key_siphash); 318 319 if (ctx->attrs.tstamp_ok) { 320 ctx->attrs.rcv_tsecr = bpf_get_prandom_u32(); 321 ctx->attrs.rcv_tsecr &= ~COOKIE_MASK; 322 ctx->attrs.rcv_tsecr |= hash & COOKIE_MASK; 323 } 324 325 hash &= ~COOKIE_MASK; 326 hash |= mssind << 6; 327 328 if (ctx->attrs.wscale_ok) 329 hash |= ctx->attrs.snd_wscale & BPF_SYNCOOKIE_WSCALE_MASK; 330 331 if (ctx->attrs.sack_ok) 332 hash |= BPF_SYNCOOKIE_SACK; 333 334 if (ctx->attrs.tstamp_ok && ctx->tcp->ece && ctx->tcp->cwr) 335 hash |= BPF_SYNCOOKIE_ECN; 336 337 ctx->cookie = hash; 338 } 339 340 static void tcp_write_options(struct tcp_syncookie *ctx) 341 { 342 ctx->ptr32 = (__be32 *)(ctx->tcp + 1); 343 344 *ctx->ptr32++ = bpf_htonl(TCPOPT_MSS << 24 | TCPOLEN_MSS << 16 | 345 ctx->attrs.mss); 346 347 if (ctx->attrs.wscale_ok) 348 *ctx->ptr32++ = bpf_htonl(TCPOPT_NOP << 24 | 349 TCPOPT_WINDOW << 16 | 350 TCPOLEN_WINDOW << 8 | 351 ctx->attrs.snd_wscale); 352 353 if (ctx->attrs.tstamp_ok) { 354 if (ctx->attrs.sack_ok) 355 *ctx->ptr32++ = bpf_htonl(TCPOPT_SACK_PERM << 24 | 356 TCPOLEN_SACK_PERM << 16 | 357 TCPOPT_TIMESTAMP << 8 | 358 TCPOLEN_TIMESTAMP); 359 else 360 *ctx->ptr32++ = bpf_htonl(TCPOPT_NOP << 24 | 361 TCPOPT_NOP << 16 | 362 TCPOPT_TIMESTAMP << 8 | 363 TCPOLEN_TIMESTAMP); 364 365 *ctx->ptr32++ = bpf_htonl(ctx->attrs.rcv_tsecr); 366 *ctx->ptr32++ = bpf_htonl(ctx->attrs.rcv_tsval); 367 } else if (ctx->attrs.sack_ok) { 368 *ctx->ptr32++ = bpf_htonl(TCPOPT_NOP << 24 | 369 TCPOPT_NOP << 16 | 370 TCPOPT_SACK_PERM << 8 | 371 TCPOLEN_SACK_PERM); 372 } 373 } 374 375 static int tcp_handle_syn(struct tcp_syncookie *ctx) 376 { 377 s64 csum; 378 379 if (tcp_validate_header(ctx)) 380 goto err; 381 382 tcp_parse_options(ctx); 383 384 if (tcp_validate_sysctl(ctx)) 385 goto err; 386 387 tcp_prepare_cookie(ctx); 388 tcp_write_options(ctx); 389 390 swap(ctx->tcp->source, ctx->tcp->dest); 391 ctx->tcp->check = 0; 392 ctx->tcp->ack_seq = bpf_htonl(bpf_ntohl(ctx->tcp->seq) + 1); 393 ctx->tcp->seq = bpf_htonl(ctx->cookie); 394 ctx->tcp->doff = ((long)ctx->ptr32 - (long)ctx->tcp) >> 2; 395 ctx->tcp->ack = 1; 396 if (!ctx->attrs.tstamp_ok || !ctx->tcp->ece || !ctx->tcp->cwr) 397 ctx->tcp->ece = 0; 398 ctx->tcp->cwr = 0; 399 400 csum = bpf_csum_diff(0, 0, (void *)ctx->tcp, ctx->tcp->doff * 4, 0); 401 if (csum < 0) 402 goto err; 403 404 if (ctx->ipv4) { 405 swap(ctx->ipv4->saddr, ctx->ipv4->daddr); 406 ctx->tcp->check = tcp_v4_csum(ctx, csum); 407 408 ctx->ipv4->check = 0; 409 ctx->ipv4->tos = 0; 410 ctx->ipv4->tot_len = bpf_htons((long)ctx->ptr32 - (long)ctx->ipv4); 411 ctx->ipv4->id = 0; 412 ctx->ipv4->ttl = 64; 413 414 csum = bpf_csum_diff(0, 0, (void *)ctx->ipv4, sizeof(*ctx->ipv4), 0); 415 if (csum < 0) 416 goto err; 417 418 ctx->ipv4->check = csum_fold(csum); 419 } else if (ctx->ipv6) { 420 swap(ctx->ipv6->saddr, ctx->ipv6->daddr); 421 ctx->tcp->check = tcp_v6_csum(ctx, csum); 422 423 *(__be32 *)ctx->ipv6 = bpf_htonl(0x60000000); 424 ctx->ipv6->payload_len = bpf_htons((long)ctx->ptr32 - (long)ctx->tcp); 425 ctx->ipv6->hop_limit = 64; 426 } 427 428 swap_array(ctx->eth->h_source, ctx->eth->h_dest); 429 430 if (bpf_skb_change_tail(ctx->skb, (long)ctx->ptr32 - (long)ctx->eth, 0)) 431 goto err; 432 433 return bpf_redirect(ctx->skb->ifindex, 0); 434 err: 435 return TC_ACT_SHOT; 436 } 437 438 static int tcp_validate_cookie(struct tcp_syncookie *ctx) 439 { 440 u32 cookie = bpf_ntohl(ctx->tcp->ack_seq) - 1; 441 u32 seq = bpf_ntohl(ctx->tcp->seq) - 1; 442 u64 first = 0, second; 443 int mssind; 444 u32 hash; 445 446 if (ctx->ipv4) 447 first = (u64)ctx->ipv4->saddr << 32 | ctx->ipv4->daddr; 448 else if (ctx->ipv6) 449 first = (u64)ctx->ipv6->saddr.in6_u.u6_addr8[0] << 32 | 450 ctx->ipv6->daddr.in6_u.u6_addr32[0]; 451 452 second = (u64)seq << 32 | ctx->tcp->source << 16 | ctx->tcp->dest; 453 hash = siphash_2u64(first, second, &test_key_siphash); 454 455 if (ctx->attrs.tstamp_ok) 456 hash -= ctx->attrs.rcv_tsecr & COOKIE_MASK; 457 else 458 hash &= ~COOKIE_MASK; 459 460 hash -= cookie & ~COOKIE_MASK; 461 if (hash) 462 goto err; 463 464 mssind = (cookie & (3 << 6)) >> 6; 465 if (ctx->ipv4) { 466 if (mssind > ARRAY_SIZE(msstab4)) 467 goto err; 468 469 ctx->attrs.mss = msstab4[mssind]; 470 } else { 471 if (mssind > ARRAY_SIZE(msstab6)) 472 goto err; 473 474 ctx->attrs.mss = msstab6[mssind]; 475 } 476 477 ctx->attrs.snd_wscale = cookie & BPF_SYNCOOKIE_WSCALE_MASK; 478 ctx->attrs.rcv_wscale = ctx->attrs.snd_wscale; 479 ctx->attrs.wscale_ok = ctx->attrs.snd_wscale == BPF_SYNCOOKIE_WSCALE_MASK; 480 ctx->attrs.sack_ok = cookie & BPF_SYNCOOKIE_SACK; 481 ctx->attrs.ecn_ok = cookie & BPF_SYNCOOKIE_ECN; 482 483 return 0; 484 err: 485 return -1; 486 } 487 488 static int tcp_handle_ack(struct tcp_syncookie *ctx) 489 { 490 struct bpf_sock_tuple tuple; 491 struct bpf_sock *skc; 492 int ret = TC_ACT_OK; 493 struct sock *sk; 494 u32 tuple_size; 495 496 if (ctx->ipv4) { 497 tuple.ipv4.saddr = ctx->ipv4->saddr; 498 tuple.ipv4.daddr = ctx->ipv4->daddr; 499 tuple.ipv4.sport = ctx->tcp->source; 500 tuple.ipv4.dport = ctx->tcp->dest; 501 tuple_size = sizeof(tuple.ipv4); 502 } else if (ctx->ipv6) { 503 __builtin_memcpy(tuple.ipv6.saddr, &ctx->ipv6->saddr, sizeof(tuple.ipv6.saddr)); 504 __builtin_memcpy(tuple.ipv6.daddr, &ctx->ipv6->daddr, sizeof(tuple.ipv6.daddr)); 505 tuple.ipv6.sport = ctx->tcp->source; 506 tuple.ipv6.dport = ctx->tcp->dest; 507 tuple_size = sizeof(tuple.ipv6); 508 } else { 509 goto out; 510 } 511 512 skc = bpf_skc_lookup_tcp(ctx->skb, &tuple, tuple_size, -1, 0); 513 if (!skc) 514 goto out; 515 516 if (skc->state != TCP_LISTEN) 517 goto release; 518 519 sk = (struct sock *)bpf_skc_to_tcp_sock(skc); 520 if (!sk) 521 goto err; 522 523 if (tcp_validate_header(ctx)) 524 goto err; 525 526 tcp_parse_options(ctx); 527 528 if (tcp_validate_cookie(ctx)) 529 goto err; 530 531 ret = bpf_sk_assign_tcp_reqsk(ctx->skb, sk, &ctx->attrs, sizeof(ctx->attrs)); 532 if (ret < 0) 533 goto err; 534 535 release: 536 bpf_sk_release(skc); 537 out: 538 return ret; 539 540 err: 541 ret = TC_ACT_SHOT; 542 goto release; 543 } 544 545 SEC("tc") 546 int tcp_custom_syncookie(struct __sk_buff *skb) 547 { 548 struct tcp_syncookie ctx = { 549 .skb = skb, 550 }; 551 552 if (tcp_load_headers(&ctx)) 553 return TC_ACT_OK; 554 555 if (ctx.tcp->rst) 556 return TC_ACT_OK; 557 558 if (ctx.tcp->syn) { 559 if (ctx.tcp->ack) 560 return TC_ACT_OK; 561 562 handled_syn = true; 563 564 return tcp_handle_syn(&ctx); 565 } 566 567 handled_ack = true; 568 569 return tcp_handle_ack(&ctx); 570 } 571 572 char _license[] SEC("license") = "GPL"; 573